Authorization via with_route
All checks were successful
Test before pr merge / test-lint (pull_request) Successful in 6s
Build on push / prepare (push) Successful in 9s
Build on push / query (push) Successful in 18s
Build on push / core (push) Successful in 18s
Build on push / dependency (push) Successful in 17s
Build on push / application (push) Successful in 15s
Build on push / mail (push) Successful in 18s
Build on push / translation (push) Successful in 18s
Build on push / database (push) Successful in 19s
Build on push / auth (push) Successful in 14s
Build on push / api (push) Successful in 14s

This commit is contained in:
2025-09-22 22:03:42 +02:00
parent d6b7eb9b30
commit 86ad953ff1
11 changed files with 249 additions and 73 deletions

View File

@@ -23,7 +23,10 @@ def add_api(collection: _ServiceCollection):
dependency_error("cpl-auth", e) dependency_error("cpl-auth", e)
from cpl.api.registry.policy import PolicyRegistry from cpl.api.registry.policy import PolicyRegistry
from cpl.api.registry.route import RouteRegistry
collection.add_singleton(PolicyRegistry) collection.add_singleton(PolicyRegistry)
collection.add_singleton(RouteRegistry)
_ServiceCollection.with_module(add_api, __name__) _ServiceCollection.with_module(add_api, __name__)

View File

@@ -1,4 +1,5 @@
import os import os
from enum import Enum
from typing import Mapping, Any, Callable, Self, Union from typing import Mapping, Any, Callable, Self, Union
import uvicorn import uvicorn
@@ -7,18 +8,20 @@ from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from starlette.routing import Route
from starlette.types import ExceptionHandler from starlette.types import ExceptionHandler
from cpl import api, auth from cpl import api, auth
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.error import APIError from cpl.api.error import APIError
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.api.middleware.authentication import AuthenticationMiddleware from cpl.api.middleware.authentication import AuthenticationMiddleware
from cpl.api.middleware.authorization import AuthorizationMiddleware from cpl.api.middleware.authorization import AuthorizationMiddleware
from cpl.api.middleware.logging import LoggingMiddleware from cpl.api.middleware.logging import LoggingMiddleware
from cpl.api.middleware.request import RequestMiddleware from cpl.api.middleware.request import RequestMiddleware
from cpl.api.model.api_route import ApiRoute
from cpl.api.model.policy import Policy from cpl.api.model.policy import Policy
from cpl.api.model.validation_match import ValidationMatch
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.registry.route import RouteRegistry
from cpl.api.router import Router from cpl.api.router import Router
from cpl.api.settings import ApiSettings from cpl.api.settings import ApiSettings
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
@@ -30,15 +33,16 @@ _logger = APILogger("API")
PolicyInput = Union[dict[str, PolicyResolver], Policy] PolicyInput = Union[dict[str, PolicyResolver], Policy]
class WebApp(ApplicationABC): class WebApp(ApplicationABC):
def __init__(self, services: ServiceProviderABC): def __init__(self, services: ServiceProviderABC):
super().__init__(services, [auth, api]) super().__init__(services, [auth, api])
self._app: Starlette | None = None self._app: Starlette | None = None
self._api_settings = Configuration.get(ApiSettings) self._api_settings = Configuration.get(ApiSettings)
self._policy_registry = services.get_service(PolicyRegistry) self._policies = services.get_service(PolicyRegistry)
self._routes = services.get_service(RouteRegistry)
self._routes: list[Route] = []
self._middleware: list[Middleware] = [ self._middleware: list[Middleware] = [
Middleware(RequestMiddleware), Middleware(RequestMiddleware),
Middleware(LoggingMiddleware), Middleware(LoggingMiddleware),
@@ -100,27 +104,64 @@ class WebApp(ApplicationABC):
return self return self
def with_routes(self, routes: list[Route]) -> Self: def with_routes(
self,
routes: list[ApiRoute],
method: HTTPMethods,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Self:
self._check_for_app() self._check_for_app()
assert self._routes is not None, "routes must not be None" assert self._routes is not None, "routes must not be None"
assert all(isinstance(route, Route) for route in routes), "all routes must be of type starlette.routing.Route" assert all(isinstance(route, ApiRoute) for route in routes), "all routes must be of type ApiRoute"
self._routes.extend(routes) for route in routes:
self.with_route(
route.path,
route.fn,
method,
authentication,
roles,
permissions,
policies,
match,
)
return self return self
def with_route(self, path: str, fn: Callable[[Request], Any], method: HTTPMethods, **kwargs) -> Self: def with_route(
self,
path: str,
fn: Callable[[Request], Any],
method: HTTPMethods,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Self:
self._check_for_app() self._check_for_app()
assert path is not None, "path must not be None" assert path is not None, "path must not be None"
assert fn is not None, "fn must not be None" assert fn is not None, "fn must not be None"
assert method in [ assert method in [
"GET", "GET",
"HEAD",
"POST", "POST",
"PUT", "PUT",
"DELETE",
"PATCH", "PATCH",
"DELETE",
"OPTIONS", "OPTIONS",
"HEAD",
], "method must be a valid HTTP method" ], "method must be a valid HTTP method"
self._routes.append(Route(path, fn, methods=[method], **kwargs))
Router.route(path, method, registry=self._routes)(fn)
if authentication:
Router.authenticate()(fn)
if roles or permissions or policies:
Router.authorize(roles, permissions, policies, match)(fn)
return self return self
def with_middleware(self, middleware: PartialMiddleware) -> Self: def with_middleware(self, middleware: PartialMiddleware) -> Self:
@@ -162,7 +203,7 @@ class WebApp(ApplicationABC):
_policies.append(policy) _policies.append(policy)
self._policy_registry.extend_policies(_policies) self._policies.extend_policies(_policies)
self.with_middleware(AuthorizationMiddleware) self.with_middleware(AuthorizationMiddleware)
return self return self
@@ -170,7 +211,7 @@ class WebApp(ApplicationABC):
def _validate_policies(self): def _validate_policies(self):
for rule in Router.get_authorization_rules(): for rule in Router.get_authorization_rules():
for policy_name in rule["policies"]: for policy_name in rule["policies"]:
policy = self._policy_registry.get(policy_name) policy = self._policies.get(policy_name)
if not policy: if not policy:
_logger.fatal(f"Authorization policy '{policy_name}' not found") _logger.fatal(f"Authorization policy '{policy_name}' not found")
@@ -179,15 +220,7 @@ class WebApp(ApplicationABC):
self._validate_policies() self._validate_policies()
if self._app is None: if self._app is None:
routes = [ routes = [route.to_starlette(self._services.inject) for route in self._routes.all()]
Route(
path=route.path,
endpoint=self._services.inject(route.endpoint),
methods=route.methods,
name=route.name,
)
for route in self._routes + Router.get_routes()
]
app = Starlette( app = Starlette(
routes=routes, routes=routes,

View File

@@ -25,8 +25,13 @@ class AuthorizationMiddleware(ASGIMiddleware):
async def __call__(self, scope: Scope, receive: Receive, send: Send): async def __call__(self, scope: Scope, receive: Receive, send: Send):
request = get_request() request = get_request()
user = get_user() url = request.url.path
if url not in Router.get_authorization_rules_paths():
_logger.trace(f"No authorization required for {url}")
return await self._app(scope, receive, send)
user = get_user()
if not user: if not user:
return await Unauthorized(f"Unknown user").asgi_response(scope, receive, send) return await Unauthorized(f"Unknown user").asgi_response(scope, receive, send)
@@ -48,9 +53,13 @@ class AuthorizationMiddleware(ASGIMiddleware):
if rule["permissions"]: if rule["permissions"]:
if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]): if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send) return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
scope, receive, send
)
if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]): if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send) return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
scope, receive, send
)
for policy_name in rule["policies"]: for policy_name in rule["policies"]:
policy = self._policies.get(policy_name) policy = self._policies.get(policy_name)

View File

@@ -0,0 +1,43 @@
from typing import Callable
from starlette.routing import Route
from cpl.api.typing import HTTPMethods
class ApiRoute:
def __init__(self, path: str, fn: Callable, method: HTTPMethods, **kwargs):
self._path = path
self._fn = fn
self._method = method
self._kwargs = kwargs
@property
def name(self) -> str:
return self._fn.__name__
@property
def fn(self) -> Callable:
return self._fn
@property
def path(self) -> str:
return self._path
@property
def method(self) -> HTTPMethods:
return self._method
@property
def kwargs(self) -> dict:
return self._kwargs
def to_starlette(self, wrap_endpoint: Callable = None) -> Route:
return Route(
self._path,
self._fn if not wrap_endpoint else wrap_endpoint(self._fn),
methods=[self._method],
**self._kwargs,
)

View File

@@ -1,23 +1,28 @@
from typing import Optional from typing import Optional
from cpl.api.model.policy import Policy from cpl.api.model.policy import Policy
from cpl.core.abc.registry_abc import RegistryABC
class PolicyRegistry: class PolicyRegistry(RegistryABC):
def __init__(self): def __init__(self):
self._policies: dict[str, Policy] = {} RegistryABC.__init__(self)
def extend_policies(self, policies: list[Policy]): def extend(self, items: list[Policy]):
for policy in policies: for policy in items:
self.add_policy(policy) self.add(policy)
def add_policy(self, policy: Policy): def add(self, item: Policy):
assert isinstance(policy, Policy), "policy must be an instance of Policy" assert isinstance(item, Policy), "policy must be an instance of Policy"
if policy.name in self._policies: if item.name in self._items:
raise ValueError(f"Policy {policy.name} is already registered") raise ValueError(f"Policy {item.name} is already registered")
self._policies[policy.name] = policy self._items[item.name] = item
def get(self, name: str) -> Optional[Policy]: def get(self, key: str) -> Optional[Policy]:
return self._policies.get(name) return self._items.get(key)
def all(self) -> list[Policy]:
return list(self._items.values())

View File

@@ -0,0 +1,33 @@
from typing import Optional
from cpl.api.model.policy import Policy
from cpl.api.model.api_route import ApiRoute
from cpl.core.abc.registry_abc import RegistryABC
class RouteRegistry(RegistryABC):
def __init__(self):
RegistryABC.__init__(self)
def extend(self, items: list[ApiRoute]):
for policy in items:
self.add(policy)
def add(self, item: ApiRoute):
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
if item.path in self._items:
raise ValueError(f"ApiRoute {item.path} is already registered")
self._items[item.path] = item
def set(self, item: ApiRoute):
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
self._items[item.path] = item
def get(self, key: str) -> Optional[ApiRoute]:
return self._items.get(key)
def all(self) -> list[ApiRoute]:
return list(self._items.values())

View File

@@ -1,26 +1,25 @@
from enum import Enum from enum import Enum
from starlette.routing import Route
from cpl.api.model.validation_match import ValidationMatch from cpl.api.model.validation_match import ValidationMatch
from cpl.api.registry.route import RouteRegistry
from cpl.api.typing import HTTPMethods
class Router: class Router:
_registered_routes: list[Route] = []
_auth_required: list[str] = [] _auth_required: list[str] = []
_authorization_rules: list[dict] = [] _authorization_rules: dict[str, dict] = {}
@classmethod
def get_routes(cls) -> list[Route]:
return cls._registered_routes
@classmethod @classmethod
def get_auth_required_routes(cls) -> list[str]: def get_auth_required_routes(cls) -> list[str]:
return cls._auth_required return cls._auth_required
@classmethod
def get_authorization_rules_paths(cls) -> list[str]:
return list(cls._authorization_rules.keys())
@classmethod @classmethod
def get_authorization_rules(cls) -> list[dict]: def get_authorization_rules(cls) -> list[dict]:
return cls._authorization_rules return list(cls._authorization_rules.values())
@classmethod @classmethod
def authenticate(cls): def authenticate(cls):
@@ -42,7 +41,13 @@ class Router:
return inner return inner
@classmethod @classmethod
def authorize(cls, roles: list[str | Enum]=None, permissions: list[str | Enum]=None, policies: list[str]=None, match: ValidationMatch=None): def authorize(
cls,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
):
""" """
Decorator to mark a route as requiring authorization. Decorator to mark a route as requiring authorization.
Usage: Usage:
@@ -67,52 +72,64 @@ class Router:
permissions[permissions.index(perm)] = perm.value permissions[permissions.index(perm)] = perm.value
def inner(fn): def inner(fn):
route_path = getattr(fn, "_route_path", None) path = getattr(fn, "_route_path", None)
if not route_path: if not path:
return fn return fn
if route_path in cls._authorization_rules: if path in cls._authorization_rules:
raise ValueError(f"Route {route_path} is already registered for authorization") raise ValueError(f"Route {path} is already registered for authorization")
cls._authorization_rules.append({ cls._authorization_rules[path] = {
"roles": roles or [], "roles": roles or [],
"permissions": permissions or [], "permissions": permissions or [],
"policies": policies or [], "policies": policies or [],
"match": match or ValidationMatch.all, "match": match or ValidationMatch.all,
}) }
return fn return fn
return inner return inner
@classmethod @classmethod
def route(cls, path=None, **kwargs): def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs):
if not registry:
from cpl.api.model.api_route import ApiRoute
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = ServiceProviderABC.get_global_service(RouteRegistry)
else:
routes = registry
def inner(fn): def inner(fn):
cls._registered_routes.append(Route(path, fn, **kwargs)) routes.add(ApiRoute(path, fn, method, **kwargs))
setattr(fn, "_route_path", path) setattr(fn, "_route_path", path)
return fn return fn
return inner return inner
@classmethod @classmethod
def get(cls, path=None, **kwargs): def get(cls, path: str, **kwargs):
return cls.route(path, methods=["GET"], **kwargs) return cls.route(path, "GET", **kwargs)
@classmethod @classmethod
def post(cls, path=None, **kwargs): def head(cls, path: str, **kwargs):
return cls.route(path, methods=["POST"], **kwargs) return cls.route(path, "HEAD", **kwargs)
@classmethod @classmethod
def head(cls, path=None, **kwargs): def post(cls, path: str, **kwargs):
return cls.route(path, methods=["HEAD"], **kwargs) return cls.route(path, "POST", **kwargs)
@classmethod @classmethod
def put(cls, path=None, **kwargs): def put(cls, path: str, **kwargs):
return cls.route(path, methods=["PUT"], **kwargs) return cls.route(path, "PUT", **kwargs)
@classmethod @classmethod
def delete(cls, path=None, **kwargs): def patch(cls, path: str, **kwargs):
return cls.route(path, methods=["DELETE"], **kwargs) return cls.route(path, "PATCH", **kwargs)
@classmethod
def delete(cls, path: str, **kwargs):
return cls.route(path, "DELETE", **kwargs)
@classmethod @classmethod
def override(cls): def override(cls):
@@ -125,13 +142,22 @@ class Router:
... ...
""" """
from cpl.api.model.api_route import ApiRoute
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = ServiceProviderABC.get_global_service(RouteRegistry)
def inner(fn): def inner(fn):
route_path = getattr(fn, "_route_path", None) path = getattr(fn, "_route_path", None)
if path is None:
raise ValueError("Cannot override a route that has not been registered yet")
routes = list(filter(lambda x: x.path == route_path, cls._registered_routes)) route = routes.get(path)
for route in routes[:-1]: if route is None:
cls._registered_routes.remove(route) raise ValueError(f"Cannot override a route that does not exist: {path}")
routes.add(ApiRoute(path, fn, route.method, **route.kwargs))
setattr(fn, "_route_path", path)
return fn return fn
return inner return inner

View File

@@ -9,7 +9,7 @@ from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.auth.schema import AuthUser from cpl.auth.schema import AuthUser
TRequest = Union[Request, WebSocket] TRequest = Union[Request, WebSocket]
HTTPMethods = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"] HTTPMethods = Literal["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
PartialMiddleware = Union[ PartialMiddleware = Union[
ASGIMiddleware, ASGIMiddleware,
Type[ASGIMiddleware], Type[ASGIMiddleware],

View File

View File

@@ -0,0 +1,23 @@
from abc import abstractmethod, ABC
from typing import Generic
from cpl.core.typing import T
class RegistryABC(ABC, Generic[T]):
@abstractmethod
def __init__(self):
self._items: dict[str, T] = {}
@abstractmethod
def extend(self, items: list[T]) -> None: ...
@abstractmethod
def add(self, item: T) -> None: ...
@abstractmethod
def get(self, key: str) -> T | None: ...
@abstractmethod
def all(self) -> list[T]: ...

View File

@@ -3,6 +3,7 @@ from starlette.responses import JSONResponse
from cpl import api from cpl import api
from cpl.api.application.web_app import WebApp from cpl.api.application.web_app import WebApp
from cpl.application import ApplicationBuilder from cpl.application import ApplicationBuilder
from cpl.auth.permission.permissions import Permissions
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.core.environment import Environment from cpl.core.environment import Environment
from service import PingService from service import PingService
@@ -26,7 +27,7 @@ def main():
app.with_authentication() app.with_authentication()
app.with_authorization() app.with_authorization()
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET") app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET", authentication=True, permissions=[Permissions.administrator])
app.with_routes_directory("routes") app.with_routes_directory("routes")
app.run() app.run()