diff --git a/src/cpl-api/cpl/api/__init__.py b/src/cpl-api/cpl/api/__init__.py index d3755d7e..3ba6cbd9 100644 --- a/src/cpl-api/cpl/api/__init__.py +++ b/src/cpl-api/cpl/api/__init__.py @@ -22,5 +22,8 @@ def add_api(collection: _ServiceCollection): dependency_error("cpl-auth", e) + from cpl.api.registry.policy import PolicyRegistry + collection.add_singleton(PolicyRegistry) + _ServiceCollection.with_module(add_api, __name__) diff --git a/src/cpl-api/cpl/api/application/__init__.py b/src/cpl-api/cpl/api/application/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-api/cpl/api/web_app.py b/src/cpl-api/cpl/api/application/web_app.py similarity index 71% rename from src/cpl-api/cpl/api/web_app.py rename to src/cpl-api/cpl/api/application/web_app.py index bd5c0b1c..1736678f 100644 --- a/src/cpl-api/cpl/api/web_app.py +++ b/src/cpl-api/cpl/api/application/web_app.py @@ -1,5 +1,5 @@ import os -from typing import Mapping, Any, Callable +from typing import Mapping, Any, Callable, Self, Union import uvicorn from starlette.applications import Starlette @@ -11,20 +11,24 @@ from starlette.routing import Route from starlette.types import ExceptionHandler from cpl import api, auth -from cpl.api.api_logger import APILogger -from cpl.api.api_settings import ApiSettings +from cpl.api.registry.policy import PolicyRegistry from cpl.api.error import APIError +from cpl.api.logger import APILogger from cpl.api.middleware.authentication import AuthenticationMiddleware +from cpl.api.middleware.authorization import AuthorizationMiddleware from cpl.api.middleware.logging import LoggingMiddleware from cpl.api.middleware.request import RequestMiddleware +from cpl.api.model.policy import Policy from cpl.api.router import Router -from cpl.api.typing import HTTPMethods, PartialMiddleware +from cpl.api.settings import ApiSettings +from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver from cpl.application.abc.application_abc import ApplicationABC from cpl.core.configuration import Configuration from cpl.dependency.service_provider_abc import ServiceProviderABC _logger = APILogger("API") +PolicyInput = Union[dict[str, PolicyResolver], Policy] class WebApp(ApplicationABC): def __init__(self, services: ServiceProviderABC): @@ -32,6 +36,7 @@ class WebApp(ApplicationABC): self._app: Starlette | None = None self._api_settings = Configuration.get(ApiSettings) + self._policy_registry = services.get_service(PolicyRegistry) self._routes: list[Route] = [] self._middleware: list[Middleware] = [ @@ -66,11 +71,12 @@ class WebApp(ApplicationABC): _logger.debug(f"Allowed origins: {origins}") return origins.split(",") - def with_database(self): + def with_database(self) -> Self: self.with_migrations() self.with_seeders() + return self - def with_app(self, app: Starlette): + def with_app(self, app: Starlette) -> Self: assert app is not None, "app must not be None" assert isinstance(app, Starlette), "app must be an instance of Starlette" self._app = app @@ -80,7 +86,7 @@ class WebApp(ApplicationABC): if self._app is not None: raise ValueError("App is already set, cannot add routes or middleware") - def with_routes_directory(self, directory: str) -> "WebApp": + def with_routes_directory(self, directory: str) -> Self: self._check_for_app() assert directory is not None, "directory must not be None" @@ -94,14 +100,14 @@ class WebApp(ApplicationABC): return self - def with_routes(self, routes: list[Route]) -> "WebApp": + def with_routes(self, routes: list[Route]) -> Self: self._check_for_app() assert self._routes is not None, "routes must not be None" assert all(isinstance(route, Route) for route in routes), "all routes must be of type starlette.routing.Route" self._routes.extend(routes) return self - def with_route(self, path: str, fn: Callable[[Request], Any], method: HTTPMethods, **kwargs) -> "WebApp": + def with_route(self, path: str, fn: Callable[[Request], Any], method: HTTPMethods, **kwargs) -> Self: self._check_for_app() assert path is not None, "path must not be None" assert fn is not None, "fn must not be None" @@ -117,7 +123,7 @@ class WebApp(ApplicationABC): self._routes.append(Route(path, fn, methods=[method], **kwargs)) return self - def with_middleware(self, middleware: PartialMiddleware) -> "WebApp": + def with_middleware(self, middleware: PartialMiddleware) -> Self: self._check_for_app() if isinstance(middleware, Middleware): @@ -129,15 +135,49 @@ class WebApp(ApplicationABC): return self - def with_authentication(self): + def with_authentication(self) -> Self: self.with_middleware(AuthenticationMiddleware) return self - def with_authorization(self): - pass + def with_authorization(self, *policies: list[PolicyInput] | PolicyInput) -> Self: + if policies: + _policies = [] + + if not isinstance(policies, list): + policies = list(policies) + + for i, policy in enumerate(policies): + if isinstance(policy, dict): + for name, resolver in policy.items(): + if not isinstance(name, str): + _logger.warning(f"Skipping policy at index {i}, name must be a string") + continue + + if not callable(resolver): + _logger.warning(f"Skipping policy {name}, resolver must be callable") + continue + + _policies.append(Policy(name, resolver)) + continue + + _policies.append(policy) + + self._policy_registry.extend_policies(_policies) + + self.with_middleware(AuthorizationMiddleware) + return self + + def _validate_policies(self): + for rule in Router.get_authorization_rules(): + for policy_name in rule["policies"]: + policy = self._policy_registry.get(policy_name) + if not policy: + _logger.fatal(f"Authorization policy '{policy_name}' not found") async def main(self): _logger.debug(f"Preparing API") + self._validate_policies() + if self._app is None: routes = [ Route( @@ -166,13 +206,6 @@ class WebApp(ApplicationABC): app = self._app _logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}") - # uvicorn.run( - # app, - # host=self._api_settings.host, - # port=self._api_settings.port, - # log_config=None, - # loop="asyncio" - # ) config = uvicorn.Config( app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio" diff --git a/src/cpl-api/cpl/api/error.py b/src/cpl-api/cpl/api/error.py index b58df339..50329e98 100644 --- a/src/cpl-api/cpl/api/error.py +++ b/src/cpl-api/cpl/api/error.py @@ -7,14 +7,23 @@ from starlette.types import Scope, Receive, Send class APIError(HTTPException): status_code = 500 - @classmethod - async def asgi_response(cls, scope: Scope, receive: Receive, send: Send): - r = JSONResponse({"error": cls.__name__}, status_code=cls.status_code) + def __init__(self, message: str = ""): + super().__init__(self.status_code, message) + self._message = message + + @property + def error_message(self) -> str: + if self._message: + return f"{type(self).__name__}: {self._message}" + + return f"{type(self).__name__}" + + async def asgi_response(self, scope: Scope, receive: Receive, send: Send): + r = JSONResponse({"error": self.error_message}, status_code=self.status_code) return await r(scope, receive, send) - @classmethod - def response(cls): - return JSONResponse({"error": cls.__name__}, status_code=cls.status_code) + def response(self): + return JSONResponse({"error": self.error_message}, status_code=self.status_code) class Unauthorized(APIError): diff --git a/src/cpl-api/cpl/api/api_logger.py b/src/cpl-api/cpl/api/logger.py similarity index 100% rename from src/cpl-api/cpl/api/api_logger.py rename to src/cpl-api/cpl/api/logger.py diff --git a/src/cpl-api/cpl/api/middleware/authentication.py b/src/cpl-api/cpl/api/middleware/authentication.py index 4b4b5cc6..cd047706 100644 --- a/src/cpl-api/cpl/api/middleware/authentication.py +++ b/src/cpl-api/cpl/api/middleware/authentication.py @@ -2,12 +2,13 @@ from keycloak import KeycloakAuthenticationError from starlette.types import Scope, Receive, Send from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware -from cpl.api.api_logger import APILogger +from cpl.api.logger import APILogger from cpl.api.error import Unauthorized from cpl.api.middleware.request import get_request from cpl.api.router import Router from cpl.auth.keycloak import KeycloakClient from cpl.auth.schema import AuthUserDao, AuthUser +from cpl.core.ctx import set_user from cpl.dependency import ServiceProviderABC _logger = APILogger(__name__) @@ -53,6 +54,9 @@ class AuthenticationMiddleware(ASGIMiddleware): _logger.debug(f"Unauthorized access to {url}, user is deleted") return await Unauthorized("User is deleted").asgi_response(scope, receive, send) + request.state.user = user + set_user(user) + return await self._call_next(scope, receive, send) async def _get_or_crate_user(self, keycloak_id: str) -> AuthUser: diff --git a/src/cpl-api/cpl/api/middleware/authorization.py b/src/cpl-api/cpl/api/middleware/authorization.py new file mode 100644 index 00000000..6d760b3a --- /dev/null +++ b/src/cpl-api/cpl/api/middleware/authorization.py @@ -0,0 +1,64 @@ +from starlette.types import Scope, Receive, Send + +from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware +from cpl.api.error import Unauthorized, Forbidden +from cpl.api.logger import APILogger +from cpl.api.middleware.request import get_request +from cpl.api.model.validation_match import ValidationMatch +from cpl.api.registry.policy import PolicyRegistry +from cpl.api.router import Router +from cpl.auth.schema._administration.auth_user_dao import AuthUserDao +from cpl.core.ctx.user_context import get_user +from cpl.dependency.service_provider_abc import ServiceProviderABC + +_logger = APILogger(__name__) + + +class AuthorizationMiddleware(ASGIMiddleware): + + @ServiceProviderABC.inject + def __init__(self, app, policies: PolicyRegistry, user_dao: AuthUserDao): + ASGIMiddleware.__init__(self, app) + + self._policies = policies + self._user_dao = user_dao + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + request = get_request() + user = get_user() + + if not user: + return await Unauthorized(f"Unknown user").asgi_response(scope, receive, send) + + roles = await user.roles + request.state.roles = roles + role_names = [r.name for r in roles] + + perms = await user.permissions + request.state.permissions = perms + perm_names = [p.name for p in perms] + + for rule in Router.get_authorization_rules(): + match = rule["match"] + if rule["roles"]: + if match == ValidationMatch.all and not all(r in role_names for r in rule["roles"]): + return await Forbidden(f"missing roles: {rule["roles"]}").asgi_response(scope, receive, send) + if match == ValidationMatch.any and not any(r in role_names for r in rule["roles"]): + return await Forbidden(f"missing roles: {rule["roles"]}").asgi_response(scope, receive, send) + + if rule["permissions"]: + if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]): + return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send) + if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]): + return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send) + + for policy_name in rule["policies"]: + policy = self._policies.get(policy_name) + if not policy: + _logger.warning(f"Authorization policy '{policy_name}' not found") + continue + + if not await policy.resolve(user): + return await Forbidden(f"policy {policy.name} failed").asgi_response(scope, receive, send) + + return await self._call_next(scope, receive, send) \ No newline at end of file diff --git a/src/cpl-api/cpl/api/middleware/logging.py b/src/cpl-api/cpl/api/middleware/logging.py index 21feb63f..e47cbe77 100644 --- a/src/cpl-api/cpl/api/middleware/logging.py +++ b/src/cpl-api/cpl/api/middleware/logging.py @@ -4,7 +4,7 @@ from starlette.requests import Request from starlette.types import Receive, Scope, Send from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware -from cpl.api.api_logger import APILogger +from cpl.api.logger import APILogger from cpl.api.middleware.request import get_request _logger = APILogger(__name__) diff --git a/src/cpl-api/cpl/api/middleware/request.py b/src/cpl-api/cpl/api/middleware/request.py index 85051fde..5255b26c 100644 --- a/src/cpl-api/cpl/api/middleware/request.py +++ b/src/cpl-api/cpl/api/middleware/request.py @@ -5,10 +5,9 @@ from uuid import uuid4 from starlette.requests import Request from starlette.types import Scope, Receive, Send -from starlette.websockets import WebSocket from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware -from cpl.api.api_logger import APILogger +from cpl.api.logger import APILogger from cpl.api.typing import TRequest _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None) @@ -50,5 +49,5 @@ class RequestMiddleware(ASGIMiddleware): _request_context.reset(self._ctx_token) -def get_request() -> Optional[Union[TRequest, WebSocket]]: +def get_request() -> Optional[TRequest]: return _request_context.get() diff --git a/src/cpl-api/cpl/api/model/__init__.py b/src/cpl-api/cpl/api/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-api/cpl/api/model/policy.py b/src/cpl-api/cpl/api/model/policy.py new file mode 100644 index 00000000..ea118e0c --- /dev/null +++ b/src/cpl-api/cpl/api/model/policy.py @@ -0,0 +1,34 @@ +from asyncio import iscoroutinefunction +from typing import Optional, Any, Coroutine, Awaitable + +from cpl.api.typing import PolicyResolver +from cpl.core.ctx import get_user + + +class Policy: + def __init__( + self, + name: str, + resolver: PolicyResolver = None, + ): + self._name = name + self._resolver: Optional[PolicyResolver] = resolver + + @property + def name(self) -> str: + return self._name + + @property + def resolvers(self) -> PolicyResolver: + return self._resolver + + async def resolve(self, *args, **kwargs) -> bool: + if not self._resolver: + return True + + if callable(self._resolver): + if iscoroutinefunction(self._resolver): + return await self._resolver(get_user()) + + return self._resolver(get_user()) + return False diff --git a/src/cpl-api/cpl/api/model/validation_match.py b/src/cpl-api/cpl/api/model/validation_match.py new file mode 100644 index 00000000..9121fa95 --- /dev/null +++ b/src/cpl-api/cpl/api/model/validation_match.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class ValidationMatch(Enum): + any = "any" + all = "all" diff --git a/src/cpl-api/cpl/api/registry/__init__.py b/src/cpl-api/cpl/api/registry/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-api/cpl/api/registry/policy.py b/src/cpl-api/cpl/api/registry/policy.py new file mode 100644 index 00000000..63fef54a --- /dev/null +++ b/src/cpl-api/cpl/api/registry/policy.py @@ -0,0 +1,23 @@ +from typing import Optional + +from cpl.api.model.policy import Policy + + +class PolicyRegistry: + def __init__(self): + self._policies: dict[str, Policy] = {} + + def extend_policies(self, policies: list[Policy]): + for policy in policies: + self.add_policy(policy) + + def add_policy(self, policy: Policy): + assert isinstance(policy, Policy), "policy must be an instance of Policy" + + if policy.name in self._policies: + raise ValueError(f"Policy {policy.name} is already registered") + + self._policies[policy.name] = policy + + def get(self, name: str) -> Optional[Policy]: + return self._policies.get(name) diff --git a/src/cpl-api/cpl/api/router.py b/src/cpl-api/cpl/api/router.py index e8936b25..7fa2df99 100644 --- a/src/cpl-api/cpl/api/router.py +++ b/src/cpl-api/cpl/api/router.py @@ -1,9 +1,14 @@ +from enum import Enum + from starlette.routing import Route +from cpl.api.model.validation_match import ValidationMatch + class Router: _registered_routes: list[Route] = [] _auth_required: list[str] = [] + _authorization_rules: list[dict] = [] @classmethod def get_routes(cls) -> list[Route]: @@ -13,6 +18,10 @@ class Router: def get_auth_required_routes(cls) -> list[str]: return cls._auth_required + @classmethod + def get_authorization_rules(cls) -> list[dict]: + return cls._authorization_rules + @classmethod def authenticate(cls): """ @@ -32,6 +41,50 @@ class Router: return inner + @classmethod + def authorize(cls, roles: list[str | Enum]=None, permissions: list[str | Enum]=None, policies: list[str]=None, match: ValidationMatch=None): + """ + Decorator to mark a route as requiring authorization. + Usage: + @Route.authorize() + @Route.get("/example") + async def example_endpoint(request: TRequest): + ... + """ + assert roles is None or isinstance(roles, list), "roles must be a list of strings" + assert permissions is None or isinstance(permissions, list), "permissions must be a list of strings" + assert policies is None or isinstance(policies, list), "policies must be a list of strings" + assert match is None or isinstance(match, ValidationMatch), "match must be an instance of ValidationMatch" + + if roles is not None: + for role in roles: + if isinstance(role, Enum): + roles[roles.index(role)] = role.value + + if permissions is not None: + for perm in permissions: + if isinstance(perm, Enum): + permissions[permissions.index(perm)] = perm.value + + def inner(fn): + route_path = getattr(fn, "_route_path", None) + if not route_path: + return fn + + if route_path in cls._authorization_rules: + raise ValueError(f"Route {route_path} is already registered for authorization") + + cls._authorization_rules.append({ + "roles": roles or [], + "permissions": permissions or [], + "policies": policies or [], + "match": match or ValidationMatch.all, + }) + + return fn + + return inner + @classmethod def route(cls, path=None, **kwargs): def inner(fn): diff --git a/src/cpl-api/cpl/api/api_settings.py b/src/cpl-api/cpl/api/settings.py similarity index 100% rename from src/cpl-api/cpl/api/api_settings.py rename to src/cpl-api/cpl/api/settings.py diff --git a/src/cpl-api/cpl/api/typing.py b/src/cpl-api/cpl/api/typing.py index ca570e59..b139b8a7 100644 --- a/src/cpl-api/cpl/api/typing.py +++ b/src/cpl-api/cpl/api/typing.py @@ -1,13 +1,19 @@ -from typing import Union, Literal, Callable +from typing import Union, Literal, Callable, Type, Awaitable from urllib.request import Request from starlette.middleware import Middleware from starlette.types import ASGIApp from starlette.websockets import WebSocket +from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware +from cpl.auth.schema import AuthUser + TRequest = Union[Request, WebSocket] HTTPMethods = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"] PartialMiddleware = Union[ + ASGIMiddleware, + Type[ASGIMiddleware], Middleware, Callable[[ASGIApp], ASGIApp], ] +PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]] \ No newline at end of file diff --git a/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py b/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py index df414960..c219f87d 100644 --- a/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py +++ b/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py @@ -43,9 +43,9 @@ class AuthUserDao(DbModelDaoABC[AuthUser]): p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value) result = await self._db.select_map( f""" - SELECT COUNT(*) - FROM permission.role_users ru - JOIN permission.role_permissions rp ON ru.roleId = rp.roleId + SELECT COUNT(*) as count + FROM {TableManager.get("role_users")} ru + JOIN {TableManager.get("role_permissions")} rp ON ru.roleId = rp.roleId WHERE ru.userId = {user_id} AND rp.permissionId = {p.id} AND ru.deleted = FALSE @@ -61,9 +61,9 @@ class AuthUserDao(DbModelDaoABC[AuthUser]): result = await self._db.select_map( f""" SELECT p.* - FROM permission.permissions p - JOIN permission.role_permissions rp ON p.id = rp.permissionId - JOIN permission.role_users ru ON rp.roleId = ru.roleId + FROM {TableManager.get("permissions")} p + JOIN {TableManager.get("role_permissions")} rp ON p.id = rp.permissionId + JOIN {TableManager.get("role_users")} ru ON rp.roleId = ru.roleId WHERE ru.userId = {user_id} AND rp.deleted = FALSE AND ru.deleted = FALSE; diff --git a/src/cpl-database/cpl/database/abc/data_access_object_abc.py b/src/cpl-database/cpl/database/abc/data_access_object_abc.py index 1ea7b88f..ec011ca9 100644 --- a/src/cpl-database/cpl/database/abc/data_access_object_abc.py +++ b/src/cpl-database/cpl/database/abc/data_access_object_abc.py @@ -487,7 +487,7 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]): builder.with_temp_table(self._external_fields[temp]) if for_count: - builder.with_attribute("COUNT(*)", ignore_table_name=True) + builder.with_attribute("COUNT(*) as count", ignore_table_name=True) else: builder.with_attribute("*") diff --git a/src/cpl-database/cpl/database/table_manager.py b/src/cpl-database/cpl/database/table_manager.py index 1d7ad7a1..9bd1f6b2 100644 --- a/src/cpl-database/cpl/database/table_manager.py +++ b/src/cpl-database/cpl/database/table_manager.py @@ -33,7 +33,7 @@ class TableManager: }, "role_users": { ServerTypes.POSTGRES: "permission.role_users", - ServerTypes.MYSQL: "permission_role_users", + ServerTypes.MYSQL: "permission_role_auth_users", }, } diff --git a/tests/custom/api/src/main.py b/tests/custom/api/src/main.py index 58ad878d..ffae56c8 100644 --- a/tests/custom/api/src/main.py +++ b/tests/custom/api/src/main.py @@ -1,7 +1,7 @@ from starlette.responses import JSONResponse from cpl import api -from cpl.api.web_app import WebApp +from cpl.api.application.web_app import WebApp from cpl.application import ApplicationBuilder from cpl.core.configuration import Configuration from cpl.core.environment import Environment @@ -24,6 +24,8 @@ def main(): app.with_database() app.with_authentication() + app.with_authorization() + app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET") app.with_routes_directory("routes") diff --git a/tests/custom/api/src/routes/ping.py b/tests/custom/api/src/routes/ping.py index 0324a2f6..b77273e2 100644 --- a/tests/custom/api/src/routes/ping.py +++ b/tests/custom/api/src/routes/ping.py @@ -3,11 +3,14 @@ from urllib.request import Request from starlette.responses import JSONResponse from cpl.api.router import Router +from cpl.auth.permission.permissions import Permissions from cpl.core.log import Logger from service import PingService @Router.authenticate() +@Router.authorize(permissions=[Permissions.administrator]) +# @Router.authorize(policies=["test"]) @Router.get(f"/ping") async def ping(r: Request, ping: PingService, logger: Logger): logger.info(f"Ping: {ping}")