WIP: dev into master #184

Draft
edraft wants to merge 121 commits from dev into master
22 changed files with 280 additions and 41 deletions
Showing only changes of commit d6b7eb9b30 - Show all commits

View File

@@ -22,5 +22,8 @@ def add_api(collection: _ServiceCollection):
dependency_error("cpl-auth", e)
from cpl.api.registry.policy import PolicyRegistry
collection.add_singleton(PolicyRegistry)
_ServiceCollection.with_module(add_api, __name__)

View File

@@ -1,5 +1,5 @@
import os
from typing import Mapping, Any, Callable
from typing import Mapping, Any, Callable, Self, Union
import uvicorn
from starlette.applications import Starlette
@@ -11,20 +11,24 @@ from starlette.routing import Route
from starlette.types import ExceptionHandler
from cpl import api, auth
from cpl.api.api_logger import APILogger
from cpl.api.api_settings import ApiSettings
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.error import APIError
from cpl.api.logger import APILogger
from cpl.api.middleware.authentication import AuthenticationMiddleware
from cpl.api.middleware.authorization import AuthorizationMiddleware
from cpl.api.middleware.logging import LoggingMiddleware
from cpl.api.middleware.request import RequestMiddleware
from cpl.api.model.policy import Policy
from cpl.api.router import Router
from cpl.api.typing import HTTPMethods, PartialMiddleware
from cpl.api.settings import ApiSettings
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
from cpl.application.abc.application_abc import ApplicationABC
from cpl.core.configuration import Configuration
from cpl.dependency.service_provider_abc import ServiceProviderABC
_logger = APILogger("API")
PolicyInput = Union[dict[str, PolicyResolver], Policy]
class WebApp(ApplicationABC):
def __init__(self, services: ServiceProviderABC):
@@ -32,6 +36,7 @@ class WebApp(ApplicationABC):
self._app: Starlette | None = None
self._api_settings = Configuration.get(ApiSettings)
self._policy_registry = services.get_service(PolicyRegistry)
self._routes: list[Route] = []
self._middleware: list[Middleware] = [
@@ -66,11 +71,12 @@ class WebApp(ApplicationABC):
_logger.debug(f"Allowed origins: {origins}")
return origins.split(",")
def with_database(self):
def with_database(self) -> Self:
self.with_migrations()
self.with_seeders()
return self
def with_app(self, app: Starlette):
def with_app(self, app: Starlette) -> Self:
assert app is not None, "app must not be None"
assert isinstance(app, Starlette), "app must be an instance of Starlette"
self._app = app
@@ -80,7 +86,7 @@ class WebApp(ApplicationABC):
if self._app is not None:
raise ValueError("App is already set, cannot add routes or middleware")
def with_routes_directory(self, directory: str) -> "WebApp":
def with_routes_directory(self, directory: str) -> Self:
self._check_for_app()
assert directory is not None, "directory must not be None"
@@ -94,14 +100,14 @@ class WebApp(ApplicationABC):
return self
def with_routes(self, routes: list[Route]) -> "WebApp":
def with_routes(self, routes: list[Route]) -> Self:
self._check_for_app()
assert self._routes is not None, "routes must not be None"
assert all(isinstance(route, Route) for route in routes), "all routes must be of type starlette.routing.Route"
self._routes.extend(routes)
return self
def with_route(self, path: str, fn: Callable[[Request], Any], method: HTTPMethods, **kwargs) -> "WebApp":
def with_route(self, path: str, fn: Callable[[Request], Any], method: HTTPMethods, **kwargs) -> Self:
self._check_for_app()
assert path is not None, "path must not be None"
assert fn is not None, "fn must not be None"
@@ -117,7 +123,7 @@ class WebApp(ApplicationABC):
self._routes.append(Route(path, fn, methods=[method], **kwargs))
return self
def with_middleware(self, middleware: PartialMiddleware) -> "WebApp":
def with_middleware(self, middleware: PartialMiddleware) -> Self:
self._check_for_app()
if isinstance(middleware, Middleware):
@@ -129,15 +135,49 @@ class WebApp(ApplicationABC):
return self
def with_authentication(self):
def with_authentication(self) -> Self:
self.with_middleware(AuthenticationMiddleware)
return self
def with_authorization(self):
pass
def with_authorization(self, *policies: list[PolicyInput] | PolicyInput) -> Self:
if policies:
_policies = []
if not isinstance(policies, list):
policies = list(policies)
for i, policy in enumerate(policies):
if isinstance(policy, dict):
for name, resolver in policy.items():
if not isinstance(name, str):
_logger.warning(f"Skipping policy at index {i}, name must be a string")
continue
if not callable(resolver):
_logger.warning(f"Skipping policy {name}, resolver must be callable")
continue
_policies.append(Policy(name, resolver))
continue
_policies.append(policy)
self._policy_registry.extend_policies(_policies)
self.with_middleware(AuthorizationMiddleware)
return self
def _validate_policies(self):
for rule in Router.get_authorization_rules():
for policy_name in rule["policies"]:
policy = self._policy_registry.get(policy_name)
if not policy:
_logger.fatal(f"Authorization policy '{policy_name}' not found")
async def main(self):
_logger.debug(f"Preparing API")
self._validate_policies()
if self._app is None:
routes = [
Route(
@@ -166,13 +206,6 @@ class WebApp(ApplicationABC):
app = self._app
_logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
# uvicorn.run(
# app,
# host=self._api_settings.host,
# port=self._api_settings.port,
# log_config=None,
# loop="asyncio"
# )
config = uvicorn.Config(
app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio"

View File

@@ -7,14 +7,23 @@ from starlette.types import Scope, Receive, Send
class APIError(HTTPException):
status_code = 500
@classmethod
async def asgi_response(cls, scope: Scope, receive: Receive, send: Send):
r = JSONResponse({"error": cls.__name__}, status_code=cls.status_code)
def __init__(self, message: str = ""):
super().__init__(self.status_code, message)
self._message = message
@property
def error_message(self) -> str:
if self._message:
return f"{type(self).__name__}: {self._message}"
return f"{type(self).__name__}"
async def asgi_response(self, scope: Scope, receive: Receive, send: Send):
r = JSONResponse({"error": self.error_message}, status_code=self.status_code)
return await r(scope, receive, send)
@classmethod
def response(cls):
return JSONResponse({"error": cls.__name__}, status_code=cls.status_code)
def response(self):
return JSONResponse({"error": self.error_message}, status_code=self.status_code)
class Unauthorized(APIError):

View File

@@ -2,12 +2,13 @@ from keycloak import KeycloakAuthenticationError
from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.api_logger import APILogger
from cpl.api.logger import APILogger
from cpl.api.error import Unauthorized
from cpl.api.middleware.request import get_request
from cpl.api.router import Router
from cpl.auth.keycloak import KeycloakClient
from cpl.auth.schema import AuthUserDao, AuthUser
from cpl.core.ctx import set_user
from cpl.dependency import ServiceProviderABC
_logger = APILogger(__name__)
@@ -53,6 +54,9 @@ class AuthenticationMiddleware(ASGIMiddleware):
_logger.debug(f"Unauthorized access to {url}, user is deleted")
return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
request.state.user = user
set_user(user)
return await self._call_next(scope, receive, send)
async def _get_or_crate_user(self, keycloak_id: str) -> AuthUser:

View File

@@ -0,0 +1,64 @@
from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.error import Unauthorized, Forbidden
from cpl.api.logger import APILogger
from cpl.api.middleware.request import get_request
from cpl.api.model.validation_match import ValidationMatch
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.router import Router
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
from cpl.core.ctx.user_context import get_user
from cpl.dependency.service_provider_abc import ServiceProviderABC
_logger = APILogger(__name__)
class AuthorizationMiddleware(ASGIMiddleware):
@ServiceProviderABC.inject
def __init__(self, app, policies: PolicyRegistry, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app)
self._policies = policies
self._user_dao = user_dao
async def __call__(self, scope: Scope, receive: Receive, send: Send):
request = get_request()
user = get_user()
if not user:
return await Unauthorized(f"Unknown user").asgi_response(scope, receive, send)
roles = await user.roles
request.state.roles = roles
role_names = [r.name for r in roles]
perms = await user.permissions
request.state.permissions = perms
perm_names = [p.name for p in perms]
for rule in Router.get_authorization_rules():
match = rule["match"]
if rule["roles"]:
if match == ValidationMatch.all and not all(r in role_names for r in rule["roles"]):
return await Forbidden(f"missing roles: {rule["roles"]}").asgi_response(scope, receive, send)
if match == ValidationMatch.any and not any(r in role_names for r in rule["roles"]):
return await Forbidden(f"missing roles: {rule["roles"]}").asgi_response(scope, receive, send)
if rule["permissions"]:
if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
for policy_name in rule["policies"]:
policy = self._policies.get(policy_name)
if not policy:
_logger.warning(f"Authorization policy '{policy_name}' not found")
continue
if not await policy.resolve(user):
return await Forbidden(f"policy {policy.name} failed").asgi_response(scope, receive, send)
return await self._call_next(scope, receive, send)

View File

@@ -4,7 +4,7 @@ from starlette.requests import Request
from starlette.types import Receive, Scope, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.api_logger import APILogger
from cpl.api.logger import APILogger
from cpl.api.middleware.request import get_request
_logger = APILogger(__name__)

View File

@@ -5,10 +5,9 @@ from uuid import uuid4
from starlette.requests import Request
from starlette.types import Scope, Receive, Send
from starlette.websockets import WebSocket
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.api_logger import APILogger
from cpl.api.logger import APILogger
from cpl.api.typing import TRequest
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
@@ -50,5 +49,5 @@ class RequestMiddleware(ASGIMiddleware):
_request_context.reset(self._ctx_token)
def get_request() -> Optional[Union[TRequest, WebSocket]]:
def get_request() -> Optional[TRequest]:
return _request_context.get()

View File

View File

@@ -0,0 +1,34 @@
from asyncio import iscoroutinefunction
from typing import Optional, Any, Coroutine, Awaitable
from cpl.api.typing import PolicyResolver
from cpl.core.ctx import get_user
class Policy:
def __init__(
self,
name: str,
resolver: PolicyResolver = None,
):
self._name = name
self._resolver: Optional[PolicyResolver] = resolver
@property
def name(self) -> str:
return self._name
@property
def resolvers(self) -> PolicyResolver:
return self._resolver
async def resolve(self, *args, **kwargs) -> bool:
if not self._resolver:
return True
if callable(self._resolver):
if iscoroutinefunction(self._resolver):
return await self._resolver(get_user())
return self._resolver(get_user())
return False

View File

@@ -0,0 +1,6 @@
from enum import Enum
class ValidationMatch(Enum):
any = "any"
all = "all"

View File

View File

@@ -0,0 +1,23 @@
from typing import Optional
from cpl.api.model.policy import Policy
class PolicyRegistry:
def __init__(self):
self._policies: dict[str, Policy] = {}
def extend_policies(self, policies: list[Policy]):
for policy in policies:
self.add_policy(policy)
def add_policy(self, policy: Policy):
assert isinstance(policy, Policy), "policy must be an instance of Policy"
if policy.name in self._policies:
raise ValueError(f"Policy {policy.name} is already registered")
self._policies[policy.name] = policy
def get(self, name: str) -> Optional[Policy]:
return self._policies.get(name)

View File

@@ -1,9 +1,14 @@
from enum import Enum
from starlette.routing import Route
from cpl.api.model.validation_match import ValidationMatch
class Router:
_registered_routes: list[Route] = []
_auth_required: list[str] = []
_authorization_rules: list[dict] = []
@classmethod
def get_routes(cls) -> list[Route]:
@@ -13,6 +18,10 @@ class Router:
def get_auth_required_routes(cls) -> list[str]:
return cls._auth_required
@classmethod
def get_authorization_rules(cls) -> list[dict]:
return cls._authorization_rules
@classmethod
def authenticate(cls):
"""
@@ -32,6 +41,50 @@ class Router:
return inner
@classmethod
def authorize(cls, roles: list[str | Enum]=None, permissions: list[str | Enum]=None, policies: list[str]=None, match: ValidationMatch=None):
"""
Decorator to mark a route as requiring authorization.
Usage:
@Route.authorize()
@Route.get("/example")
async def example_endpoint(request: TRequest):
...
"""
assert roles is None or isinstance(roles, list), "roles must be a list of strings"
assert permissions is None or isinstance(permissions, list), "permissions must be a list of strings"
assert policies is None or isinstance(policies, list), "policies must be a list of strings"
assert match is None or isinstance(match, ValidationMatch), "match must be an instance of ValidationMatch"
if roles is not None:
for role in roles:
if isinstance(role, Enum):
roles[roles.index(role)] = role.value
if permissions is not None:
for perm in permissions:
if isinstance(perm, Enum):
permissions[permissions.index(perm)] = perm.value
def inner(fn):
route_path = getattr(fn, "_route_path", None)
if not route_path:
return fn
if route_path in cls._authorization_rules:
raise ValueError(f"Route {route_path} is already registered for authorization")
cls._authorization_rules.append({
"roles": roles or [],
"permissions": permissions or [],
"policies": policies or [],
"match": match or ValidationMatch.all,
})
return fn
return inner
@classmethod
def route(cls, path=None, **kwargs):
def inner(fn):

View File

@@ -1,13 +1,19 @@
from typing import Union, Literal, Callable
from typing import Union, Literal, Callable, Type, Awaitable
from urllib.request import Request
from starlette.middleware import Middleware
from starlette.types import ASGIApp
from starlette.websockets import WebSocket
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.auth.schema import AuthUser
TRequest = Union[Request, WebSocket]
HTTPMethods = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
PartialMiddleware = Union[
ASGIMiddleware,
Type[ASGIMiddleware],
Middleware,
Callable[[ASGIApp], ASGIApp],
]
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]

View File

@@ -43,9 +43,9 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
result = await self._db.select_map(
f"""
SELECT COUNT(*)
FROM permission.role_users ru
JOIN permission.role_permissions rp ON ru.roleId = rp.roleId
SELECT COUNT(*) as count
FROM {TableManager.get("role_users")} ru
JOIN {TableManager.get("role_permissions")} rp ON ru.roleId = rp.roleId
WHERE ru.userId = {user_id}
AND rp.permissionId = {p.id}
AND ru.deleted = FALSE
@@ -61,9 +61,9 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
result = await self._db.select_map(
f"""
SELECT p.*
FROM permission.permissions p
JOIN permission.role_permissions rp ON p.id = rp.permissionId
JOIN permission.role_users ru ON rp.roleId = ru.roleId
FROM {TableManager.get("permissions")} p
JOIN {TableManager.get("role_permissions")} rp ON p.id = rp.permissionId
JOIN {TableManager.get("role_users")} ru ON rp.roleId = ru.roleId
WHERE ru.userId = {user_id}
AND rp.deleted = FALSE
AND ru.deleted = FALSE;

View File

@@ -487,7 +487,7 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
builder.with_temp_table(self._external_fields[temp])
if for_count:
builder.with_attribute("COUNT(*)", ignore_table_name=True)
builder.with_attribute("COUNT(*) as count", ignore_table_name=True)
else:
builder.with_attribute("*")

View File

@@ -33,7 +33,7 @@ class TableManager:
},
"role_users": {
ServerTypes.POSTGRES: "permission.role_users",
ServerTypes.MYSQL: "permission_role_users",
ServerTypes.MYSQL: "permission_role_auth_users",
},
}

View File

@@ -1,7 +1,7 @@
from starlette.responses import JSONResponse
from cpl import api
from cpl.api.web_app import WebApp
from cpl.api.application.web_app import WebApp
from cpl.application import ApplicationBuilder
from cpl.core.configuration import Configuration
from cpl.core.environment import Environment
@@ -24,6 +24,8 @@ def main():
app.with_database()
app.with_authentication()
app.with_authorization()
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET")
app.with_routes_directory("routes")

View File

@@ -3,11 +3,14 @@ from urllib.request import Request
from starlette.responses import JSONResponse
from cpl.api.router import Router
from cpl.auth.permission.permissions import Permissions
from cpl.core.log import Logger
from service import PingService
@Router.authenticate()
@Router.authorize(permissions=[Permissions.administrator])
# @Router.authorize(policies=["test"])
@Router.get(f"/ping")
async def ping(r: Request, ping: PingService, logger: Logger):
logger.info(f"Ping: {ping}")