diff --git a/src/cpl-api/cpl/api/__init__.py b/src/cpl-api/cpl/api/__init__.py index 3ba6cbd9..b163c104 100644 --- a/src/cpl-api/cpl/api/__init__.py +++ b/src/cpl-api/cpl/api/__init__.py @@ -23,7 +23,10 @@ def add_api(collection: _ServiceCollection): dependency_error("cpl-auth", e) from cpl.api.registry.policy import PolicyRegistry + from cpl.api.registry.route import RouteRegistry + collection.add_singleton(PolicyRegistry) + collection.add_singleton(RouteRegistry) _ServiceCollection.with_module(add_api, __name__) diff --git a/src/cpl-api/cpl/api/application/web_app.py b/src/cpl-api/cpl/api/application/web_app.py index 1736678f..0daccaba 100644 --- a/src/cpl-api/cpl/api/application/web_app.py +++ b/src/cpl-api/cpl/api/application/web_app.py @@ -1,4 +1,5 @@ import os +from enum import Enum from typing import Mapping, Any, Callable, Self, Union import uvicorn @@ -7,18 +8,20 @@ from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.responses import JSONResponse -from starlette.routing import Route from starlette.types import ExceptionHandler from cpl import api, auth -from cpl.api.registry.policy import PolicyRegistry from cpl.api.error import APIError from cpl.api.logger import APILogger from cpl.api.middleware.authentication import AuthenticationMiddleware from cpl.api.middleware.authorization import AuthorizationMiddleware from cpl.api.middleware.logging import LoggingMiddleware from cpl.api.middleware.request import RequestMiddleware +from cpl.api.model.api_route import ApiRoute from cpl.api.model.policy import Policy +from cpl.api.model.validation_match import ValidationMatch +from cpl.api.registry.policy import PolicyRegistry +from cpl.api.registry.route import RouteRegistry from cpl.api.router import Router from cpl.api.settings import ApiSettings from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver @@ -30,15 +33,16 @@ _logger = APILogger("API") PolicyInput = Union[dict[str, PolicyResolver], Policy] + class WebApp(ApplicationABC): def __init__(self, services: ServiceProviderABC): super().__init__(services, [auth, api]) self._app: Starlette | None = None self._api_settings = Configuration.get(ApiSettings) - self._policy_registry = services.get_service(PolicyRegistry) + self._policies = services.get_service(PolicyRegistry) + self._routes = services.get_service(RouteRegistry) - self._routes: list[Route] = [] self._middleware: list[Middleware] = [ Middleware(RequestMiddleware), Middleware(LoggingMiddleware), @@ -100,27 +104,64 @@ class WebApp(ApplicationABC): return self - def with_routes(self, routes: list[Route]) -> Self: + def with_routes( + self, + routes: list[ApiRoute], + method: HTTPMethods, + authentication: bool = False, + roles: list[str | Enum] = None, + permissions: list[str | Enum] = None, + policies: list[str] = None, + match: ValidationMatch = None, + ) -> Self: self._check_for_app() assert self._routes is not None, "routes must not be None" - assert all(isinstance(route, Route) for route in routes), "all routes must be of type starlette.routing.Route" - self._routes.extend(routes) + assert all(isinstance(route, ApiRoute) for route in routes), "all routes must be of type ApiRoute" + for route in routes: + self.with_route( + route.path, + route.fn, + method, + authentication, + roles, + permissions, + policies, + match, + ) return self - def with_route(self, path: str, fn: Callable[[Request], Any], method: HTTPMethods, **kwargs) -> Self: + def with_route( + self, + path: str, + fn: Callable[[Request], Any], + method: HTTPMethods, + authentication: bool = False, + roles: list[str | Enum] = None, + permissions: list[str | Enum] = None, + policies: list[str] = None, + match: ValidationMatch = None, + ) -> Self: self._check_for_app() assert path is not None, "path must not be None" assert fn is not None, "fn must not be None" assert method in [ "GET", + "HEAD", "POST", "PUT", - "DELETE", "PATCH", + "DELETE", "OPTIONS", - "HEAD", ], "method must be a valid HTTP method" - self._routes.append(Route(path, fn, methods=[method], **kwargs)) + + Router.route(path, method, registry=self._routes)(fn) + + if authentication: + Router.authenticate()(fn) + + if roles or permissions or policies: + Router.authorize(roles, permissions, policies, match)(fn) + return self def with_middleware(self, middleware: PartialMiddleware) -> Self: @@ -162,7 +203,7 @@ class WebApp(ApplicationABC): _policies.append(policy) - self._policy_registry.extend_policies(_policies) + self._policies.extend_policies(_policies) self.with_middleware(AuthorizationMiddleware) return self @@ -170,7 +211,7 @@ class WebApp(ApplicationABC): def _validate_policies(self): for rule in Router.get_authorization_rules(): for policy_name in rule["policies"]: - policy = self._policy_registry.get(policy_name) + policy = self._policies.get(policy_name) if not policy: _logger.fatal(f"Authorization policy '{policy_name}' not found") @@ -179,15 +220,7 @@ class WebApp(ApplicationABC): self._validate_policies() if self._app is None: - routes = [ - Route( - path=route.path, - endpoint=self._services.inject(route.endpoint), - methods=route.methods, - name=route.name, - ) - for route in self._routes + Router.get_routes() - ] + routes = [route.to_starlette(self._services.inject) for route in self._routes.all()] app = Starlette( routes=routes, diff --git a/src/cpl-api/cpl/api/middleware/authorization.py b/src/cpl-api/cpl/api/middleware/authorization.py index 6d760b3a..021017ba 100644 --- a/src/cpl-api/cpl/api/middleware/authorization.py +++ b/src/cpl-api/cpl/api/middleware/authorization.py @@ -25,8 +25,13 @@ class AuthorizationMiddleware(ASGIMiddleware): async def __call__(self, scope: Scope, receive: Receive, send: Send): request = get_request() - user = get_user() + url = request.url.path + if url not in Router.get_authorization_rules_paths(): + _logger.trace(f"No authorization required for {url}") + return await self._app(scope, receive, send) + + user = get_user() if not user: return await Unauthorized(f"Unknown user").asgi_response(scope, receive, send) @@ -48,9 +53,13 @@ class AuthorizationMiddleware(ASGIMiddleware): if rule["permissions"]: if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]): - return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send) + return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response( + scope, receive, send + ) if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]): - return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send) + return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response( + scope, receive, send + ) for policy_name in rule["policies"]: policy = self._policies.get(policy_name) @@ -61,4 +70,4 @@ class AuthorizationMiddleware(ASGIMiddleware): if not await policy.resolve(user): return await Forbidden(f"policy {policy.name} failed").asgi_response(scope, receive, send) - return await self._call_next(scope, receive, send) \ No newline at end of file + return await self._call_next(scope, receive, send) diff --git a/src/cpl-api/cpl/api/model/api_route.py b/src/cpl-api/cpl/api/model/api_route.py new file mode 100644 index 00000000..64f94d34 --- /dev/null +++ b/src/cpl-api/cpl/api/model/api_route.py @@ -0,0 +1,43 @@ +from typing import Callable + +from starlette.routing import Route + +from cpl.api.typing import HTTPMethods + + +class ApiRoute: + + def __init__(self, path: str, fn: Callable, method: HTTPMethods, **kwargs): + self._path = path + self._fn = fn + self._method = method + + self._kwargs = kwargs + + @property + def name(self) -> str: + return self._fn.__name__ + + @property + def fn(self) -> Callable: + return self._fn + + @property + def path(self) -> str: + return self._path + + @property + def method(self) -> HTTPMethods: + return self._method + + @property + def kwargs(self) -> dict: + return self._kwargs + + def to_starlette(self, wrap_endpoint: Callable = None) -> Route: + return Route( + self._path, + self._fn if not wrap_endpoint else wrap_endpoint(self._fn), + methods=[self._method], + **self._kwargs, + ) diff --git a/src/cpl-api/cpl/api/registry/policy.py b/src/cpl-api/cpl/api/registry/policy.py index 63fef54a..f59d9bb2 100644 --- a/src/cpl-api/cpl/api/registry/policy.py +++ b/src/cpl-api/cpl/api/registry/policy.py @@ -1,23 +1,28 @@ from typing import Optional from cpl.api.model.policy import Policy +from cpl.core.abc.registry_abc import RegistryABC -class PolicyRegistry: +class PolicyRegistry(RegistryABC): + def __init__(self): - self._policies: dict[str, Policy] = {} + RegistryABC.__init__(self) - def extend_policies(self, policies: list[Policy]): - for policy in policies: - self.add_policy(policy) + def extend(self, items: list[Policy]): + for policy in items: + self.add(policy) - def add_policy(self, policy: Policy): - assert isinstance(policy, Policy), "policy must be an instance of Policy" + def add(self, item: Policy): + assert isinstance(item, Policy), "policy must be an instance of Policy" - if policy.name in self._policies: - raise ValueError(f"Policy {policy.name} is already registered") + if item.name in self._items: + raise ValueError(f"Policy {item.name} is already registered") - self._policies[policy.name] = policy + self._items[item.name] = item - def get(self, name: str) -> Optional[Policy]: - return self._policies.get(name) + def get(self, key: str) -> Optional[Policy]: + return self._items.get(key) + + def all(self) -> list[Policy]: + return list(self._items.values()) diff --git a/src/cpl-api/cpl/api/registry/route.py b/src/cpl-api/cpl/api/registry/route.py new file mode 100644 index 00000000..6e9b167d --- /dev/null +++ b/src/cpl-api/cpl/api/registry/route.py @@ -0,0 +1,33 @@ +from typing import Optional + +from cpl.api.model.policy import Policy +from cpl.api.model.api_route import ApiRoute +from cpl.core.abc.registry_abc import RegistryABC + + +class RouteRegistry(RegistryABC): + + def __init__(self): + RegistryABC.__init__(self) + + def extend(self, items: list[ApiRoute]): + for policy in items: + self.add(policy) + + def add(self, item: ApiRoute): + assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute" + + if item.path in self._items: + raise ValueError(f"ApiRoute {item.path} is already registered") + + self._items[item.path] = item + + def set(self, item: ApiRoute): + assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute" + self._items[item.path] = item + + def get(self, key: str) -> Optional[ApiRoute]: + return self._items.get(key) + + def all(self) -> list[ApiRoute]: + return list(self._items.values()) diff --git a/src/cpl-api/cpl/api/router.py b/src/cpl-api/cpl/api/router.py index 7fa2df99..d89b2458 100644 --- a/src/cpl-api/cpl/api/router.py +++ b/src/cpl-api/cpl/api/router.py @@ -1,26 +1,25 @@ from enum import Enum -from starlette.routing import Route - from cpl.api.model.validation_match import ValidationMatch +from cpl.api.registry.route import RouteRegistry +from cpl.api.typing import HTTPMethods class Router: - _registered_routes: list[Route] = [] _auth_required: list[str] = [] - _authorization_rules: list[dict] = [] - - @classmethod - def get_routes(cls) -> list[Route]: - return cls._registered_routes + _authorization_rules: dict[str, dict] = {} @classmethod def get_auth_required_routes(cls) -> list[str]: return cls._auth_required + @classmethod + def get_authorization_rules_paths(cls) -> list[str]: + return list(cls._authorization_rules.keys()) + @classmethod def get_authorization_rules(cls) -> list[dict]: - return cls._authorization_rules + return list(cls._authorization_rules.values()) @classmethod def authenticate(cls): @@ -42,7 +41,13 @@ class Router: return inner @classmethod - def authorize(cls, roles: list[str | Enum]=None, permissions: list[str | Enum]=None, policies: list[str]=None, match: ValidationMatch=None): + def authorize( + cls, + roles: list[str | Enum] = None, + permissions: list[str | Enum] = None, + policies: list[str] = None, + match: ValidationMatch = None, + ): """ Decorator to mark a route as requiring authorization. Usage: @@ -67,52 +72,64 @@ class Router: permissions[permissions.index(perm)] = perm.value def inner(fn): - route_path = getattr(fn, "_route_path", None) - if not route_path: + path = getattr(fn, "_route_path", None) + if not path: return fn - if route_path in cls._authorization_rules: - raise ValueError(f"Route {route_path} is already registered for authorization") + if path in cls._authorization_rules: + raise ValueError(f"Route {path} is already registered for authorization") - cls._authorization_rules.append({ + cls._authorization_rules[path] = { "roles": roles or [], "permissions": permissions or [], "policies": policies or [], "match": match or ValidationMatch.all, - }) + } return fn return inner @classmethod - def route(cls, path=None, **kwargs): + def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs): + if not registry: + from cpl.api.model.api_route import ApiRoute + from cpl.dependency.service_provider_abc import ServiceProviderABC + + routes = ServiceProviderABC.get_global_service(RouteRegistry) + else: + routes = registry + def inner(fn): - cls._registered_routes.append(Route(path, fn, **kwargs)) + routes.add(ApiRoute(path, fn, method, **kwargs)) setattr(fn, "_route_path", path) return fn return inner @classmethod - def get(cls, path=None, **kwargs): - return cls.route(path, methods=["GET"], **kwargs) + def get(cls, path: str, **kwargs): + return cls.route(path, "GET", **kwargs) @classmethod - def post(cls, path=None, **kwargs): - return cls.route(path, methods=["POST"], **kwargs) + def head(cls, path: str, **kwargs): + return cls.route(path, "HEAD", **kwargs) @classmethod - def head(cls, path=None, **kwargs): - return cls.route(path, methods=["HEAD"], **kwargs) + def post(cls, path: str, **kwargs): + return cls.route(path, "POST", **kwargs) @classmethod - def put(cls, path=None, **kwargs): - return cls.route(path, methods=["PUT"], **kwargs) + def put(cls, path: str, **kwargs): + return cls.route(path, "PUT", **kwargs) @classmethod - def delete(cls, path=None, **kwargs): - return cls.route(path, methods=["DELETE"], **kwargs) + def patch(cls, path: str, **kwargs): + return cls.route(path, "PATCH", **kwargs) + + @classmethod + def delete(cls, path: str, **kwargs): + return cls.route(path, "DELETE", **kwargs) @classmethod def override(cls): @@ -125,13 +142,22 @@ class Router: ... """ + from cpl.api.model.api_route import ApiRoute + from cpl.dependency.service_provider_abc import ServiceProviderABC + + routes = ServiceProviderABC.get_global_service(RouteRegistry) + def inner(fn): - route_path = getattr(fn, "_route_path", None) + path = getattr(fn, "_route_path", None) + if path is None: + raise ValueError("Cannot override a route that has not been registered yet") - routes = list(filter(lambda x: x.path == route_path, cls._registered_routes)) - for route in routes[:-1]: - cls._registered_routes.remove(route) + route = routes.get(path) + if route is None: + raise ValueError(f"Cannot override a route that does not exist: {path}") + routes.add(ApiRoute(path, fn, route.method, **route.kwargs)) + setattr(fn, "_route_path", path) return fn return inner diff --git a/src/cpl-api/cpl/api/typing.py b/src/cpl-api/cpl/api/typing.py index b139b8a7..c8319900 100644 --- a/src/cpl-api/cpl/api/typing.py +++ b/src/cpl-api/cpl/api/typing.py @@ -9,11 +9,11 @@ from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.auth.schema import AuthUser TRequest = Union[Request, WebSocket] -HTTPMethods = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"] +HTTPMethods = Literal["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"] PartialMiddleware = Union[ ASGIMiddleware, Type[ASGIMiddleware], Middleware, Callable[[ASGIApp], ASGIApp], ] -PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]] \ No newline at end of file +PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]] diff --git a/src/cpl-core/cpl/core/abc/__init__.py b/src/cpl-core/cpl/core/abc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-core/cpl/core/abc/registry_abc.py b/src/cpl-core/cpl/core/abc/registry_abc.py new file mode 100644 index 00000000..50837ad8 --- /dev/null +++ b/src/cpl-core/cpl/core/abc/registry_abc.py @@ -0,0 +1,23 @@ +from abc import abstractmethod, ABC +from typing import Generic + +from cpl.core.typing import T + + +class RegistryABC(ABC, Generic[T]): + + @abstractmethod + def __init__(self): + self._items: dict[str, T] = {} + + @abstractmethod + def extend(self, items: list[T]) -> None: ... + + @abstractmethod + def add(self, item: T) -> None: ... + + @abstractmethod + def get(self, key: str) -> T | None: ... + + @abstractmethod + def all(self) -> list[T]: ... diff --git a/tests/custom/api/src/main.py b/tests/custom/api/src/main.py index ffae56c8..286a6399 100644 --- a/tests/custom/api/src/main.py +++ b/tests/custom/api/src/main.py @@ -3,6 +3,7 @@ from starlette.responses import JSONResponse from cpl import api from cpl.api.application.web_app import WebApp from cpl.application import ApplicationBuilder +from cpl.auth.permission.permissions import Permissions from cpl.core.configuration import Configuration from cpl.core.environment import Environment from service import PingService @@ -26,7 +27,7 @@ def main(): app.with_authentication() app.with_authorization() - app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET") + app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET", authentication=True, permissions=[Permissions.administrator]) app.with_routes_directory("routes") app.run()