Authorization via decorator

This commit is contained in:
2025-09-22 21:16:47 +02:00
parent 12b7c62b69
commit d6b7eb9b30
22 changed files with 280 additions and 41 deletions

View File

@@ -22,5 +22,8 @@ def add_api(collection: _ServiceCollection):
dependency_error("cpl-auth", e) dependency_error("cpl-auth", e)
from cpl.api.registry.policy import PolicyRegistry
collection.add_singleton(PolicyRegistry)
_ServiceCollection.with_module(add_api, __name__) _ServiceCollection.with_module(add_api, __name__)

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Mapping, Any, Callable from typing import Mapping, Any, Callable, Self, Union
import uvicorn import uvicorn
from starlette.applications import Starlette from starlette.applications import Starlette
@@ -11,20 +11,24 @@ from starlette.routing import Route
from starlette.types import ExceptionHandler from starlette.types import ExceptionHandler
from cpl import api, auth from cpl import api, auth
from cpl.api.api_logger import APILogger from cpl.api.registry.policy import PolicyRegistry
from cpl.api.api_settings import ApiSettings
from cpl.api.error import APIError from cpl.api.error import APIError
from cpl.api.logger import APILogger
from cpl.api.middleware.authentication import AuthenticationMiddleware from cpl.api.middleware.authentication import AuthenticationMiddleware
from cpl.api.middleware.authorization import AuthorizationMiddleware
from cpl.api.middleware.logging import LoggingMiddleware from cpl.api.middleware.logging import LoggingMiddleware
from cpl.api.middleware.request import RequestMiddleware from cpl.api.middleware.request import RequestMiddleware
from cpl.api.model.policy import Policy
from cpl.api.router import Router from cpl.api.router import Router
from cpl.api.typing import HTTPMethods, PartialMiddleware from cpl.api.settings import ApiSettings
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
from cpl.application.abc.application_abc import ApplicationABC from cpl.application.abc.application_abc import ApplicationABC
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.dependency.service_provider_abc import ServiceProviderABC from cpl.dependency.service_provider_abc import ServiceProviderABC
_logger = APILogger("API") _logger = APILogger("API")
PolicyInput = Union[dict[str, PolicyResolver], Policy]
class WebApp(ApplicationABC): class WebApp(ApplicationABC):
def __init__(self, services: ServiceProviderABC): def __init__(self, services: ServiceProviderABC):
@@ -32,6 +36,7 @@ class WebApp(ApplicationABC):
self._app: Starlette | None = None self._app: Starlette | None = None
self._api_settings = Configuration.get(ApiSettings) self._api_settings = Configuration.get(ApiSettings)
self._policy_registry = services.get_service(PolicyRegistry)
self._routes: list[Route] = [] self._routes: list[Route] = []
self._middleware: list[Middleware] = [ self._middleware: list[Middleware] = [
@@ -66,11 +71,12 @@ class WebApp(ApplicationABC):
_logger.debug(f"Allowed origins: {origins}") _logger.debug(f"Allowed origins: {origins}")
return origins.split(",") return origins.split(",")
def with_database(self): def with_database(self) -> Self:
self.with_migrations() self.with_migrations()
self.with_seeders() self.with_seeders()
return self
def with_app(self, app: Starlette): def with_app(self, app: Starlette) -> Self:
assert app is not None, "app must not be None" assert app is not None, "app must not be None"
assert isinstance(app, Starlette), "app must be an instance of Starlette" assert isinstance(app, Starlette), "app must be an instance of Starlette"
self._app = app self._app = app
@@ -80,7 +86,7 @@ class WebApp(ApplicationABC):
if self._app is not None: if self._app is not None:
raise ValueError("App is already set, cannot add routes or middleware") raise ValueError("App is already set, cannot add routes or middleware")
def with_routes_directory(self, directory: str) -> "WebApp": def with_routes_directory(self, directory: str) -> Self:
self._check_for_app() self._check_for_app()
assert directory is not None, "directory must not be None" assert directory is not None, "directory must not be None"
@@ -94,14 +100,14 @@ class WebApp(ApplicationABC):
return self return self
def with_routes(self, routes: list[Route]) -> "WebApp": def with_routes(self, routes: list[Route]) -> Self:
self._check_for_app() self._check_for_app()
assert self._routes is not None, "routes must not be None" assert self._routes is not None, "routes must not be None"
assert all(isinstance(route, Route) for route in routes), "all routes must be of type starlette.routing.Route" assert all(isinstance(route, Route) for route in routes), "all routes must be of type starlette.routing.Route"
self._routes.extend(routes) self._routes.extend(routes)
return self return self
def with_route(self, path: str, fn: Callable[[Request], Any], method: HTTPMethods, **kwargs) -> "WebApp": def with_route(self, path: str, fn: Callable[[Request], Any], method: HTTPMethods, **kwargs) -> Self:
self._check_for_app() self._check_for_app()
assert path is not None, "path must not be None" assert path is not None, "path must not be None"
assert fn is not None, "fn must not be None" assert fn is not None, "fn must not be None"
@@ -117,7 +123,7 @@ class WebApp(ApplicationABC):
self._routes.append(Route(path, fn, methods=[method], **kwargs)) self._routes.append(Route(path, fn, methods=[method], **kwargs))
return self return self
def with_middleware(self, middleware: PartialMiddleware) -> "WebApp": def with_middleware(self, middleware: PartialMiddleware) -> Self:
self._check_for_app() self._check_for_app()
if isinstance(middleware, Middleware): if isinstance(middleware, Middleware):
@@ -129,15 +135,49 @@ class WebApp(ApplicationABC):
return self return self
def with_authentication(self): def with_authentication(self) -> Self:
self.with_middleware(AuthenticationMiddleware) self.with_middleware(AuthenticationMiddleware)
return self return self
def with_authorization(self): def with_authorization(self, *policies: list[PolicyInput] | PolicyInput) -> Self:
pass if policies:
_policies = []
if not isinstance(policies, list):
policies = list(policies)
for i, policy in enumerate(policies):
if isinstance(policy, dict):
for name, resolver in policy.items():
if not isinstance(name, str):
_logger.warning(f"Skipping policy at index {i}, name must be a string")
continue
if not callable(resolver):
_logger.warning(f"Skipping policy {name}, resolver must be callable")
continue
_policies.append(Policy(name, resolver))
continue
_policies.append(policy)
self._policy_registry.extend_policies(_policies)
self.with_middleware(AuthorizationMiddleware)
return self
def _validate_policies(self):
for rule in Router.get_authorization_rules():
for policy_name in rule["policies"]:
policy = self._policy_registry.get(policy_name)
if not policy:
_logger.fatal(f"Authorization policy '{policy_name}' not found")
async def main(self): async def main(self):
_logger.debug(f"Preparing API") _logger.debug(f"Preparing API")
self._validate_policies()
if self._app is None: if self._app is None:
routes = [ routes = [
Route( Route(
@@ -166,13 +206,6 @@ class WebApp(ApplicationABC):
app = self._app app = self._app
_logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}") _logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
# uvicorn.run(
# app,
# host=self._api_settings.host,
# port=self._api_settings.port,
# log_config=None,
# loop="asyncio"
# )
config = uvicorn.Config( config = uvicorn.Config(
app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio" app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio"

View File

@@ -7,14 +7,23 @@ from starlette.types import Scope, Receive, Send
class APIError(HTTPException): class APIError(HTTPException):
status_code = 500 status_code = 500
@classmethod def __init__(self, message: str = ""):
async def asgi_response(cls, scope: Scope, receive: Receive, send: Send): super().__init__(self.status_code, message)
r = JSONResponse({"error": cls.__name__}, status_code=cls.status_code) self._message = message
@property
def error_message(self) -> str:
if self._message:
return f"{type(self).__name__}: {self._message}"
return f"{type(self).__name__}"
async def asgi_response(self, scope: Scope, receive: Receive, send: Send):
r = JSONResponse({"error": self.error_message}, status_code=self.status_code)
return await r(scope, receive, send) return await r(scope, receive, send)
@classmethod def response(self):
def response(cls): return JSONResponse({"error": self.error_message}, status_code=self.status_code)
return JSONResponse({"error": cls.__name__}, status_code=cls.status_code)
class Unauthorized(APIError): class Unauthorized(APIError):

View File

@@ -2,12 +2,13 @@ from keycloak import KeycloakAuthenticationError
from starlette.types import Scope, Receive, Send from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.api_logger import APILogger from cpl.api.logger import APILogger
from cpl.api.error import Unauthorized from cpl.api.error import Unauthorized
from cpl.api.middleware.request import get_request from cpl.api.middleware.request import get_request
from cpl.api.router import Router from cpl.api.router import Router
from cpl.auth.keycloak import KeycloakClient from cpl.auth.keycloak import KeycloakClient
from cpl.auth.schema import AuthUserDao, AuthUser from cpl.auth.schema import AuthUserDao, AuthUser
from cpl.core.ctx import set_user
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProviderABC
_logger = APILogger(__name__) _logger = APILogger(__name__)
@@ -53,6 +54,9 @@ class AuthenticationMiddleware(ASGIMiddleware):
_logger.debug(f"Unauthorized access to {url}, user is deleted") _logger.debug(f"Unauthorized access to {url}, user is deleted")
return await Unauthorized("User is deleted").asgi_response(scope, receive, send) return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
request.state.user = user
set_user(user)
return await self._call_next(scope, receive, send) return await self._call_next(scope, receive, send)
async def _get_or_crate_user(self, keycloak_id: str) -> AuthUser: async def _get_or_crate_user(self, keycloak_id: str) -> AuthUser:

View File

@@ -0,0 +1,64 @@
from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.error import Unauthorized, Forbidden
from cpl.api.logger import APILogger
from cpl.api.middleware.request import get_request
from cpl.api.model.validation_match import ValidationMatch
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.router import Router
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
from cpl.core.ctx.user_context import get_user
from cpl.dependency.service_provider_abc import ServiceProviderABC
_logger = APILogger(__name__)
class AuthorizationMiddleware(ASGIMiddleware):
@ServiceProviderABC.inject
def __init__(self, app, policies: PolicyRegistry, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app)
self._policies = policies
self._user_dao = user_dao
async def __call__(self, scope: Scope, receive: Receive, send: Send):
request = get_request()
user = get_user()
if not user:
return await Unauthorized(f"Unknown user").asgi_response(scope, receive, send)
roles = await user.roles
request.state.roles = roles
role_names = [r.name for r in roles]
perms = await user.permissions
request.state.permissions = perms
perm_names = [p.name for p in perms]
for rule in Router.get_authorization_rules():
match = rule["match"]
if rule["roles"]:
if match == ValidationMatch.all and not all(r in role_names for r in rule["roles"]):
return await Forbidden(f"missing roles: {rule["roles"]}").asgi_response(scope, receive, send)
if match == ValidationMatch.any and not any(r in role_names for r in rule["roles"]):
return await Forbidden(f"missing roles: {rule["roles"]}").asgi_response(scope, receive, send)
if rule["permissions"]:
if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
for policy_name in rule["policies"]:
policy = self._policies.get(policy_name)
if not policy:
_logger.warning(f"Authorization policy '{policy_name}' not found")
continue
if not await policy.resolve(user):
return await Forbidden(f"policy {policy.name} failed").asgi_response(scope, receive, send)
return await self._call_next(scope, receive, send)

View File

@@ -4,7 +4,7 @@ from starlette.requests import Request
from starlette.types import Receive, Scope, Send from starlette.types import Receive, Scope, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.api_logger import APILogger from cpl.api.logger import APILogger
from cpl.api.middleware.request import get_request from cpl.api.middleware.request import get_request
_logger = APILogger(__name__) _logger = APILogger(__name__)

View File

@@ -5,10 +5,9 @@ from uuid import uuid4
from starlette.requests import Request from starlette.requests import Request
from starlette.types import Scope, Receive, Send from starlette.types import Scope, Receive, Send
from starlette.websockets import WebSocket
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.api_logger import APILogger from cpl.api.logger import APILogger
from cpl.api.typing import TRequest from cpl.api.typing import TRequest
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None) _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
@@ -50,5 +49,5 @@ class RequestMiddleware(ASGIMiddleware):
_request_context.reset(self._ctx_token) _request_context.reset(self._ctx_token)
def get_request() -> Optional[Union[TRequest, WebSocket]]: def get_request() -> Optional[TRequest]:
return _request_context.get() return _request_context.get()

View File

View File

@@ -0,0 +1,34 @@
from asyncio import iscoroutinefunction
from typing import Optional, Any, Coroutine, Awaitable
from cpl.api.typing import PolicyResolver
from cpl.core.ctx import get_user
class Policy:
def __init__(
self,
name: str,
resolver: PolicyResolver = None,
):
self._name = name
self._resolver: Optional[PolicyResolver] = resolver
@property
def name(self) -> str:
return self._name
@property
def resolvers(self) -> PolicyResolver:
return self._resolver
async def resolve(self, *args, **kwargs) -> bool:
if not self._resolver:
return True
if callable(self._resolver):
if iscoroutinefunction(self._resolver):
return await self._resolver(get_user())
return self._resolver(get_user())
return False

View File

@@ -0,0 +1,6 @@
from enum import Enum
class ValidationMatch(Enum):
any = "any"
all = "all"

View File

View File

@@ -0,0 +1,23 @@
from typing import Optional
from cpl.api.model.policy import Policy
class PolicyRegistry:
def __init__(self):
self._policies: dict[str, Policy] = {}
def extend_policies(self, policies: list[Policy]):
for policy in policies:
self.add_policy(policy)
def add_policy(self, policy: Policy):
assert isinstance(policy, Policy), "policy must be an instance of Policy"
if policy.name in self._policies:
raise ValueError(f"Policy {policy.name} is already registered")
self._policies[policy.name] = policy
def get(self, name: str) -> Optional[Policy]:
return self._policies.get(name)

View File

@@ -1,9 +1,14 @@
from enum import Enum
from starlette.routing import Route from starlette.routing import Route
from cpl.api.model.validation_match import ValidationMatch
class Router: class Router:
_registered_routes: list[Route] = [] _registered_routes: list[Route] = []
_auth_required: list[str] = [] _auth_required: list[str] = []
_authorization_rules: list[dict] = []
@classmethod @classmethod
def get_routes(cls) -> list[Route]: def get_routes(cls) -> list[Route]:
@@ -13,6 +18,10 @@ class Router:
def get_auth_required_routes(cls) -> list[str]: def get_auth_required_routes(cls) -> list[str]:
return cls._auth_required return cls._auth_required
@classmethod
def get_authorization_rules(cls) -> list[dict]:
return cls._authorization_rules
@classmethod @classmethod
def authenticate(cls): def authenticate(cls):
""" """
@@ -32,6 +41,50 @@ class Router:
return inner return inner
@classmethod
def authorize(cls, roles: list[str | Enum]=None, permissions: list[str | Enum]=None, policies: list[str]=None, match: ValidationMatch=None):
"""
Decorator to mark a route as requiring authorization.
Usage:
@Route.authorize()
@Route.get("/example")
async def example_endpoint(request: TRequest):
...
"""
assert roles is None or isinstance(roles, list), "roles must be a list of strings"
assert permissions is None or isinstance(permissions, list), "permissions must be a list of strings"
assert policies is None or isinstance(policies, list), "policies must be a list of strings"
assert match is None or isinstance(match, ValidationMatch), "match must be an instance of ValidationMatch"
if roles is not None:
for role in roles:
if isinstance(role, Enum):
roles[roles.index(role)] = role.value
if permissions is not None:
for perm in permissions:
if isinstance(perm, Enum):
permissions[permissions.index(perm)] = perm.value
def inner(fn):
route_path = getattr(fn, "_route_path", None)
if not route_path:
return fn
if route_path in cls._authorization_rules:
raise ValueError(f"Route {route_path} is already registered for authorization")
cls._authorization_rules.append({
"roles": roles or [],
"permissions": permissions or [],
"policies": policies or [],
"match": match or ValidationMatch.all,
})
return fn
return inner
@classmethod @classmethod
def route(cls, path=None, **kwargs): def route(cls, path=None, **kwargs):
def inner(fn): def inner(fn):

View File

@@ -1,13 +1,19 @@
from typing import Union, Literal, Callable from typing import Union, Literal, Callable, Type, Awaitable
from urllib.request import Request from urllib.request import Request
from starlette.middleware import Middleware from starlette.middleware import Middleware
from starlette.types import ASGIApp from starlette.types import ASGIApp
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.auth.schema import AuthUser
TRequest = Union[Request, WebSocket] TRequest = Union[Request, WebSocket]
HTTPMethods = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"] HTTPMethods = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
PartialMiddleware = Union[ PartialMiddleware = Union[
ASGIMiddleware,
Type[ASGIMiddleware],
Middleware, Middleware,
Callable[[ASGIApp], ASGIApp], Callable[[ASGIApp], ASGIApp],
] ]
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]

View File

@@ -43,9 +43,9 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value) p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
result = await self._db.select_map( result = await self._db.select_map(
f""" f"""
SELECT COUNT(*) SELECT COUNT(*) as count
FROM permission.role_users ru FROM {TableManager.get("role_users")} ru
JOIN permission.role_permissions rp ON ru.roleId = rp.roleId JOIN {TableManager.get("role_permissions")} rp ON ru.roleId = rp.roleId
WHERE ru.userId = {user_id} WHERE ru.userId = {user_id}
AND rp.permissionId = {p.id} AND rp.permissionId = {p.id}
AND ru.deleted = FALSE AND ru.deleted = FALSE
@@ -61,9 +61,9 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
result = await self._db.select_map( result = await self._db.select_map(
f""" f"""
SELECT p.* SELECT p.*
FROM permission.permissions p FROM {TableManager.get("permissions")} p
JOIN permission.role_permissions rp ON p.id = rp.permissionId JOIN {TableManager.get("role_permissions")} rp ON p.id = rp.permissionId
JOIN permission.role_users ru ON rp.roleId = ru.roleId JOIN {TableManager.get("role_users")} ru ON rp.roleId = ru.roleId
WHERE ru.userId = {user_id} WHERE ru.userId = {user_id}
AND rp.deleted = FALSE AND rp.deleted = FALSE
AND ru.deleted = FALSE; AND ru.deleted = FALSE;

View File

@@ -487,7 +487,7 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
builder.with_temp_table(self._external_fields[temp]) builder.with_temp_table(self._external_fields[temp])
if for_count: if for_count:
builder.with_attribute("COUNT(*)", ignore_table_name=True) builder.with_attribute("COUNT(*) as count", ignore_table_name=True)
else: else:
builder.with_attribute("*") builder.with_attribute("*")

View File

@@ -33,7 +33,7 @@ class TableManager:
}, },
"role_users": { "role_users": {
ServerTypes.POSTGRES: "permission.role_users", ServerTypes.POSTGRES: "permission.role_users",
ServerTypes.MYSQL: "permission_role_users", ServerTypes.MYSQL: "permission_role_auth_users",
}, },
} }

View File

@@ -1,7 +1,7 @@
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from cpl import api from cpl import api
from cpl.api.web_app import WebApp from cpl.api.application.web_app import WebApp
from cpl.application import ApplicationBuilder from cpl.application import ApplicationBuilder
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.core.environment import Environment from cpl.core.environment import Environment
@@ -24,6 +24,8 @@ def main():
app.with_database() app.with_database()
app.with_authentication() app.with_authentication()
app.with_authorization()
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET") app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET")
app.with_routes_directory("routes") app.with_routes_directory("routes")

View File

@@ -3,11 +3,14 @@ from urllib.request import Request
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from cpl.api.router import Router from cpl.api.router import Router
from cpl.auth.permission.permissions import Permissions
from cpl.core.log import Logger from cpl.core.log import Logger
from service import PingService from service import PingService
@Router.authenticate() @Router.authenticate()
@Router.authorize(permissions=[Permissions.administrator])
# @Router.authorize(policies=["test"])
@Router.get(f"/ping") @Router.get(f"/ping")
async def ping(r: Request, ping: PingService, logger: Logger): async def ping(r: Request, ping: PingService, logger: Logger):
logger.info(f"Ping: {ping}") logger.info(f"Ping: {ping}")