Authorization via with_route
All checks were successful
Test before pr merge / test-lint (pull_request) Successful in 6s
Build on push / prepare (push) Successful in 9s
Build on push / query (push) Successful in 18s
Build on push / core (push) Successful in 18s
Build on push / dependency (push) Successful in 17s
Build on push / application (push) Successful in 15s
Build on push / mail (push) Successful in 18s
Build on push / translation (push) Successful in 18s
Build on push / database (push) Successful in 19s
Build on push / auth (push) Successful in 14s
Build on push / api (push) Successful in 14s

This commit is contained in:
2025-09-22 22:03:42 +02:00
parent d6b7eb9b30
commit 86ad953ff1
11 changed files with 249 additions and 73 deletions

View File

@@ -23,7 +23,10 @@ def add_api(collection: _ServiceCollection):
dependency_error("cpl-auth", e)
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.registry.route import RouteRegistry
collection.add_singleton(PolicyRegistry)
collection.add_singleton(RouteRegistry)
_ServiceCollection.with_module(add_api, __name__)

View File

@@ -1,4 +1,5 @@
import os
from enum import Enum
from typing import Mapping, Any, Callable, Self, Union
import uvicorn
@@ -7,18 +8,20 @@ from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from starlette.types import ExceptionHandler
from cpl import api, auth
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.error import APIError
from cpl.api.logger import APILogger
from cpl.api.middleware.authentication import AuthenticationMiddleware
from cpl.api.middleware.authorization import AuthorizationMiddleware
from cpl.api.middleware.logging import LoggingMiddleware
from cpl.api.middleware.request import RequestMiddleware
from cpl.api.model.api_route import ApiRoute
from cpl.api.model.policy import Policy
from cpl.api.model.validation_match import ValidationMatch
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.registry.route import RouteRegistry
from cpl.api.router import Router
from cpl.api.settings import ApiSettings
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
@@ -30,15 +33,16 @@ _logger = APILogger("API")
PolicyInput = Union[dict[str, PolicyResolver], Policy]
class WebApp(ApplicationABC):
def __init__(self, services: ServiceProviderABC):
super().__init__(services, [auth, api])
self._app: Starlette | None = None
self._api_settings = Configuration.get(ApiSettings)
self._policy_registry = services.get_service(PolicyRegistry)
self._policies = services.get_service(PolicyRegistry)
self._routes = services.get_service(RouteRegistry)
self._routes: list[Route] = []
self._middleware: list[Middleware] = [
Middleware(RequestMiddleware),
Middleware(LoggingMiddleware),
@@ -100,27 +104,64 @@ class WebApp(ApplicationABC):
return self
def with_routes(self, routes: list[Route]) -> Self:
def with_routes(
self,
routes: list[ApiRoute],
method: HTTPMethods,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Self:
self._check_for_app()
assert self._routes is not None, "routes must not be None"
assert all(isinstance(route, Route) for route in routes), "all routes must be of type starlette.routing.Route"
self._routes.extend(routes)
assert all(isinstance(route, ApiRoute) for route in routes), "all routes must be of type ApiRoute"
for route in routes:
self.with_route(
route.path,
route.fn,
method,
authentication,
roles,
permissions,
policies,
match,
)
return self
def with_route(self, path: str, fn: Callable[[Request], Any], method: HTTPMethods, **kwargs) -> Self:
def with_route(
self,
path: str,
fn: Callable[[Request], Any],
method: HTTPMethods,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Self:
self._check_for_app()
assert path is not None, "path must not be None"
assert fn is not None, "fn must not be None"
assert method in [
"GET",
"HEAD",
"POST",
"PUT",
"DELETE",
"PATCH",
"DELETE",
"OPTIONS",
"HEAD",
], "method must be a valid HTTP method"
self._routes.append(Route(path, fn, methods=[method], **kwargs))
Router.route(path, method, registry=self._routes)(fn)
if authentication:
Router.authenticate()(fn)
if roles or permissions or policies:
Router.authorize(roles, permissions, policies, match)(fn)
return self
def with_middleware(self, middleware: PartialMiddleware) -> Self:
@@ -162,7 +203,7 @@ class WebApp(ApplicationABC):
_policies.append(policy)
self._policy_registry.extend_policies(_policies)
self._policies.extend_policies(_policies)
self.with_middleware(AuthorizationMiddleware)
return self
@@ -170,7 +211,7 @@ class WebApp(ApplicationABC):
def _validate_policies(self):
for rule in Router.get_authorization_rules():
for policy_name in rule["policies"]:
policy = self._policy_registry.get(policy_name)
policy = self._policies.get(policy_name)
if not policy:
_logger.fatal(f"Authorization policy '{policy_name}' not found")
@@ -179,15 +220,7 @@ class WebApp(ApplicationABC):
self._validate_policies()
if self._app is None:
routes = [
Route(
path=route.path,
endpoint=self._services.inject(route.endpoint),
methods=route.methods,
name=route.name,
)
for route in self._routes + Router.get_routes()
]
routes = [route.to_starlette(self._services.inject) for route in self._routes.all()]
app = Starlette(
routes=routes,

View File

@@ -25,8 +25,13 @@ class AuthorizationMiddleware(ASGIMiddleware):
async def __call__(self, scope: Scope, receive: Receive, send: Send):
request = get_request()
user = get_user()
url = request.url.path
if url not in Router.get_authorization_rules_paths():
_logger.trace(f"No authorization required for {url}")
return await self._app(scope, receive, send)
user = get_user()
if not user:
return await Unauthorized(f"Unknown user").asgi_response(scope, receive, send)
@@ -48,9 +53,13 @@ class AuthorizationMiddleware(ASGIMiddleware):
if rule["permissions"]:
if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
scope, receive, send
)
if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
scope, receive, send
)
for policy_name in rule["policies"]:
policy = self._policies.get(policy_name)
@@ -61,4 +70,4 @@ class AuthorizationMiddleware(ASGIMiddleware):
if not await policy.resolve(user):
return await Forbidden(f"policy {policy.name} failed").asgi_response(scope, receive, send)
return await self._call_next(scope, receive, send)
return await self._call_next(scope, receive, send)

View File

@@ -0,0 +1,43 @@
from typing import Callable
from starlette.routing import Route
from cpl.api.typing import HTTPMethods
class ApiRoute:
def __init__(self, path: str, fn: Callable, method: HTTPMethods, **kwargs):
self._path = path
self._fn = fn
self._method = method
self._kwargs = kwargs
@property
def name(self) -> str:
return self._fn.__name__
@property
def fn(self) -> Callable:
return self._fn
@property
def path(self) -> str:
return self._path
@property
def method(self) -> HTTPMethods:
return self._method
@property
def kwargs(self) -> dict:
return self._kwargs
def to_starlette(self, wrap_endpoint: Callable = None) -> Route:
return Route(
self._path,
self._fn if not wrap_endpoint else wrap_endpoint(self._fn),
methods=[self._method],
**self._kwargs,
)

View File

@@ -1,23 +1,28 @@
from typing import Optional
from cpl.api.model.policy import Policy
from cpl.core.abc.registry_abc import RegistryABC
class PolicyRegistry:
class PolicyRegistry(RegistryABC):
def __init__(self):
self._policies: dict[str, Policy] = {}
RegistryABC.__init__(self)
def extend_policies(self, policies: list[Policy]):
for policy in policies:
self.add_policy(policy)
def extend(self, items: list[Policy]):
for policy in items:
self.add(policy)
def add_policy(self, policy: Policy):
assert isinstance(policy, Policy), "policy must be an instance of Policy"
def add(self, item: Policy):
assert isinstance(item, Policy), "policy must be an instance of Policy"
if policy.name in self._policies:
raise ValueError(f"Policy {policy.name} is already registered")
if item.name in self._items:
raise ValueError(f"Policy {item.name} is already registered")
self._policies[policy.name] = policy
self._items[item.name] = item
def get(self, name: str) -> Optional[Policy]:
return self._policies.get(name)
def get(self, key: str) -> Optional[Policy]:
return self._items.get(key)
def all(self) -> list[Policy]:
return list(self._items.values())

View File

@@ -0,0 +1,33 @@
from typing import Optional
from cpl.api.model.policy import Policy
from cpl.api.model.api_route import ApiRoute
from cpl.core.abc.registry_abc import RegistryABC
class RouteRegistry(RegistryABC):
def __init__(self):
RegistryABC.__init__(self)
def extend(self, items: list[ApiRoute]):
for policy in items:
self.add(policy)
def add(self, item: ApiRoute):
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
if item.path in self._items:
raise ValueError(f"ApiRoute {item.path} is already registered")
self._items[item.path] = item
def set(self, item: ApiRoute):
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
self._items[item.path] = item
def get(self, key: str) -> Optional[ApiRoute]:
return self._items.get(key)
def all(self) -> list[ApiRoute]:
return list(self._items.values())

View File

@@ -1,26 +1,25 @@
from enum import Enum
from starlette.routing import Route
from cpl.api.model.validation_match import ValidationMatch
from cpl.api.registry.route import RouteRegistry
from cpl.api.typing import HTTPMethods
class Router:
_registered_routes: list[Route] = []
_auth_required: list[str] = []
_authorization_rules: list[dict] = []
@classmethod
def get_routes(cls) -> list[Route]:
return cls._registered_routes
_authorization_rules: dict[str, dict] = {}
@classmethod
def get_auth_required_routes(cls) -> list[str]:
return cls._auth_required
@classmethod
def get_authorization_rules_paths(cls) -> list[str]:
return list(cls._authorization_rules.keys())
@classmethod
def get_authorization_rules(cls) -> list[dict]:
return cls._authorization_rules
return list(cls._authorization_rules.values())
@classmethod
def authenticate(cls):
@@ -42,7 +41,13 @@ class Router:
return inner
@classmethod
def authorize(cls, roles: list[str | Enum]=None, permissions: list[str | Enum]=None, policies: list[str]=None, match: ValidationMatch=None):
def authorize(
cls,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
):
"""
Decorator to mark a route as requiring authorization.
Usage:
@@ -67,52 +72,64 @@ class Router:
permissions[permissions.index(perm)] = perm.value
def inner(fn):
route_path = getattr(fn, "_route_path", None)
if not route_path:
path = getattr(fn, "_route_path", None)
if not path:
return fn
if route_path in cls._authorization_rules:
raise ValueError(f"Route {route_path} is already registered for authorization")
if path in cls._authorization_rules:
raise ValueError(f"Route {path} is already registered for authorization")
cls._authorization_rules.append({
cls._authorization_rules[path] = {
"roles": roles or [],
"permissions": permissions or [],
"policies": policies or [],
"match": match or ValidationMatch.all,
})
}
return fn
return inner
@classmethod
def route(cls, path=None, **kwargs):
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs):
if not registry:
from cpl.api.model.api_route import ApiRoute
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = ServiceProviderABC.get_global_service(RouteRegistry)
else:
routes = registry
def inner(fn):
cls._registered_routes.append(Route(path, fn, **kwargs))
routes.add(ApiRoute(path, fn, method, **kwargs))
setattr(fn, "_route_path", path)
return fn
return inner
@classmethod
def get(cls, path=None, **kwargs):
return cls.route(path, methods=["GET"], **kwargs)
def get(cls, path: str, **kwargs):
return cls.route(path, "GET", **kwargs)
@classmethod
def post(cls, path=None, **kwargs):
return cls.route(path, methods=["POST"], **kwargs)
def head(cls, path: str, **kwargs):
return cls.route(path, "HEAD", **kwargs)
@classmethod
def head(cls, path=None, **kwargs):
return cls.route(path, methods=["HEAD"], **kwargs)
def post(cls, path: str, **kwargs):
return cls.route(path, "POST", **kwargs)
@classmethod
def put(cls, path=None, **kwargs):
return cls.route(path, methods=["PUT"], **kwargs)
def put(cls, path: str, **kwargs):
return cls.route(path, "PUT", **kwargs)
@classmethod
def delete(cls, path=None, **kwargs):
return cls.route(path, methods=["DELETE"], **kwargs)
def patch(cls, path: str, **kwargs):
return cls.route(path, "PATCH", **kwargs)
@classmethod
def delete(cls, path: str, **kwargs):
return cls.route(path, "DELETE", **kwargs)
@classmethod
def override(cls):
@@ -125,13 +142,22 @@ class Router:
...
"""
from cpl.api.model.api_route import ApiRoute
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = ServiceProviderABC.get_global_service(RouteRegistry)
def inner(fn):
route_path = getattr(fn, "_route_path", None)
path = getattr(fn, "_route_path", None)
if path is None:
raise ValueError("Cannot override a route that has not been registered yet")
routes = list(filter(lambda x: x.path == route_path, cls._registered_routes))
for route in routes[:-1]:
cls._registered_routes.remove(route)
route = routes.get(path)
if route is None:
raise ValueError(f"Cannot override a route that does not exist: {path}")
routes.add(ApiRoute(path, fn, route.method, **route.kwargs))
setattr(fn, "_route_path", path)
return fn
return inner

View File

@@ -9,11 +9,11 @@ from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.auth.schema import AuthUser
TRequest = Union[Request, WebSocket]
HTTPMethods = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
HTTPMethods = Literal["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
PartialMiddleware = Union[
ASGIMiddleware,
Type[ASGIMiddleware],
Middleware,
Callable[[ASGIApp], ASGIApp],
]
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]

View File

View File

@@ -0,0 +1,23 @@
from abc import abstractmethod, ABC
from typing import Generic
from cpl.core.typing import T
class RegistryABC(ABC, Generic[T]):
@abstractmethod
def __init__(self):
self._items: dict[str, T] = {}
@abstractmethod
def extend(self, items: list[T]) -> None: ...
@abstractmethod
def add(self, item: T) -> None: ...
@abstractmethod
def get(self, key: str) -> T | None: ...
@abstractmethod
def all(self) -> list[T]: ...

View File

@@ -3,6 +3,7 @@ from starlette.responses import JSONResponse
from cpl import api
from cpl.api.application.web_app import WebApp
from cpl.application import ApplicationBuilder
from cpl.auth.permission.permissions import Permissions
from cpl.core.configuration import Configuration
from cpl.core.environment import Environment
from service import PingService
@@ -26,7 +27,7 @@ def main():
app.with_authentication()
app.with_authorization()
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET")
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET", authentication=True, permissions=[Permissions.administrator])
app.with_routes_directory("routes")
app.run()