Compare commits
2 Commits
2025.09.22
...
2025.09.22
| Author | SHA1 | Date | |
|---|---|---|---|
| 77d821bb6e | |||
| 86ad953ff1 |
@@ -1,5 +1,8 @@
|
||||
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
|
||||
|
||||
from .error import APIError, AlreadyExists, EndpointNotImplemented, Forbidden, NotFound, Unauthorized
|
||||
from .logger import APILogger
|
||||
from .settings import ApiSettings
|
||||
|
||||
def add_api(collection: _ServiceCollection):
|
||||
try:
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .asgi_middleware_abc import ASGIMiddleware
|
||||
@@ -0,0 +1 @@
|
||||
from .web_app import WebApp
|
||||
@@ -33,6 +33,7 @@ _logger = APILogger("API")
|
||||
|
||||
PolicyInput = Union[dict[str, PolicyResolver], Policy]
|
||||
|
||||
|
||||
class WebApp(ApplicationABC):
|
||||
def __init__(self, services: ServiceProviderABC):
|
||||
super().__init__(services, [auth, api])
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from .authentication import AuthenticationMiddleware
|
||||
from .authorization import AuthorizationMiddleware
|
||||
from .logging import LoggingMiddleware
|
||||
from .request import RequestMiddleware
|
||||
|
||||
@@ -53,9 +53,13 @@ class AuthorizationMiddleware(ASGIMiddleware):
|
||||
|
||||
if rule["permissions"]:
|
||||
if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]):
|
||||
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
|
||||
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
|
||||
scope, receive, send
|
||||
)
|
||||
if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]):
|
||||
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
|
||||
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
|
||||
scope, receive, send
|
||||
)
|
||||
|
||||
for policy_name in rule["policies"]:
|
||||
policy = self._policies.get(policy_name)
|
||||
@@ -66,4 +70,4 @@ class AuthorizationMiddleware(ASGIMiddleware):
|
||||
if not await policy.resolve(user):
|
||||
return await Forbidden(f"policy {policy.name} failed").asgi_response(scope, receive, send)
|
||||
|
||||
return await self._call_next(scope, receive, send)
|
||||
return await self._call_next(scope, receive, send)
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .api_route import ApiRoute
|
||||
from .policy import Policy
|
||||
from .validation_match import ValidationMatch
|
||||
@@ -7,13 +7,7 @@ from cpl.api.typing import HTTPMethods
|
||||
|
||||
class ApiRoute:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
fn: Callable,
|
||||
method: HTTPMethods,
|
||||
**kwargs
|
||||
):
|
||||
def __init__(self, path: str, fn: Callable, method: HTTPMethods, **kwargs):
|
||||
self._path = path
|
||||
self._fn = fn
|
||||
self._method = method
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
from .policy import PolicyRegistry
|
||||
from .route import RouteRegistry
|
||||
@@ -41,7 +41,13 @@ class Router:
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def authorize(cls, roles: list[str | Enum]=None, permissions: list[str | Enum]=None, policies: list[str]=None, match: ValidationMatch=None):
|
||||
def authorize(
|
||||
cls,
|
||||
roles: list[str | Enum] = None,
|
||||
permissions: list[str | Enum] = None,
|
||||
policies: list[str] = None,
|
||||
match: ValidationMatch = None,
|
||||
):
|
||||
"""
|
||||
Decorator to mark a route as requiring authorization.
|
||||
Usage:
|
||||
@@ -85,15 +91,15 @@ class Router:
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry=None, **kwargs):
|
||||
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs):
|
||||
if not registry:
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
routes = ServiceProviderABC.get_global_service(RouteRegistry)
|
||||
else:
|
||||
routes = registry
|
||||
|
||||
|
||||
def inner(fn):
|
||||
routes.add(ApiRoute(path, fn, method, **kwargs))
|
||||
setattr(fn, "_route_path", path)
|
||||
@@ -138,7 +144,9 @@ class Router:
|
||||
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
routes = ServiceProviderABC.get_global_service(RouteRegistry)
|
||||
|
||||
def inner(fn):
|
||||
path = getattr(fn, "_route_path", None)
|
||||
if path is None:
|
||||
@@ -147,7 +155,7 @@ class Router:
|
||||
route = routes.get(path)
|
||||
if route is None:
|
||||
raise ValueError(f"Cannot override a route that does not exist: {path}")
|
||||
|
||||
|
||||
routes.add(ApiRoute(path, fn, route.method, **route.kwargs))
|
||||
setattr(fn, "_route_path", path)
|
||||
return fn
|
||||
|
||||
@@ -16,4 +16,4 @@ PartialMiddleware = Union[
|
||||
Middleware,
|
||||
Callable[[ASGIApp], ASGIApp],
|
||||
]
|
||||
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]
|
||||
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]
|
||||
|
||||
Reference in New Issue
Block a user