WIP: dev into master #184
@@ -22,5 +22,8 @@ def add_api(collection: _ServiceCollection):
|
|||||||
|
|
||||||
dependency_error("cpl-auth", e)
|
dependency_error("cpl-auth", e)
|
||||||
|
|
||||||
|
from cpl.api.registry.policy import PolicyRegistry
|
||||||
|
collection.add_singleton(PolicyRegistry)
|
||||||
|
|
||||||
|
|
||||||
_ServiceCollection.with_module(add_api, __name__)
|
_ServiceCollection.with_module(add_api, __name__)
|
||||||
|
|||||||
0
src/cpl-api/cpl/api/application/__init__.py
Normal file
0
src/cpl-api/cpl/api/application/__init__.py
Normal file
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Mapping, Any, Callable
|
from typing import Mapping, Any, Callable, Self, Union
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
@@ -11,20 +11,24 @@ from starlette.routing import Route
|
|||||||
from starlette.types import ExceptionHandler
|
from starlette.types import ExceptionHandler
|
||||||
|
|
||||||
from cpl import api, auth
|
from cpl import api, auth
|
||||||
from cpl.api.api_logger import APILogger
|
from cpl.api.registry.policy import PolicyRegistry
|
||||||
from cpl.api.api_settings import ApiSettings
|
|
||||||
from cpl.api.error import APIError
|
from cpl.api.error import APIError
|
||||||
|
from cpl.api.logger import APILogger
|
||||||
from cpl.api.middleware.authentication import AuthenticationMiddleware
|
from cpl.api.middleware.authentication import AuthenticationMiddleware
|
||||||
|
from cpl.api.middleware.authorization import AuthorizationMiddleware
|
||||||
from cpl.api.middleware.logging import LoggingMiddleware
|
from cpl.api.middleware.logging import LoggingMiddleware
|
||||||
from cpl.api.middleware.request import RequestMiddleware
|
from cpl.api.middleware.request import RequestMiddleware
|
||||||
|
from cpl.api.model.policy import Policy
|
||||||
from cpl.api.router import Router
|
from cpl.api.router import Router
|
||||||
from cpl.api.typing import HTTPMethods, PartialMiddleware
|
from cpl.api.settings import ApiSettings
|
||||||
|
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
|
||||||
from cpl.application.abc.application_abc import ApplicationABC
|
from cpl.application.abc.application_abc import ApplicationABC
|
||||||
from cpl.core.configuration import Configuration
|
from cpl.core.configuration import Configuration
|
||||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||||
|
|
||||||
_logger = APILogger("API")
|
_logger = APILogger("API")
|
||||||
|
|
||||||
|
PolicyInput = Union[dict[str, PolicyResolver], Policy]
|
||||||
|
|
||||||
class WebApp(ApplicationABC):
|
class WebApp(ApplicationABC):
|
||||||
def __init__(self, services: ServiceProviderABC):
|
def __init__(self, services: ServiceProviderABC):
|
||||||
@@ -32,6 +36,7 @@ class WebApp(ApplicationABC):
|
|||||||
self._app: Starlette | None = None
|
self._app: Starlette | None = None
|
||||||
|
|
||||||
self._api_settings = Configuration.get(ApiSettings)
|
self._api_settings = Configuration.get(ApiSettings)
|
||||||
|
self._policy_registry = services.get_service(PolicyRegistry)
|
||||||
|
|
||||||
self._routes: list[Route] = []
|
self._routes: list[Route] = []
|
||||||
self._middleware: list[Middleware] = [
|
self._middleware: list[Middleware] = [
|
||||||
@@ -66,11 +71,12 @@ class WebApp(ApplicationABC):
|
|||||||
_logger.debug(f"Allowed origins: {origins}")
|
_logger.debug(f"Allowed origins: {origins}")
|
||||||
return origins.split(",")
|
return origins.split(",")
|
||||||
|
|
||||||
def with_database(self):
|
def with_database(self) -> Self:
|
||||||
self.with_migrations()
|
self.with_migrations()
|
||||||
self.with_seeders()
|
self.with_seeders()
|
||||||
|
return self
|
||||||
|
|
||||||
def with_app(self, app: Starlette):
|
def with_app(self, app: Starlette) -> Self:
|
||||||
assert app is not None, "app must not be None"
|
assert app is not None, "app must not be None"
|
||||||
assert isinstance(app, Starlette), "app must be an instance of Starlette"
|
assert isinstance(app, Starlette), "app must be an instance of Starlette"
|
||||||
self._app = app
|
self._app = app
|
||||||
@@ -80,7 +86,7 @@ class WebApp(ApplicationABC):
|
|||||||
if self._app is not None:
|
if self._app is not None:
|
||||||
raise ValueError("App is already set, cannot add routes or middleware")
|
raise ValueError("App is already set, cannot add routes or middleware")
|
||||||
|
|
||||||
def with_routes_directory(self, directory: str) -> "WebApp":
|
def with_routes_directory(self, directory: str) -> Self:
|
||||||
self._check_for_app()
|
self._check_for_app()
|
||||||
assert directory is not None, "directory must not be None"
|
assert directory is not None, "directory must not be None"
|
||||||
|
|
||||||
@@ -94,14 +100,14 @@ class WebApp(ApplicationABC):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_routes(self, routes: list[Route]) -> "WebApp":
|
def with_routes(self, routes: list[Route]) -> Self:
|
||||||
self._check_for_app()
|
self._check_for_app()
|
||||||
assert self._routes is not None, "routes must not be None"
|
assert self._routes is not None, "routes must not be None"
|
||||||
assert all(isinstance(route, Route) for route in routes), "all routes must be of type starlette.routing.Route"
|
assert all(isinstance(route, Route) for route in routes), "all routes must be of type starlette.routing.Route"
|
||||||
self._routes.extend(routes)
|
self._routes.extend(routes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_route(self, path: str, fn: Callable[[Request], Any], method: HTTPMethods, **kwargs) -> "WebApp":
|
def with_route(self, path: str, fn: Callable[[Request], Any], method: HTTPMethods, **kwargs) -> Self:
|
||||||
self._check_for_app()
|
self._check_for_app()
|
||||||
assert path is not None, "path must not be None"
|
assert path is not None, "path must not be None"
|
||||||
assert fn is not None, "fn must not be None"
|
assert fn is not None, "fn must not be None"
|
||||||
@@ -117,7 +123,7 @@ class WebApp(ApplicationABC):
|
|||||||
self._routes.append(Route(path, fn, methods=[method], **kwargs))
|
self._routes.append(Route(path, fn, methods=[method], **kwargs))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_middleware(self, middleware: PartialMiddleware) -> "WebApp":
|
def with_middleware(self, middleware: PartialMiddleware) -> Self:
|
||||||
self._check_for_app()
|
self._check_for_app()
|
||||||
|
|
||||||
if isinstance(middleware, Middleware):
|
if isinstance(middleware, Middleware):
|
||||||
@@ -129,15 +135,49 @@ class WebApp(ApplicationABC):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_authentication(self):
|
def with_authentication(self) -> Self:
|
||||||
self.with_middleware(AuthenticationMiddleware)
|
self.with_middleware(AuthenticationMiddleware)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_authorization(self):
|
def with_authorization(self, *policies: list[PolicyInput] | PolicyInput) -> Self:
|
||||||
pass
|
if policies:
|
||||||
|
_policies = []
|
||||||
|
|
||||||
|
if not isinstance(policies, list):
|
||||||
|
policies = list(policies)
|
||||||
|
|
||||||
|
for i, policy in enumerate(policies):
|
||||||
|
if isinstance(policy, dict):
|
||||||
|
for name, resolver in policy.items():
|
||||||
|
if not isinstance(name, str):
|
||||||
|
_logger.warning(f"Skipping policy at index {i}, name must be a string")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not callable(resolver):
|
||||||
|
_logger.warning(f"Skipping policy {name}, resolver must be callable")
|
||||||
|
continue
|
||||||
|
|
||||||
|
_policies.append(Policy(name, resolver))
|
||||||
|
continue
|
||||||
|
|
||||||
|
_policies.append(policy)
|
||||||
|
|
||||||
|
self._policy_registry.extend_policies(_policies)
|
||||||
|
|
||||||
|
self.with_middleware(AuthorizationMiddleware)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _validate_policies(self):
|
||||||
|
for rule in Router.get_authorization_rules():
|
||||||
|
for policy_name in rule["policies"]:
|
||||||
|
policy = self._policy_registry.get(policy_name)
|
||||||
|
if not policy:
|
||||||
|
_logger.fatal(f"Authorization policy '{policy_name}' not found")
|
||||||
|
|
||||||
async def main(self):
|
async def main(self):
|
||||||
_logger.debug(f"Preparing API")
|
_logger.debug(f"Preparing API")
|
||||||
|
self._validate_policies()
|
||||||
|
|
||||||
if self._app is None:
|
if self._app is None:
|
||||||
routes = [
|
routes = [
|
||||||
Route(
|
Route(
|
||||||
@@ -166,13 +206,6 @@ class WebApp(ApplicationABC):
|
|||||||
app = self._app
|
app = self._app
|
||||||
|
|
||||||
_logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
|
_logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
|
||||||
# uvicorn.run(
|
|
||||||
# app,
|
|
||||||
# host=self._api_settings.host,
|
|
||||||
# port=self._api_settings.port,
|
|
||||||
# log_config=None,
|
|
||||||
# loop="asyncio"
|
|
||||||
# )
|
|
||||||
|
|
||||||
config = uvicorn.Config(
|
config = uvicorn.Config(
|
||||||
app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio"
|
app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio"
|
||||||
@@ -7,14 +7,23 @@ from starlette.types import Scope, Receive, Send
|
|||||||
class APIError(HTTPException):
|
class APIError(HTTPException):
|
||||||
status_code = 500
|
status_code = 500
|
||||||
|
|
||||||
@classmethod
|
def __init__(self, message: str = ""):
|
||||||
async def asgi_response(cls, scope: Scope, receive: Receive, send: Send):
|
super().__init__(self.status_code, message)
|
||||||
r = JSONResponse({"error": cls.__name__}, status_code=cls.status_code)
|
self._message = message
|
||||||
|
|
||||||
|
@property
|
||||||
|
def error_message(self) -> str:
|
||||||
|
if self._message:
|
||||||
|
return f"{type(self).__name__}: {self._message}"
|
||||||
|
|
||||||
|
return f"{type(self).__name__}"
|
||||||
|
|
||||||
|
async def asgi_response(self, scope: Scope, receive: Receive, send: Send):
|
||||||
|
r = JSONResponse({"error": self.error_message}, status_code=self.status_code)
|
||||||
return await r(scope, receive, send)
|
return await r(scope, receive, send)
|
||||||
|
|
||||||
@classmethod
|
def response(self):
|
||||||
def response(cls):
|
return JSONResponse({"error": self.error_message}, status_code=self.status_code)
|
||||||
return JSONResponse({"error": cls.__name__}, status_code=cls.status_code)
|
|
||||||
|
|
||||||
|
|
||||||
class Unauthorized(APIError):
|
class Unauthorized(APIError):
|
||||||
|
|||||||
@@ -2,12 +2,13 @@ from keycloak import KeycloakAuthenticationError
|
|||||||
from starlette.types import Scope, Receive, Send
|
from starlette.types import Scope, Receive, Send
|
||||||
|
|
||||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
from cpl.api.api_logger import APILogger
|
from cpl.api.logger import APILogger
|
||||||
from cpl.api.error import Unauthorized
|
from cpl.api.error import Unauthorized
|
||||||
from cpl.api.middleware.request import get_request
|
from cpl.api.middleware.request import get_request
|
||||||
from cpl.api.router import Router
|
from cpl.api.router import Router
|
||||||
from cpl.auth.keycloak import KeycloakClient
|
from cpl.auth.keycloak import KeycloakClient
|
||||||
from cpl.auth.schema import AuthUserDao, AuthUser
|
from cpl.auth.schema import AuthUserDao, AuthUser
|
||||||
|
from cpl.core.ctx import set_user
|
||||||
from cpl.dependency import ServiceProviderABC
|
from cpl.dependency import ServiceProviderABC
|
||||||
|
|
||||||
_logger = APILogger(__name__)
|
_logger = APILogger(__name__)
|
||||||
@@ -53,6 +54,9 @@ class AuthenticationMiddleware(ASGIMiddleware):
|
|||||||
_logger.debug(f"Unauthorized access to {url}, user is deleted")
|
_logger.debug(f"Unauthorized access to {url}, user is deleted")
|
||||||
return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
|
return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
|
request.state.user = user
|
||||||
|
set_user(user)
|
||||||
|
|
||||||
return await self._call_next(scope, receive, send)
|
return await self._call_next(scope, receive, send)
|
||||||
|
|
||||||
async def _get_or_crate_user(self, keycloak_id: str) -> AuthUser:
|
async def _get_or_crate_user(self, keycloak_id: str) -> AuthUser:
|
||||||
|
|||||||
64
src/cpl-api/cpl/api/middleware/authorization.py
Normal file
64
src/cpl-api/cpl/api/middleware/authorization.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
from starlette.types import Scope, Receive, Send
|
||||||
|
|
||||||
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
|
from cpl.api.error import Unauthorized, Forbidden
|
||||||
|
from cpl.api.logger import APILogger
|
||||||
|
from cpl.api.middleware.request import get_request
|
||||||
|
from cpl.api.model.validation_match import ValidationMatch
|
||||||
|
from cpl.api.registry.policy import PolicyRegistry
|
||||||
|
from cpl.api.router import Router
|
||||||
|
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||||
|
from cpl.core.ctx.user_context import get_user
|
||||||
|
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||||
|
|
||||||
|
_logger = APILogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationMiddleware(ASGIMiddleware):
|
||||||
|
|
||||||
|
@ServiceProviderABC.inject
|
||||||
|
def __init__(self, app, policies: PolicyRegistry, user_dao: AuthUserDao):
|
||||||
|
ASGIMiddleware.__init__(self, app)
|
||||||
|
|
||||||
|
self._policies = policies
|
||||||
|
self._user_dao = user_dao
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||||
|
request = get_request()
|
||||||
|
user = get_user()
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
return await Unauthorized(f"Unknown user").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
|
roles = await user.roles
|
||||||
|
request.state.roles = roles
|
||||||
|
role_names = [r.name for r in roles]
|
||||||
|
|
||||||
|
perms = await user.permissions
|
||||||
|
request.state.permissions = perms
|
||||||
|
perm_names = [p.name for p in perms]
|
||||||
|
|
||||||
|
for rule in Router.get_authorization_rules():
|
||||||
|
match = rule["match"]
|
||||||
|
if rule["roles"]:
|
||||||
|
if match == ValidationMatch.all and not all(r in role_names for r in rule["roles"]):
|
||||||
|
return await Forbidden(f"missing roles: {rule["roles"]}").asgi_response(scope, receive, send)
|
||||||
|
if match == ValidationMatch.any and not any(r in role_names for r in rule["roles"]):
|
||||||
|
return await Forbidden(f"missing roles: {rule["roles"]}").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
|
if rule["permissions"]:
|
||||||
|
if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]):
|
||||||
|
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
|
||||||
|
if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]):
|
||||||
|
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
|
for policy_name in rule["policies"]:
|
||||||
|
policy = self._policies.get(policy_name)
|
||||||
|
if not policy:
|
||||||
|
_logger.warning(f"Authorization policy '{policy_name}' not found")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not await policy.resolve(user):
|
||||||
|
return await Forbidden(f"policy {policy.name} failed").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
|
return await self._call_next(scope, receive, send)
|
||||||
@@ -4,7 +4,7 @@ from starlette.requests import Request
|
|||||||
from starlette.types import Receive, Scope, Send
|
from starlette.types import Receive, Scope, Send
|
||||||
|
|
||||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
from cpl.api.api_logger import APILogger
|
from cpl.api.logger import APILogger
|
||||||
from cpl.api.middleware.request import get_request
|
from cpl.api.middleware.request import get_request
|
||||||
|
|
||||||
_logger = APILogger(__name__)
|
_logger = APILogger(__name__)
|
||||||
|
|||||||
@@ -5,10 +5,9 @@ from uuid import uuid4
|
|||||||
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.types import Scope, Receive, Send
|
from starlette.types import Scope, Receive, Send
|
||||||
from starlette.websockets import WebSocket
|
|
||||||
|
|
||||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
from cpl.api.api_logger import APILogger
|
from cpl.api.logger import APILogger
|
||||||
from cpl.api.typing import TRequest
|
from cpl.api.typing import TRequest
|
||||||
|
|
||||||
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
|
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
|
||||||
@@ -50,5 +49,5 @@ class RequestMiddleware(ASGIMiddleware):
|
|||||||
_request_context.reset(self._ctx_token)
|
_request_context.reset(self._ctx_token)
|
||||||
|
|
||||||
|
|
||||||
def get_request() -> Optional[Union[TRequest, WebSocket]]:
|
def get_request() -> Optional[TRequest]:
|
||||||
return _request_context.get()
|
return _request_context.get()
|
||||||
|
|||||||
0
src/cpl-api/cpl/api/model/__init__.py
Normal file
0
src/cpl-api/cpl/api/model/__init__.py
Normal file
34
src/cpl-api/cpl/api/model/policy.py
Normal file
34
src/cpl-api/cpl/api/model/policy.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from asyncio import iscoroutinefunction
|
||||||
|
from typing import Optional, Any, Coroutine, Awaitable
|
||||||
|
|
||||||
|
from cpl.api.typing import PolicyResolver
|
||||||
|
from cpl.core.ctx import get_user
|
||||||
|
|
||||||
|
|
||||||
|
class Policy:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
resolver: PolicyResolver = None,
|
||||||
|
):
|
||||||
|
self._name = name
|
||||||
|
self._resolver: Optional[PolicyResolver] = resolver
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def resolvers(self) -> PolicyResolver:
|
||||||
|
return self._resolver
|
||||||
|
|
||||||
|
async def resolve(self, *args, **kwargs) -> bool:
|
||||||
|
if not self._resolver:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if callable(self._resolver):
|
||||||
|
if iscoroutinefunction(self._resolver):
|
||||||
|
return await self._resolver(get_user())
|
||||||
|
|
||||||
|
return self._resolver(get_user())
|
||||||
|
return False
|
||||||
6
src/cpl-api/cpl/api/model/validation_match.py
Normal file
6
src/cpl-api/cpl/api/model/validation_match.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationMatch(Enum):
|
||||||
|
any = "any"
|
||||||
|
all = "all"
|
||||||
0
src/cpl-api/cpl/api/registry/__init__.py
Normal file
0
src/cpl-api/cpl/api/registry/__init__.py
Normal file
23
src/cpl-api/cpl/api/registry/policy.py
Normal file
23
src/cpl-api/cpl/api/registry/policy.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from cpl.api.model.policy import Policy
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyRegistry:
|
||||||
|
def __init__(self):
|
||||||
|
self._policies: dict[str, Policy] = {}
|
||||||
|
|
||||||
|
def extend_policies(self, policies: list[Policy]):
|
||||||
|
for policy in policies:
|
||||||
|
self.add_policy(policy)
|
||||||
|
|
||||||
|
def add_policy(self, policy: Policy):
|
||||||
|
assert isinstance(policy, Policy), "policy must be an instance of Policy"
|
||||||
|
|
||||||
|
if policy.name in self._policies:
|
||||||
|
raise ValueError(f"Policy {policy.name} is already registered")
|
||||||
|
|
||||||
|
self._policies[policy.name] = policy
|
||||||
|
|
||||||
|
def get(self, name: str) -> Optional[Policy]:
|
||||||
|
return self._policies.get(name)
|
||||||
@@ -1,9 +1,14 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
from starlette.routing import Route
|
from starlette.routing import Route
|
||||||
|
|
||||||
|
from cpl.api.model.validation_match import ValidationMatch
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
_registered_routes: list[Route] = []
|
_registered_routes: list[Route] = []
|
||||||
_auth_required: list[str] = []
|
_auth_required: list[str] = []
|
||||||
|
_authorization_rules: list[dict] = []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_routes(cls) -> list[Route]:
|
def get_routes(cls) -> list[Route]:
|
||||||
@@ -13,6 +18,10 @@ class Router:
|
|||||||
def get_auth_required_routes(cls) -> list[str]:
|
def get_auth_required_routes(cls) -> list[str]:
|
||||||
return cls._auth_required
|
return cls._auth_required
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_authorization_rules(cls) -> list[dict]:
|
||||||
|
return cls._authorization_rules
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def authenticate(cls):
|
def authenticate(cls):
|
||||||
"""
|
"""
|
||||||
@@ -32,6 +41,50 @@ class Router:
|
|||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def authorize(cls, roles: list[str | Enum]=None, permissions: list[str | Enum]=None, policies: list[str]=None, match: ValidationMatch=None):
|
||||||
|
"""
|
||||||
|
Decorator to mark a route as requiring authorization.
|
||||||
|
Usage:
|
||||||
|
@Route.authorize()
|
||||||
|
@Route.get("/example")
|
||||||
|
async def example_endpoint(request: TRequest):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
assert roles is None or isinstance(roles, list), "roles must be a list of strings"
|
||||||
|
assert permissions is None or isinstance(permissions, list), "permissions must be a list of strings"
|
||||||
|
assert policies is None or isinstance(policies, list), "policies must be a list of strings"
|
||||||
|
assert match is None or isinstance(match, ValidationMatch), "match must be an instance of ValidationMatch"
|
||||||
|
|
||||||
|
if roles is not None:
|
||||||
|
for role in roles:
|
||||||
|
if isinstance(role, Enum):
|
||||||
|
roles[roles.index(role)] = role.value
|
||||||
|
|
||||||
|
if permissions is not None:
|
||||||
|
for perm in permissions:
|
||||||
|
if isinstance(perm, Enum):
|
||||||
|
permissions[permissions.index(perm)] = perm.value
|
||||||
|
|
||||||
|
def inner(fn):
|
||||||
|
route_path = getattr(fn, "_route_path", None)
|
||||||
|
if not route_path:
|
||||||
|
return fn
|
||||||
|
|
||||||
|
if route_path in cls._authorization_rules:
|
||||||
|
raise ValueError(f"Route {route_path} is already registered for authorization")
|
||||||
|
|
||||||
|
cls._authorization_rules.append({
|
||||||
|
"roles": roles or [],
|
||||||
|
"permissions": permissions or [],
|
||||||
|
"policies": policies or [],
|
||||||
|
"match": match or ValidationMatch.all,
|
||||||
|
})
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def route(cls, path=None, **kwargs):
|
def route(cls, path=None, **kwargs):
|
||||||
def inner(fn):
|
def inner(fn):
|
||||||
|
|||||||
@@ -1,13 +1,19 @@
|
|||||||
from typing import Union, Literal, Callable
|
from typing import Union, Literal, Callable, Type, Awaitable
|
||||||
from urllib.request import Request
|
from urllib.request import Request
|
||||||
|
|
||||||
from starlette.middleware import Middleware
|
from starlette.middleware import Middleware
|
||||||
from starlette.types import ASGIApp
|
from starlette.types import ASGIApp
|
||||||
from starlette.websockets import WebSocket
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
|
from cpl.auth.schema import AuthUser
|
||||||
|
|
||||||
TRequest = Union[Request, WebSocket]
|
TRequest = Union[Request, WebSocket]
|
||||||
HTTPMethods = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
|
HTTPMethods = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
|
||||||
PartialMiddleware = Union[
|
PartialMiddleware = Union[
|
||||||
|
ASGIMiddleware,
|
||||||
|
Type[ASGIMiddleware],
|
||||||
Middleware,
|
Middleware,
|
||||||
Callable[[ASGIApp], ASGIApp],
|
Callable[[ASGIApp], ASGIApp],
|
||||||
]
|
]
|
||||||
|
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]
|
||||||
@@ -43,9 +43,9 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
|
|||||||
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
|
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
|
||||||
result = await self._db.select_map(
|
result = await self._db.select_map(
|
||||||
f"""
|
f"""
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*) as count
|
||||||
FROM permission.role_users ru
|
FROM {TableManager.get("role_users")} ru
|
||||||
JOIN permission.role_permissions rp ON ru.roleId = rp.roleId
|
JOIN {TableManager.get("role_permissions")} rp ON ru.roleId = rp.roleId
|
||||||
WHERE ru.userId = {user_id}
|
WHERE ru.userId = {user_id}
|
||||||
AND rp.permissionId = {p.id}
|
AND rp.permissionId = {p.id}
|
||||||
AND ru.deleted = FALSE
|
AND ru.deleted = FALSE
|
||||||
@@ -61,9 +61,9 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
|
|||||||
result = await self._db.select_map(
|
result = await self._db.select_map(
|
||||||
f"""
|
f"""
|
||||||
SELECT p.*
|
SELECT p.*
|
||||||
FROM permission.permissions p
|
FROM {TableManager.get("permissions")} p
|
||||||
JOIN permission.role_permissions rp ON p.id = rp.permissionId
|
JOIN {TableManager.get("role_permissions")} rp ON p.id = rp.permissionId
|
||||||
JOIN permission.role_users ru ON rp.roleId = ru.roleId
|
JOIN {TableManager.get("role_users")} ru ON rp.roleId = ru.roleId
|
||||||
WHERE ru.userId = {user_id}
|
WHERE ru.userId = {user_id}
|
||||||
AND rp.deleted = FALSE
|
AND rp.deleted = FALSE
|
||||||
AND ru.deleted = FALSE;
|
AND ru.deleted = FALSE;
|
||||||
|
|||||||
@@ -487,7 +487,7 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
|
|||||||
builder.with_temp_table(self._external_fields[temp])
|
builder.with_temp_table(self._external_fields[temp])
|
||||||
|
|
||||||
if for_count:
|
if for_count:
|
||||||
builder.with_attribute("COUNT(*)", ignore_table_name=True)
|
builder.with_attribute("COUNT(*) as count", ignore_table_name=True)
|
||||||
else:
|
else:
|
||||||
builder.with_attribute("*")
|
builder.with_attribute("*")
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class TableManager:
|
|||||||
},
|
},
|
||||||
"role_users": {
|
"role_users": {
|
||||||
ServerTypes.POSTGRES: "permission.role_users",
|
ServerTypes.POSTGRES: "permission.role_users",
|
||||||
ServerTypes.MYSQL: "permission_role_users",
|
ServerTypes.MYSQL: "permission_role_auth_users",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
from cpl import api
|
from cpl import api
|
||||||
from cpl.api.web_app import WebApp
|
from cpl.api.application.web_app import WebApp
|
||||||
from cpl.application import ApplicationBuilder
|
from cpl.application import ApplicationBuilder
|
||||||
from cpl.core.configuration import Configuration
|
from cpl.core.configuration import Configuration
|
||||||
from cpl.core.environment import Environment
|
from cpl.core.environment import Environment
|
||||||
@@ -24,6 +24,8 @@ def main():
|
|||||||
app.with_database()
|
app.with_database()
|
||||||
|
|
||||||
app.with_authentication()
|
app.with_authentication()
|
||||||
|
app.with_authorization()
|
||||||
|
|
||||||
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET")
|
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET")
|
||||||
app.with_routes_directory("routes")
|
app.with_routes_directory("routes")
|
||||||
|
|
||||||
|
|||||||
@@ -3,11 +3,14 @@ from urllib.request import Request
|
|||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
from cpl.api.router import Router
|
from cpl.api.router import Router
|
||||||
|
from cpl.auth.permission.permissions import Permissions
|
||||||
from cpl.core.log import Logger
|
from cpl.core.log import Logger
|
||||||
from service import PingService
|
from service import PingService
|
||||||
|
|
||||||
|
|
||||||
@Router.authenticate()
|
@Router.authenticate()
|
||||||
|
@Router.authorize(permissions=[Permissions.administrator])
|
||||||
|
# @Router.authorize(policies=["test"])
|
||||||
@Router.get(f"/ping")
|
@Router.get(f"/ping")
|
||||||
async def ping(r: Request, ping: PingService, logger: Logger):
|
async def ping(r: Request, ping: PingService, logger: Logger):
|
||||||
logger.info(f"Ping: {ping}")
|
logger.info(f"Ping: {ping}")
|
||||||
|
|||||||
Reference in New Issue
Block a user