Compare commits
6 Commits
2025.09.22
...
2025.09.24
| Author | SHA1 | Date | |
|---|---|---|---|
| c71a3df62c | |||
| e296c0992b | |||
| 6639946346 | |||
| b9ac11e15f | |||
| 77d821bb6e | |||
| 86ad953ff1 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -139,3 +139,6 @@ PythonImportHelper-v2-Completion.json
|
||||
|
||||
# cpl unittest stuff
|
||||
unittests/test_*_playground
|
||||
|
||||
# cpl logs
|
||||
**/logs/*.jsonl
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
|
||||
|
||||
from .error import APIError, AlreadyExists, EndpointNotImplemented, Forbidden, NotFound, Unauthorized
|
||||
from .logger import APILogger
|
||||
from .settings import ApiSettings
|
||||
|
||||
|
||||
def add_api(collection: _ServiceCollection):
|
||||
try:
|
||||
@@ -23,7 +27,10 @@ def add_api(collection: _ServiceCollection):
|
||||
dependency_error("cpl-auth", e)
|
||||
|
||||
from cpl.api.registry.policy import PolicyRegistry
|
||||
from cpl.api.registry.route import RouteRegistry
|
||||
|
||||
collection.add_singleton(PolicyRegistry)
|
||||
collection.add_singleton(RouteRegistry)
|
||||
|
||||
|
||||
_ServiceCollection.with_module(add_api, __name__)
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .asgi_middleware_abc import ASGIMiddleware
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .web_app import WebApp
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Mapping, Any, Callable, Self, Union
|
||||
|
||||
import uvicorn
|
||||
@@ -7,18 +8,20 @@ from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
from starlette.types import ExceptionHandler
|
||||
|
||||
from cpl import api, auth
|
||||
from cpl.api.registry.policy import PolicyRegistry
|
||||
from cpl.api.error import APIError
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.middleware.authentication import AuthenticationMiddleware
|
||||
from cpl.api.middleware.authorization import AuthorizationMiddleware
|
||||
from cpl.api.middleware.logging import LoggingMiddleware
|
||||
from cpl.api.middleware.request import RequestMiddleware
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
from cpl.api.model.policy import Policy
|
||||
from cpl.api.model.validation_match import ValidationMatch
|
||||
from cpl.api.registry.policy import PolicyRegistry
|
||||
from cpl.api.registry.route import RouteRegistry
|
||||
from cpl.api.router import Router
|
||||
from cpl.api.settings import ApiSettings
|
||||
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
|
||||
@@ -26,19 +29,21 @@ from cpl.application.abc.application_abc import ApplicationABC
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
_logger = APILogger("API")
|
||||
|
||||
PolicyInput = Union[dict[str, PolicyResolver], Policy]
|
||||
|
||||
|
||||
class WebApp(ApplicationABC):
|
||||
def __init__(self, services: ServiceProviderABC):
|
||||
super().__init__(services, [auth, api])
|
||||
self._app: Starlette | None = None
|
||||
|
||||
self._api_settings = Configuration.get(ApiSettings)
|
||||
self._policy_registry = services.get_service(PolicyRegistry)
|
||||
self._logger = services.get_service(APILogger)
|
||||
|
||||
self._api_settings = Configuration.get(ApiSettings)
|
||||
self._policies = services.get_service(PolicyRegistry)
|
||||
self._routes = services.get_service(RouteRegistry)
|
||||
|
||||
self._routes: list[Route] = []
|
||||
self._middleware: list[Middleware] = [
|
||||
Middleware(RequestMiddleware),
|
||||
Middleware(LoggingMiddleware),
|
||||
@@ -48,16 +53,15 @@ class WebApp(ApplicationABC):
|
||||
APIError: self._handle_exception,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _handle_exception(request: Request, exc: Exception):
|
||||
async def _handle_exception(self, request: Request, exc: Exception):
|
||||
if isinstance(exc, APIError):
|
||||
_logger.error(exc)
|
||||
self._logger.error(exc)
|
||||
return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
|
||||
|
||||
if hasattr(request.state, "request_id"):
|
||||
_logger.error(f"Request {request.state.request_id}", exc)
|
||||
self._logger.error(f"Request {request.state.request_id}", exc)
|
||||
else:
|
||||
_logger.error("Request unknown", exc)
|
||||
self._logger.error("Request unknown", exc)
|
||||
|
||||
return JSONResponse({"error": str(exc)}, status_code=500)
|
||||
|
||||
@@ -65,10 +69,10 @@ class WebApp(ApplicationABC):
|
||||
origins = self._api_settings.allowed_origins
|
||||
|
||||
if origins is None or origins == "":
|
||||
_logger.warning("No allowed origins specified, allowing all origins")
|
||||
self._logger.warning("No allowed origins specified, allowing all origins")
|
||||
return ["*"]
|
||||
|
||||
_logger.debug(f"Allowed origins: {origins}")
|
||||
self._logger.debug(f"Allowed origins: {origins}")
|
||||
return origins.split(",")
|
||||
|
||||
def with_database(self) -> Self:
|
||||
@@ -100,27 +104,64 @@ class WebApp(ApplicationABC):
|
||||
|
||||
return self
|
||||
|
||||
def with_routes(self, routes: list[Route]) -> Self:
|
||||
def with_routes(
|
||||
self,
|
||||
routes: list[ApiRoute],
|
||||
method: HTTPMethods,
|
||||
authentication: bool = False,
|
||||
roles: list[str | Enum] = None,
|
||||
permissions: list[str | Enum] = None,
|
||||
policies: list[str] = None,
|
||||
match: ValidationMatch = None,
|
||||
) -> Self:
|
||||
self._check_for_app()
|
||||
assert self._routes is not None, "routes must not be None"
|
||||
assert all(isinstance(route, Route) for route in routes), "all routes must be of type starlette.routing.Route"
|
||||
self._routes.extend(routes)
|
||||
assert all(isinstance(route, ApiRoute) for route in routes), "all routes must be of type ApiRoute"
|
||||
for route in routes:
|
||||
self.with_route(
|
||||
route.path,
|
||||
route.fn,
|
||||
method,
|
||||
authentication,
|
||||
roles,
|
||||
permissions,
|
||||
policies,
|
||||
match,
|
||||
)
|
||||
return self
|
||||
|
||||
def with_route(self, path: str, fn: Callable[[Request], Any], method: HTTPMethods, **kwargs) -> Self:
|
||||
def with_route(
|
||||
self,
|
||||
path: str,
|
||||
fn: Callable[[Request], Any],
|
||||
method: HTTPMethods,
|
||||
authentication: bool = False,
|
||||
roles: list[str | Enum] = None,
|
||||
permissions: list[str | Enum] = None,
|
||||
policies: list[str] = None,
|
||||
match: ValidationMatch = None,
|
||||
) -> Self:
|
||||
self._check_for_app()
|
||||
assert path is not None, "path must not be None"
|
||||
assert fn is not None, "fn must not be None"
|
||||
assert method in [
|
||||
"GET",
|
||||
"HEAD",
|
||||
"POST",
|
||||
"PUT",
|
||||
"DELETE",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"OPTIONS",
|
||||
"HEAD",
|
||||
], "method must be a valid HTTP method"
|
||||
self._routes.append(Route(path, fn, methods=[method], **kwargs))
|
||||
|
||||
Router.route(path, method, registry=self._routes)(fn)
|
||||
|
||||
if authentication:
|
||||
Router.authenticate()(fn)
|
||||
|
||||
if roles or permissions or policies:
|
||||
Router.authorize(roles, permissions, policies, match)(fn)
|
||||
|
||||
return self
|
||||
|
||||
def with_middleware(self, middleware: PartialMiddleware) -> Self:
|
||||
@@ -150,11 +191,11 @@ class WebApp(ApplicationABC):
|
||||
if isinstance(policy, dict):
|
||||
for name, resolver in policy.items():
|
||||
if not isinstance(name, str):
|
||||
_logger.warning(f"Skipping policy at index {i}, name must be a string")
|
||||
self._logger.warning(f"Skipping policy at index {i}, name must be a string")
|
||||
continue
|
||||
|
||||
if not callable(resolver):
|
||||
_logger.warning(f"Skipping policy {name}, resolver must be callable")
|
||||
self._logger.warning(f"Skipping policy {name}, resolver must be callable")
|
||||
continue
|
||||
|
||||
_policies.append(Policy(name, resolver))
|
||||
@@ -162,7 +203,7 @@ class WebApp(ApplicationABC):
|
||||
|
||||
_policies.append(policy)
|
||||
|
||||
self._policy_registry.extend_policies(_policies)
|
||||
self._policies.extend(_policies)
|
||||
|
||||
self.with_middleware(AuthorizationMiddleware)
|
||||
return self
|
||||
@@ -170,24 +211,16 @@ class WebApp(ApplicationABC):
|
||||
def _validate_policies(self):
|
||||
for rule in Router.get_authorization_rules():
|
||||
for policy_name in rule["policies"]:
|
||||
policy = self._policy_registry.get(policy_name)
|
||||
policy = self._policies.get(policy_name)
|
||||
if not policy:
|
||||
_logger.fatal(f"Authorization policy '{policy_name}' not found")
|
||||
self._logger.fatal(f"Authorization policy '{policy_name}' not found")
|
||||
|
||||
async def main(self):
|
||||
_logger.debug(f"Preparing API")
|
||||
self._logger.debug(f"Preparing API")
|
||||
self._validate_policies()
|
||||
|
||||
if self._app is None:
|
||||
routes = [
|
||||
Route(
|
||||
path=route.path,
|
||||
endpoint=self._services.inject(route.endpoint),
|
||||
methods=route.methods,
|
||||
name=route.name,
|
||||
)
|
||||
for route in self._routes + Router.get_routes()
|
||||
]
|
||||
routes = [route.to_starlette(self._services.inject) for route in self._routes.all()]
|
||||
|
||||
app = Starlette(
|
||||
routes=routes,
|
||||
@@ -205,7 +238,7 @@ class WebApp(ApplicationABC):
|
||||
else:
|
||||
app = self._app
|
||||
|
||||
_logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
|
||||
self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
|
||||
|
||||
config = uvicorn.Config(
|
||||
app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio"
|
||||
@@ -213,4 +246,4 @@ class WebApp(ApplicationABC):
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
_logger.info("Shutdown API")
|
||||
self._logger.info("Shutdown API")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from cpl.core.log.logger import Logger
|
||||
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||
|
||||
|
||||
class APILogger(Logger):
|
||||
class APILogger(WrappedLogger):
|
||||
|
||||
def __init__(self, source: str):
|
||||
Logger.__init__(self, source, "api")
|
||||
def __init__(self):
|
||||
WrappedLogger.__init__(self, "api")
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from .authentication import AuthenticationMiddleware
|
||||
from .authorization import AuthorizationMiddleware
|
||||
from .logging import LoggingMiddleware
|
||||
from .request import RequestMiddleware
|
||||
|
||||
@@ -2,8 +2,8 @@ from keycloak import KeycloakAuthenticationError
|
||||
from starlette.types import Scope, Receive, Send
|
||||
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.error import Unauthorized
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.middleware.request import get_request
|
||||
from cpl.api.router import Router
|
||||
from cpl.auth.keycloak import KeycloakClient
|
||||
@@ -11,15 +11,15 @@ from cpl.auth.schema import AuthUserDao, AuthUser
|
||||
from cpl.core.ctx import set_user
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
|
||||
class AuthenticationMiddleware(ASGIMiddleware):
|
||||
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, keycloak: KeycloakClient, user_dao: AuthUserDao):
|
||||
def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._logger = logger
|
||||
|
||||
self._keycloak = keycloak
|
||||
self._user_dao = user_dao
|
||||
|
||||
@@ -28,11 +28,11 @@ class AuthenticationMiddleware(ASGIMiddleware):
|
||||
url = request.url.path
|
||||
|
||||
if url not in Router.get_auth_required_routes():
|
||||
_logger.trace(f"No authentication required for {url}")
|
||||
self._logger.trace(f"No authentication required for {url}")
|
||||
return await self._app(scope, receive, send)
|
||||
|
||||
if not request.headers.get("Authorization"):
|
||||
_logger.debug(f"Unauthorized access to {url}, missing Authorization header")
|
||||
self._logger.debug(f"Unauthorized access to {url}, missing Authorization header")
|
||||
return await Unauthorized(f"Missing header Authorization").asgi_response(scope, receive, send)
|
||||
|
||||
auth_header = request.headers.get("Authorization", None)
|
||||
@@ -41,7 +41,7 @@ class AuthenticationMiddleware(ASGIMiddleware):
|
||||
|
||||
token = auth_header.split("Bearer ")[1]
|
||||
if not await self._verify_login(token):
|
||||
_logger.debug(f"Unauthorized access to {url}, invalid token")
|
||||
self._logger.debug(f"Unauthorized access to {url}, invalid token")
|
||||
return await Unauthorized("Invalid token").asgi_response(scope, receive, send)
|
||||
|
||||
# check user exists in db, if not create
|
||||
@@ -51,7 +51,7 @@ class AuthenticationMiddleware(ASGIMiddleware):
|
||||
|
||||
user = await self._get_or_crate_user(keycloak_id)
|
||||
if user.deleted:
|
||||
_logger.debug(f"Unauthorized access to {url}, user is deleted")
|
||||
self._logger.debug(f"Unauthorized access to {url}, user is deleted")
|
||||
return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
|
||||
|
||||
request.state.user = user
|
||||
@@ -73,8 +73,8 @@ class AuthenticationMiddleware(ASGIMiddleware):
|
||||
token_info = self._keycloak.introspect(token)
|
||||
return token_info.get("active", False)
|
||||
except KeycloakAuthenticationError as e:
|
||||
_logger.debug(f"Keycloak authentication error: {e}")
|
||||
self._logger.debug(f"Keycloak authentication error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
_logger.error(f"Unexpected error during token verification: {e}")
|
||||
self._logger.error(f"Unexpected error during token verification: {e}")
|
||||
return False
|
||||
|
||||
@@ -11,22 +11,27 @@ from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
from cpl.core.ctx.user_context import get_user
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
|
||||
class AuthorizationMiddleware(ASGIMiddleware):
|
||||
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, policies: PolicyRegistry, user_dao: AuthUserDao):
|
||||
def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._logger = logger
|
||||
|
||||
self._policies = policies
|
||||
self._user_dao = user_dao
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
request = get_request()
|
||||
user = get_user()
|
||||
url = request.url.path
|
||||
|
||||
if url not in Router.get_authorization_rules_paths():
|
||||
self._logger.trace(f"No authorization required for {url}")
|
||||
return await self._app(scope, receive, send)
|
||||
|
||||
user = get_user()
|
||||
if not user:
|
||||
return await Unauthorized(f"Unknown user").asgi_response(scope, receive, send)
|
||||
|
||||
@@ -48,17 +53,21 @@ class AuthorizationMiddleware(ASGIMiddleware):
|
||||
|
||||
if rule["permissions"]:
|
||||
if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]):
|
||||
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
|
||||
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
|
||||
scope, receive, send
|
||||
)
|
||||
if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]):
|
||||
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
|
||||
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
|
||||
scope, receive, send
|
||||
)
|
||||
|
||||
for policy_name in rule["policies"]:
|
||||
policy = self._policies.get(policy_name)
|
||||
if not policy:
|
||||
_logger.warning(f"Authorization policy '{policy_name}' not found")
|
||||
self._logger.warning(f"Authorization policy '{policy_name}' not found")
|
||||
continue
|
||||
|
||||
if not await policy.resolve(user):
|
||||
return await Forbidden(f"policy {policy.name} failed").asgi_response(scope, receive, send)
|
||||
|
||||
return await self._call_next(scope, receive, send)
|
||||
return await self._call_next(scope, receive, send)
|
||||
|
||||
@@ -6,15 +6,17 @@ from starlette.types import Receive, Scope, Send
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.middleware.request import get_request
|
||||
|
||||
_logger = APILogger(__name__)
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class LoggingMiddleware(ASGIMiddleware):
|
||||
|
||||
def __init__(self, app):
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, logger: APILogger):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._logger = logger
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] != "http":
|
||||
await self._call_next(scope, receive, send)
|
||||
@@ -53,9 +55,8 @@ class LoggingMiddleware(ASGIMiddleware):
|
||||
}
|
||||
return {key: value for key, value in headers.items() if key in relevant_keys}
|
||||
|
||||
@classmethod
|
||||
async def _log_request(cls, request: Request):
|
||||
_logger.debug(
|
||||
async def _log_request(self, request: Request):
|
||||
self._logger.debug(
|
||||
f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}"
|
||||
)
|
||||
|
||||
@@ -64,7 +65,7 @@ class LoggingMiddleware(ASGIMiddleware):
|
||||
user = get_user()
|
||||
|
||||
request_info = {
|
||||
"headers": cls._filter_relevant_headers(dict(request.headers)),
|
||||
"headers": self._filter_relevant_headers(dict(request.headers)),
|
||||
"args": dict(request.query_params),
|
||||
"form-data": (
|
||||
await request.form()
|
||||
@@ -78,10 +79,9 @@ class LoggingMiddleware(ASGIMiddleware):
|
||||
),
|
||||
}
|
||||
|
||||
_logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
|
||||
self._logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
|
||||
|
||||
@staticmethod
|
||||
async def _log_after_request(request: Request, status_code: int, duration: float):
|
||||
_logger.info(
|
||||
async def _log_after_request(self, request: Request, status_code: int, duration: float):
|
||||
self._logger.info(
|
||||
f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
|
||||
)
|
||||
|
||||
@@ -9,16 +9,19 @@ from starlette.types import Scope, Receive, Send
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.typing import TRequest
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
|
||||
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
|
||||
class RequestMiddleware(ASGIMiddleware):
|
||||
|
||||
def __init__(self, app):
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, logger: APILogger):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._logger = logger
|
||||
|
||||
self._ctx_token = None
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
@@ -33,7 +36,7 @@ class RequestMiddleware(ASGIMiddleware):
|
||||
async def set_request_data(self, request: TRequest):
|
||||
request.state.request_id = uuid4()
|
||||
request.state.start_time = time.time()
|
||||
_logger.trace(f"Set new current request: {request.state.request_id}")
|
||||
self._logger.trace(f"Set new current request: {request.state.request_id}")
|
||||
|
||||
self._ctx_token = _request_context.set(request)
|
||||
|
||||
@@ -45,7 +48,7 @@ class RequestMiddleware(ASGIMiddleware):
|
||||
if self._ctx_token is None:
|
||||
return
|
||||
|
||||
_logger.trace(f"Clearing current request: {request.state.request_id}")
|
||||
self._logger.trace(f"Clearing current request: {request.state.request_id}")
|
||||
_request_context.reset(self._ctx_token)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .api_route import ApiRoute
|
||||
from .policy import Policy
|
||||
from .validation_match import ValidationMatch
|
||||
|
||||
43
src/cpl-api/cpl/api/model/api_route.py
Normal file
43
src/cpl-api/cpl/api/model/api_route.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Callable
|
||||
|
||||
from starlette.routing import Route
|
||||
|
||||
from cpl.api.typing import HTTPMethods
|
||||
|
||||
|
||||
class ApiRoute:
|
||||
|
||||
def __init__(self, path: str, fn: Callable, method: HTTPMethods, **kwargs):
|
||||
self._path = path
|
||||
self._fn = fn
|
||||
self._method = method
|
||||
|
||||
self._kwargs = kwargs
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._fn.__name__
|
||||
|
||||
@property
|
||||
def fn(self) -> Callable:
|
||||
return self._fn
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self._path
|
||||
|
||||
@property
|
||||
def method(self) -> HTTPMethods:
|
||||
return self._method
|
||||
|
||||
@property
|
||||
def kwargs(self) -> dict:
|
||||
return self._kwargs
|
||||
|
||||
def to_starlette(self, wrap_endpoint: Callable = None) -> Route:
|
||||
return Route(
|
||||
self._path,
|
||||
self._fn if not wrap_endpoint else wrap_endpoint(self._fn),
|
||||
methods=[self._method],
|
||||
**self._kwargs,
|
||||
)
|
||||
@@ -1,5 +1,5 @@
|
||||
from asyncio import iscoroutinefunction
|
||||
from typing import Optional, Any, Coroutine, Awaitable
|
||||
from typing import Optional
|
||||
|
||||
from cpl.api.typing import PolicyResolver
|
||||
from cpl.core.ctx import get_user
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
from .policy import PolicyRegistry
|
||||
from .route import RouteRegistry
|
||||
|
||||
@@ -1,23 +1,28 @@
|
||||
from typing import Optional
|
||||
|
||||
from cpl.api.model.policy import Policy
|
||||
from cpl.core.abc.registry_abc import RegistryABC
|
||||
|
||||
|
||||
class PolicyRegistry:
|
||||
class PolicyRegistry(RegistryABC):
|
||||
|
||||
def __init__(self):
|
||||
self._policies: dict[str, Policy] = {}
|
||||
RegistryABC.__init__(self)
|
||||
|
||||
def extend_policies(self, policies: list[Policy]):
|
||||
for policy in policies:
|
||||
self.add_policy(policy)
|
||||
def extend(self, items: list[Policy]):
|
||||
for policy in items:
|
||||
self.add(policy)
|
||||
|
||||
def add_policy(self, policy: Policy):
|
||||
assert isinstance(policy, Policy), "policy must be an instance of Policy"
|
||||
def add(self, item: Policy):
|
||||
assert isinstance(item, Policy), "policy must be an instance of Policy"
|
||||
|
||||
if policy.name in self._policies:
|
||||
raise ValueError(f"Policy {policy.name} is already registered")
|
||||
if item.name in self._items:
|
||||
raise ValueError(f"Policy {item.name} is already registered")
|
||||
|
||||
self._policies[policy.name] = policy
|
||||
self._items[item.name] = item
|
||||
|
||||
def get(self, name: str) -> Optional[Policy]:
|
||||
return self._policies.get(name)
|
||||
def get(self, key: str) -> Optional[Policy]:
|
||||
return self._items.get(key)
|
||||
|
||||
def all(self) -> list[Policy]:
|
||||
return list(self._items.values())
|
||||
|
||||
32
src/cpl-api/cpl/api/registry/route.py
Normal file
32
src/cpl-api/cpl/api/registry/route.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Optional
|
||||
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
from cpl.core.abc.registry_abc import RegistryABC
|
||||
|
||||
|
||||
class RouteRegistry(RegistryABC):
|
||||
|
||||
def __init__(self):
|
||||
RegistryABC.__init__(self)
|
||||
|
||||
def extend(self, items: list[ApiRoute]):
|
||||
for policy in items:
|
||||
self.add(policy)
|
||||
|
||||
def add(self, item: ApiRoute):
|
||||
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
|
||||
|
||||
if item.path in self._items:
|
||||
raise ValueError(f"ApiRoute {item.path} is already registered")
|
||||
|
||||
self._items[item.path] = item
|
||||
|
||||
def set(self, item: ApiRoute):
|
||||
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
|
||||
self._items[item.path] = item
|
||||
|
||||
def get(self, key: str) -> Optional[ApiRoute]:
|
||||
return self._items.get(key)
|
||||
|
||||
def all(self) -> list[ApiRoute]:
|
||||
return list(self._items.values())
|
||||
@@ -1,26 +1,25 @@
|
||||
from enum import Enum
|
||||
|
||||
from starlette.routing import Route
|
||||
|
||||
from cpl.api.model.validation_match import ValidationMatch
|
||||
from cpl.api.registry.route import RouteRegistry
|
||||
from cpl.api.typing import HTTPMethods
|
||||
|
||||
|
||||
class Router:
|
||||
_registered_routes: list[Route] = []
|
||||
_auth_required: list[str] = []
|
||||
_authorization_rules: list[dict] = []
|
||||
|
||||
@classmethod
|
||||
def get_routes(cls) -> list[Route]:
|
||||
return cls._registered_routes
|
||||
_authorization_rules: dict[str, dict] = {}
|
||||
|
||||
@classmethod
|
||||
def get_auth_required_routes(cls) -> list[str]:
|
||||
return cls._auth_required
|
||||
|
||||
@classmethod
|
||||
def get_authorization_rules_paths(cls) -> list[str]:
|
||||
return list(cls._authorization_rules.keys())
|
||||
|
||||
@classmethod
|
||||
def get_authorization_rules(cls) -> list[dict]:
|
||||
return cls._authorization_rules
|
||||
return list(cls._authorization_rules.values())
|
||||
|
||||
@classmethod
|
||||
def authenticate(cls):
|
||||
@@ -42,7 +41,13 @@ class Router:
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def authorize(cls, roles: list[str | Enum]=None, permissions: list[str | Enum]=None, policies: list[str]=None, match: ValidationMatch=None):
|
||||
def authorize(
|
||||
cls,
|
||||
roles: list[str | Enum] = None,
|
||||
permissions: list[str | Enum] = None,
|
||||
policies: list[str] = None,
|
||||
match: ValidationMatch = None,
|
||||
):
|
||||
"""
|
||||
Decorator to mark a route as requiring authorization.
|
||||
Usage:
|
||||
@@ -67,52 +72,65 @@ class Router:
|
||||
permissions[permissions.index(perm)] = perm.value
|
||||
|
||||
def inner(fn):
|
||||
route_path = getattr(fn, "_route_path", None)
|
||||
if not route_path:
|
||||
path = getattr(fn, "_route_path", None)
|
||||
if not path:
|
||||
return fn
|
||||
|
||||
if route_path in cls._authorization_rules:
|
||||
raise ValueError(f"Route {route_path} is already registered for authorization")
|
||||
if path in cls._authorization_rules:
|
||||
raise ValueError(f"Route {path} is already registered for authorization")
|
||||
|
||||
cls._authorization_rules.append({
|
||||
cls._authorization_rules[path] = {
|
||||
"roles": roles or [],
|
||||
"permissions": permissions or [],
|
||||
"policies": policies or [],
|
||||
"match": match or ValidationMatch.all,
|
||||
})
|
||||
}
|
||||
|
||||
return fn
|
||||
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def route(cls, path=None, **kwargs):
|
||||
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs):
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
|
||||
if not registry:
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
routes = ServiceProviderABC.get_global_service(RouteRegistry)
|
||||
else:
|
||||
routes = registry
|
||||
|
||||
def inner(fn):
|
||||
cls._registered_routes.append(Route(path, fn, **kwargs))
|
||||
routes.add(ApiRoute(path, fn, method, **kwargs))
|
||||
setattr(fn, "_route_path", path)
|
||||
return fn
|
||||
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def get(cls, path=None, **kwargs):
|
||||
return cls.route(path, methods=["GET"], **kwargs)
|
||||
def get(cls, path: str, **kwargs):
|
||||
return cls.route(path, "GET", **kwargs)
|
||||
|
||||
@classmethod
|
||||
def post(cls, path=None, **kwargs):
|
||||
return cls.route(path, methods=["POST"], **kwargs)
|
||||
def head(cls, path: str, **kwargs):
|
||||
return cls.route(path, "HEAD", **kwargs)
|
||||
|
||||
@classmethod
|
||||
def head(cls, path=None, **kwargs):
|
||||
return cls.route(path, methods=["HEAD"], **kwargs)
|
||||
def post(cls, path: str, **kwargs):
|
||||
return cls.route(path, "POST", **kwargs)
|
||||
|
||||
@classmethod
|
||||
def put(cls, path=None, **kwargs):
|
||||
return cls.route(path, methods=["PUT"], **kwargs)
|
||||
def put(cls, path: str, **kwargs):
|
||||
return cls.route(path, "PUT", **kwargs)
|
||||
|
||||
@classmethod
|
||||
def delete(cls, path=None, **kwargs):
|
||||
return cls.route(path, methods=["DELETE"], **kwargs)
|
||||
def patch(cls, path: str, **kwargs):
|
||||
return cls.route(path, "PATCH", **kwargs)
|
||||
|
||||
@classmethod
|
||||
def delete(cls, path: str, **kwargs):
|
||||
return cls.route(path, "DELETE", **kwargs)
|
||||
|
||||
@classmethod
|
||||
def override(cls):
|
||||
@@ -125,13 +143,22 @@ class Router:
|
||||
...
|
||||
"""
|
||||
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
routes = ServiceProviderABC.get_global_service(RouteRegistry)
|
||||
|
||||
def inner(fn):
|
||||
route_path = getattr(fn, "_route_path", None)
|
||||
path = getattr(fn, "_route_path", None)
|
||||
if path is None:
|
||||
raise ValueError("Cannot override a route that has not been registered yet")
|
||||
|
||||
routes = list(filter(lambda x: x.path == route_path, cls._registered_routes))
|
||||
for route in routes[:-1]:
|
||||
cls._registered_routes.remove(route)
|
||||
route = routes.get(path)
|
||||
if route is None:
|
||||
raise ValueError(f"Cannot override a route that does not exist: {path}")
|
||||
|
||||
routes.add(ApiRoute(path, fn, route.method, **route.kwargs))
|
||||
setattr(fn, "_route_path", path)
|
||||
return fn
|
||||
|
||||
return inner
|
||||
|
||||
@@ -9,11 +9,11 @@ from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.auth.schema import AuthUser
|
||||
|
||||
TRequest = Union[Request, WebSocket]
|
||||
HTTPMethods = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
|
||||
HTTPMethods = Literal["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
|
||||
PartialMiddleware = Union[
|
||||
ASGIMiddleware,
|
||||
Type[ASGIMiddleware],
|
||||
Middleware,
|
||||
Callable[[ASGIApp], ASGIApp],
|
||||
]
|
||||
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]
|
||||
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]
|
||||
|
||||
@@ -2,9 +2,8 @@ from abc import ABC, abstractmethod
|
||||
from typing import Callable, Self
|
||||
|
||||
from cpl.application.host import Host
|
||||
from cpl.core.console.console import Console
|
||||
from cpl.core.log import LogSettings
|
||||
from cpl.core.log.log_level import LogLevel
|
||||
from cpl.core.log.log_settings import LogSettings
|
||||
from cpl.core.log.logger_abc import LoggerABC
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
|
||||
class ApplicationExtensionABC(ABC):
|
||||
|
||||
@@ -6,7 +6,7 @@ from cpl.auth import permission as _permission
|
||||
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin as _KeycloakAdmin
|
||||
from cpl.auth.keycloak.keycloak_client import KeycloakClient as _KeycloakClient
|
||||
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
|
||||
from .auth_logger import AuthLogger
|
||||
from .logger import AuthLogger
|
||||
from .keycloak_settings import KeycloakSettings
|
||||
from .permission_seeder import PermissionSeeder
|
||||
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
from cpl.core.log import Logger
|
||||
from cpl.core.typing import Source
|
||||
|
||||
|
||||
class AuthLogger(Logger):
|
||||
|
||||
def __init__(self, source: Source):
|
||||
Logger.__init__(self, source, "auth")
|
||||
@@ -1,15 +1,13 @@
|
||||
from keycloak import KeycloakAdmin as _KeycloakAdmin, KeycloakOpenIDConnection
|
||||
|
||||
from cpl.auth.auth_logger import AuthLogger
|
||||
from cpl.auth.keycloak_settings import KeycloakSettings
|
||||
|
||||
_logger = AuthLogger("keycloak")
|
||||
from cpl.auth.logger import AuthLogger
|
||||
|
||||
|
||||
class KeycloakAdmin(_KeycloakAdmin):
|
||||
|
||||
def __init__(self, settings: KeycloakSettings):
|
||||
_logger.info("Initializing Keycloak admin")
|
||||
def __init__(self, logger: AuthLogger, settings: KeycloakSettings):
|
||||
# logger.info("Initializing Keycloak admin")
|
||||
_connection = KeycloakOpenIDConnection(
|
||||
server_url=settings.url,
|
||||
client_id=settings.client_id,
|
||||
|
||||
@@ -2,15 +2,13 @@ from typing import Optional
|
||||
|
||||
from keycloak import KeycloakOpenID
|
||||
|
||||
from cpl.auth.auth_logger import AuthLogger
|
||||
from cpl.auth.logger import AuthLogger
|
||||
from cpl.auth.keycloak_settings import KeycloakSettings
|
||||
|
||||
_logger = AuthLogger("keycloak")
|
||||
|
||||
|
||||
class KeycloakClient(KeycloakOpenID):
|
||||
|
||||
def __init__(self, settings: KeycloakSettings):
|
||||
def __init__(self, logger: AuthLogger, settings: KeycloakSettings):
|
||||
KeycloakOpenID.__init__(
|
||||
self,
|
||||
server_url=settings.url,
|
||||
@@ -18,7 +16,7 @@ class KeycloakClient(KeycloakOpenID):
|
||||
realm_name=settings.realm,
|
||||
client_secret_key=settings.client_secret,
|
||||
)
|
||||
_logger.info("Initializing Keycloak client")
|
||||
logger.info("Initializing Keycloak client")
|
||||
|
||||
def get_user_id(self, token: str) -> Optional[str]:
|
||||
info = self.introspect(token)
|
||||
|
||||
7
src/cpl-auth/cpl/auth/logger.py
Normal file
7
src/cpl-auth/cpl/auth/logger.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||
|
||||
|
||||
class AuthLogger(WrappedLogger):
|
||||
|
||||
def __init__(self):
|
||||
WrappedLogger.__init__(self, "auth")
|
||||
@@ -14,14 +14,13 @@ from cpl.auth.schema import (
|
||||
)
|
||||
from cpl.core.utils.get_value import get_value
|
||||
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
from cpl.database.logger import DBLogger
|
||||
|
||||
|
||||
class PermissionSeeder(DataSeederABC):
|
||||
def __init__(
|
||||
self,
|
||||
logger: DBLogger,
|
||||
permission_dao: PermissionDao,
|
||||
role_dao: RoleDao,
|
||||
role_permission_dao: RolePermissionDao,
|
||||
@@ -29,6 +28,7 @@ class PermissionSeeder(DataSeederABC):
|
||||
api_key_permission_dao: ApiKeyPermissionDao,
|
||||
):
|
||||
DataSeederABC.__init__(self)
|
||||
self._logger = logger
|
||||
self._permission_dao = permission_dao
|
||||
self._role_dao = role_dao
|
||||
self._role_permission_dao = role_permission_dao
|
||||
@@ -40,7 +40,7 @@ class PermissionSeeder(DataSeederABC):
|
||||
possible_permissions = [permission for permission in PermissionsRegistry.get()]
|
||||
|
||||
if len(permissions) == len(possible_permissions):
|
||||
_logger.info("Permissions already existing")
|
||||
self._logger.info("Permissions already existing")
|
||||
await self._update_missing_descriptions()
|
||||
return
|
||||
|
||||
@@ -53,7 +53,7 @@ class PermissionSeeder(DataSeederABC):
|
||||
|
||||
await self._permission_dao.delete_many(to_delete, hard_delete=True)
|
||||
|
||||
_logger.warning("Permissions incomplete")
|
||||
self._logger.warning("Permissions incomplete")
|
||||
permission_names = [permission.name for permission in permissions]
|
||||
await self._permission_dao.create_many(
|
||||
[
|
||||
|
||||
@@ -3,15 +3,12 @@ from typing import Optional
|
||||
from cpl.auth.schema._administration.api_key import ApiKey
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class ApiKeyDao(DbModelDaoABC[ApiKey]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, __name__, ApiKey, TableManager.get("api_keys"))
|
||||
DbModelDaoABC.__init__(self, ApiKey, TableManager.get("api_keys"))
|
||||
|
||||
self.attribute(ApiKey.identifier, str)
|
||||
self.attribute(ApiKey.key, str, "keystring")
|
||||
|
||||
@@ -6,14 +6,12 @@ from async_property import async_property
|
||||
from keycloak import KeycloakGetError
|
||||
|
||||
from cpl.auth.keycloak import KeycloakAdmin
|
||||
from cpl.auth.auth_logger import AuthLogger
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
_logger = AuthLogger(__name__)
|
||||
|
||||
|
||||
class AuthUser(DbModelABC):
|
||||
def __init__(
|
||||
@@ -38,12 +36,13 @@ class AuthUser(DbModelABC):
|
||||
return "ANONYMOUS"
|
||||
|
||||
try:
|
||||
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
return keycloak_admin.get_user(self._keycloak_id).get("username")
|
||||
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
return keycloak.get_user(self._keycloak_id).get("username")
|
||||
except KeycloakGetError as e:
|
||||
return "UNKNOWN"
|
||||
except Exception as e:
|
||||
_logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||
logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||
return "UNKNOWN"
|
||||
|
||||
@property
|
||||
@@ -52,12 +51,13 @@ class AuthUser(DbModelABC):
|
||||
return "ANONYMOUS"
|
||||
|
||||
try:
|
||||
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
return keycloak_admin.get_user(self._keycloak_id).get("email")
|
||||
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
return keycloak.get_user(self._keycloak_id).get("email")
|
||||
except KeycloakGetError as e:
|
||||
return "UNKNOWN"
|
||||
except Exception as e:
|
||||
_logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||
logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||
return "UNKNOWN"
|
||||
|
||||
@async_property
|
||||
|
||||
@@ -4,17 +4,14 @@ from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.auth.schema._administration.auth_user import AuthUser
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class AuthUserDao(DbModelDaoABC[AuthUser]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, __name__, AuthUser, TableManager.get("auth_users"))
|
||||
DbModelDaoABC.__init__(self, AuthUser, TableManager.get("auth_users"))
|
||||
|
||||
self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId")
|
||||
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
from cpl.auth.schema._permission.api_key_permission import ApiKeyPermission
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class ApiKeyPermissionDao(DbModelDaoABC[ApiKeyPermission]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, __name__, ApiKeyPermission, TableManager.get("api_key_permissions"))
|
||||
DbModelDaoABC.__init__(self, ApiKeyPermission, TableManager.get("api_key_permissions"))
|
||||
|
||||
self.attribute(ApiKeyPermission.api_key_id, int)
|
||||
self.attribute(ApiKeyPermission.permission_id, int)
|
||||
|
||||
@@ -3,15 +3,12 @@ from typing import Optional
|
||||
from cpl.auth.schema._permission.permission import Permission
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class PermissionDao(DbModelDaoABC[Permission]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, __name__, Permission, TableManager.get("permissions"))
|
||||
DbModelDaoABC.__init__(self, Permission, TableManager.get("permissions"))
|
||||
|
||||
self.attribute(Permission.name, str)
|
||||
self.attribute(Permission.description, Optional[str])
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
from cpl.auth.schema._permission.role import Role
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class RoleDao(DbModelDaoABC[Role]):
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, __name__, Role, TableManager.get("roles"))
|
||||
DbModelDaoABC.__init__(self, Role, TableManager.get("roles"))
|
||||
self.attribute(Role.name, str)
|
||||
self.attribute(Role.description, str)
|
||||
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
from cpl.auth.schema._permission.role_permission import RolePermission
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class RolePermissionDao(DbModelDaoABC[RolePermission]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, __name__, RolePermission, TableManager.get("role_permissions"))
|
||||
DbModelDaoABC.__init__(self, RolePermission, TableManager.get("role_permissions"))
|
||||
|
||||
self.attribute(RolePermission.role_id, int)
|
||||
self.attribute(RolePermission.permission_id, int)
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
from cpl.auth.schema._permission.role_user import RoleUser
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class RoleUserDao(DbModelDaoABC[RoleUser]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, __name__, RoleUser, TableManager.get("role_users"))
|
||||
DbModelDaoABC.__init__(self, RoleUser, TableManager.get("role_users"))
|
||||
|
||||
self.attribute(RoleUser.role_id, int)
|
||||
self.attribute(RoleUser.user_id, int)
|
||||
|
||||
0
src/cpl-core/cpl/core/abc/__init__.py
Normal file
0
src/cpl-core/cpl/core/abc/__init__.py
Normal file
23
src/cpl-core/cpl/core/abc/registry_abc.py
Normal file
23
src/cpl-core/cpl/core/abc/registry_abc.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Generic
|
||||
|
||||
from cpl.core.typing import T
|
||||
|
||||
|
||||
class RegistryABC(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
self._items: dict[str, T] = {}
|
||||
|
||||
@abstractmethod
|
||||
def extend(self, items: list[T]) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def add(self, item: T) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str) -> T | None: ...
|
||||
|
||||
@abstractmethod
|
||||
def all(self) -> list[T]: ...
|
||||
@@ -1,17 +1,18 @@
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
|
||||
from cpl.auth.auth_logger import AuthLogger
|
||||
from cpl.auth.schema._administration.auth_user import AuthUser
|
||||
|
||||
_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None)
|
||||
|
||||
_logger = AuthLogger(__name__)
|
||||
|
||||
def set_user(user: Optional[AuthUser]):
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
from cpl.core.log.logger_abc import LoggerABC
|
||||
|
||||
def set_user(user_id: Optional[AuthUser]):
|
||||
_logger.trace("Setting user context", user_id)
|
||||
_user_context.set(user_id)
|
||||
logger = ServiceProviderABC.get_global_service(LoggerABC)
|
||||
logger.trace("Setting user context", user.id)
|
||||
_user_context.set(user)
|
||||
|
||||
|
||||
def get_user() -> Optional[AuthUser]:
|
||||
|
||||
@@ -2,3 +2,4 @@ from .logger import Logger
|
||||
from .logger_abc import LoggerABC
|
||||
from .log_level import LogLevel
|
||||
from .log_settings import LogSettings
|
||||
from .structured_logger import StructuredLogger
|
||||
|
||||
111
src/cpl-core/cpl/core/log/structured_logger.py
Normal file
111
src/cpl-core/cpl/core/log/structured_logger.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import json
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
from cpl.core.log.log_level import LogLevel
|
||||
from cpl.core.log.logger import Logger
|
||||
from cpl.core.typing import Source, Messages
|
||||
|
||||
|
||||
class StructuredLogger(Logger):
|
||||
|
||||
def __init__(self, source: Source, file_prefix: str = None):
|
||||
Logger.__init__(self, source, file_prefix)
|
||||
|
||||
@property
|
||||
def log_file(self):
|
||||
return f"logs/{self._file_prefix}_{datetime.now().strftime('%Y-%m-%d')}.jsonl"
|
||||
|
||||
def _log(self, level: LogLevel, *messages: Messages):
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
formatted_message = self._format_message(level.value, timestamp, *messages)
|
||||
structured_message = self._get_structured_message(level.value, timestamp, formatted_message)
|
||||
|
||||
self._write_log_to_file(level, structured_message)
|
||||
self._write_to_console(level, formatted_message)
|
||||
except Exception as e:
|
||||
print(f"Error while logging: {e} -> {traceback.format_exc()}")
|
||||
|
||||
def _get_structured_message(self, level: str, timestamp: str, messages: str) -> str:
|
||||
structured_message = {
|
||||
"timestamp": timestamp,
|
||||
"level": level.upper(),
|
||||
"source": self._source,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
self._enrich_message_with_request(structured_message)
|
||||
self._enrich_message_with_user(structured_message)
|
||||
|
||||
return json.dumps(structured_message, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def _scope_to_json(request: Request, include_headers: bool = False) -> dict:
|
||||
scope = dict(request.scope)
|
||||
|
||||
def convert(value):
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8")
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [convert(v) for v in value]
|
||||
if isinstance(value, dict):
|
||||
return {str(k): convert(v) for k, v in value.items()}
|
||||
if not isinstance(value, (str, int, float, bool, type(None))):
|
||||
return str(value)
|
||||
return value
|
||||
|
||||
serializable_scope = {str(k): convert(v) for k, v in scope.items()}
|
||||
|
||||
if not include_headers and "headers" in serializable_scope:
|
||||
serializable_scope["headers"] = "<omitted>"
|
||||
|
||||
return serializable_scope
|
||||
|
||||
def _enrich_message_with_request(self, message: dict):
|
||||
if importlib.util.find_spec("cpl.api") is None:
|
||||
return
|
||||
|
||||
from cpl.api.middleware.request import get_request
|
||||
from starlette.requests import Request
|
||||
|
||||
request = get_request()
|
||||
|
||||
if request is None:
|
||||
return
|
||||
|
||||
message["request"] = {
|
||||
"url": str(request.url),
|
||||
"method": request.method,
|
||||
"scope": self._scope_to_json(request),
|
||||
}
|
||||
if isinstance(request, Request) and request.scope == "http":
|
||||
request: Request = request # fix typing for IDEs
|
||||
|
||||
message["request"]["data"] = asyncio.create_task(request.body())
|
||||
|
||||
@staticmethod
|
||||
def _enrich_message_with_user(message: dict):
|
||||
if importlib.util.find_spec("cpl-auth") is None:
|
||||
return
|
||||
|
||||
from cpl.core.ctx import get_user
|
||||
|
||||
user = get_user()
|
||||
if user is None:
|
||||
return
|
||||
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin
|
||||
|
||||
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
kc_user = keycloak.get_user(user.keycloak_id)
|
||||
message["user"] = {
|
||||
"id": str(user.id),
|
||||
"username": kc_user.get("username"),
|
||||
"email": kc_user.get("email"),
|
||||
}
|
||||
100
src/cpl-core/cpl/core/log/wrapped_logger.py
Normal file
100
src/cpl-core/cpl/core/log/wrapped_logger.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import inspect
|
||||
from typing import Type
|
||||
|
||||
from cpl.core.log import LoggerABC, LogLevel
|
||||
from cpl.core.typing import Messages
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
|
||||
class WrappedLogger(LoggerABC):
|
||||
|
||||
def __init__(self, file_prefix: str):
|
||||
LoggerABC.__init__(self)
|
||||
assert file_prefix is not None and file_prefix != "", "file_prefix must be a non-empty string"
|
||||
|
||||
self._source = None
|
||||
self._file_prefix = file_prefix
|
||||
|
||||
self._set_logger()
|
||||
|
||||
@ServiceProviderABC.inject
|
||||
def _set_logger(self, services: ServiceProviderABC):
|
||||
from cpl.core.log import Logger
|
||||
|
||||
t_logger: Type[Logger] = services.get_service_type(LoggerABC)
|
||||
if t_logger is None:
|
||||
raise Exception("No LoggerABC service registered in ServiceProviderABC")
|
||||
|
||||
self._logger = t_logger(self._source, self._file_prefix)
|
||||
|
||||
def set_level(self, level: LogLevel):
|
||||
self._logger.set_level(level)
|
||||
|
||||
def _format_message(self, level: str, timestamp, *messages: Messages) -> str:
|
||||
return self._logger._format_message(level, timestamp, *messages)
|
||||
|
||||
@staticmethod
|
||||
def _get_source() -> str | None:
|
||||
stack = inspect.stack()
|
||||
if len(stack) <= 1:
|
||||
return None
|
||||
|
||||
from cpl.dependency import ServiceCollection
|
||||
|
||||
ignore_classes = [
|
||||
ServiceProviderABC,
|
||||
ServiceProviderABC.__subclasses__(),
|
||||
ServiceCollection,
|
||||
WrappedLogger,
|
||||
WrappedLogger.__subclasses__(),
|
||||
]
|
||||
|
||||
ignore_modules = [x.__module__ for x in ignore_classes if isinstance(x, type)]
|
||||
|
||||
for i, frame_info in enumerate(stack[1:]):
|
||||
module = inspect.getmodule(frame_info.frame)
|
||||
if module is None:
|
||||
continue
|
||||
|
||||
if module.__name__ in ignore_classes or module in ignore_classes:
|
||||
continue
|
||||
|
||||
if module in ignore_modules or module.__name__ in ignore_modules:
|
||||
continue
|
||||
|
||||
if module.__name__ != __name__:
|
||||
return module.__name__
|
||||
|
||||
return None
|
||||
|
||||
def _set_source(self):
|
||||
self._source = self._get_source()
|
||||
self._set_logger()
|
||||
|
||||
def header(self, string: str):
|
||||
self._set_source()
|
||||
self._logger.header(string)
|
||||
|
||||
def trace(self, *messages: Messages):
|
||||
self._set_source()
|
||||
self._logger.trace(*messages)
|
||||
|
||||
def debug(self, *messages: Messages):
|
||||
self._set_source()
|
||||
self._logger.debug(*messages)
|
||||
|
||||
def info(self, *messages: Messages):
|
||||
self._set_source()
|
||||
self._logger.info(*messages)
|
||||
|
||||
def warning(self, *messages: Messages):
|
||||
self._set_source()
|
||||
self._logger.warning(*messages)
|
||||
|
||||
def error(self, messages: str, e: Exception = None):
|
||||
self._set_source()
|
||||
self._logger.error(messages, e)
|
||||
|
||||
def fatal(self, messages: str, e: Exception = None):
|
||||
self._set_source()
|
||||
self._logger.fatal(messages, e)
|
||||
@@ -9,7 +9,7 @@ from cpl.core.utils.get_value import get_value
|
||||
from cpl.core.utils.string import String
|
||||
from cpl.database.abc.db_context_abc import DBContextABC
|
||||
from cpl.database.const import DATETIME_FORMAT
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
|
||||
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
|
||||
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
|
||||
@@ -18,16 +18,12 @@ from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSor
|
||||
class DataAccessObjectABC(ABC, Generic[T_DBM]):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
|
||||
def __init__(self, model_type: Type[T_DBM], table_name: str):
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
self._db = ServiceProviderABC.get_global_service(DBContextABC)
|
||||
|
||||
self._logger = DBLogger(source)
|
||||
self._model_type = model_type
|
||||
self._table_name = table_name
|
||||
|
||||
self._logger = DBLogger(source)
|
||||
self._logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||
self._model_type = model_type
|
||||
self._table_name = table_name
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ from cpl.database.abc.db_model_abc import DbModelABC
|
||||
class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
|
||||
DataAccessObjectABC.__init__(self, source, model_type, table_name)
|
||||
def __init__(self, model_type: Type[T_DBM], table_name: str):
|
||||
DataAccessObjectABC.__init__(self, model_type, table_name)
|
||||
|
||||
self.attribute(DbModelABC.id, int, ignore=True)
|
||||
self.attribute(DbModelABC.deleted, bool)
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
from cpl.core.log import Logger
|
||||
from cpl.core.typing import Source
|
||||
|
||||
|
||||
class DBLogger(Logger):
|
||||
|
||||
def __init__(self, source: Source):
|
||||
Logger.__init__(self, source, "db")
|
||||
7
src/cpl-database/cpl/database/logger.py
Normal file
7
src/cpl-database/cpl/database/logger.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||
|
||||
|
||||
class DBLogger(WrappedLogger):
|
||||
|
||||
def __init__(self):
|
||||
WrappedLogger.__init__(self, "db")
|
||||
@@ -4,18 +4,17 @@ from typing import Any, List, Dict, Tuple, Union
|
||||
from mysql.connector import Error as MySQLError, PoolError
|
||||
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.database.abc.db_context_abc import DBContextABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.model.database_settings import DatabaseSettings
|
||||
from cpl.database.mysql.mysql_pool import MySQLPool
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class DBContext(DBContextABC):
|
||||
def __init__(self):
|
||||
def __init__(self, logger: DBLogger):
|
||||
DBContextABC.__init__(self)
|
||||
self._logger = logger
|
||||
|
||||
self._pool: MySQLPool = None
|
||||
self._fails = 0
|
||||
|
||||
@@ -23,62 +22,62 @@ class DBContext(DBContextABC):
|
||||
|
||||
def connect(self, database_settings: DatabaseSettings):
|
||||
try:
|
||||
_logger.debug("Connecting to database")
|
||||
self._logger.debug("Connecting to database")
|
||||
self._pool = MySQLPool(
|
||||
database_settings,
|
||||
)
|
||||
_logger.info("Connected to database")
|
||||
self._logger.info("Connected to database")
|
||||
except Exception as e:
|
||||
_logger.fatal("Connecting to database failed", e)
|
||||
self._logger.fatal("Connecting to database failed", e)
|
||||
|
||||
async def execute(self, statement: str, args=None, multi=True) -> List[List]:
|
||||
_logger.trace(f"execute {statement} with args: {args}")
|
||||
self._logger.trace(f"execute {statement} with args: {args}")
|
||||
return await self._pool.execute(statement, args, multi)
|
||||
|
||||
async def select_map(self, statement: str, args=None) -> List[Dict]:
|
||||
_logger.trace(f"select {statement} with args: {args}")
|
||||
self._logger.trace(f"select {statement} with args: {args}")
|
||||
try:
|
||||
return await self._pool.select_map(statement, args)
|
||||
except (MySQLError, PoolError) as e:
|
||||
if self._fails >= 3:
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
uid = uuid.uuid4()
|
||||
raise Exception(
|
||||
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
|
||||
)
|
||||
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._fails += 1
|
||||
try:
|
||||
_logger.debug("Retry select")
|
||||
self._logger.debug("Retry select")
|
||||
return await self.select_map(statement, args)
|
||||
except Exception as e:
|
||||
pass
|
||||
return []
|
||||
except Exception as e:
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
raise e
|
||||
|
||||
async def select(self, statement: str, args=None) -> Union[List[str], List[Tuple], List[Any]]:
|
||||
_logger.trace(f"select {statement} with args: {args}")
|
||||
self._logger.trace(f"select {statement} with args: {args}")
|
||||
try:
|
||||
return await self._pool.select(statement, args)
|
||||
except (MySQLError, PoolError) as e:
|
||||
if self._fails >= 3:
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
uid = uuid.uuid4()
|
||||
raise Exception(
|
||||
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
|
||||
)
|
||||
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._fails += 1
|
||||
try:
|
||||
_logger.debug("Retry select")
|
||||
self._logger.debug("Retry select")
|
||||
return await self.select(statement, args)
|
||||
except Exception as e:
|
||||
pass
|
||||
return []
|
||||
except Exception as e:
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
raise e
|
||||
|
||||
@@ -4,10 +4,9 @@ import sqlparse
|
||||
from mysql.connector.aio import MySQLConnectionPool
|
||||
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.model import DatabaseSettings
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class MySQLPool:
|
||||
@@ -36,7 +35,8 @@ class MySQLPool:
|
||||
await cursor.execute("SELECT 1")
|
||||
await cursor.fetchall()
|
||||
except Exception as e:
|
||||
_logger.fatal(f"Error connecting to the database: {e}")
|
||||
logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||
logger.fatal(f"Error connecting to the database: {e}")
|
||||
finally:
|
||||
await con.close()
|
||||
|
||||
|
||||
@@ -7,16 +7,16 @@ from psycopg_pool import PoolTimeout
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.database.abc.db_context_abc import DBContextABC
|
||||
from cpl.database.database_settings import DatabaseSettings
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.model import DatabaseSettings
|
||||
from cpl.database.postgres.postgres_pool import PostgresPool
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class DBContext(DBContextABC):
|
||||
def __init__(self):
|
||||
def __init__(self, logger: DBLogger):
|
||||
DBContextABC.__init__(self)
|
||||
|
||||
self._logger = logger
|
||||
self._pool: PostgresPool = None
|
||||
self._fails = 0
|
||||
|
||||
@@ -24,63 +24,63 @@ class DBContext(DBContextABC):
|
||||
|
||||
def connect(self, database_settings: DatabaseSettings):
|
||||
try:
|
||||
_logger.debug("Connecting to database")
|
||||
self._logger.debug("Connecting to database")
|
||||
self._pool = PostgresPool(
|
||||
database_settings,
|
||||
Environment.get("DB_POOL_SIZE", int, 1),
|
||||
)
|
||||
_logger.info("Connected to database")
|
||||
self._logger.info("Connected to database")
|
||||
except Exception as e:
|
||||
_logger.fatal("Connecting to database failed", e)
|
||||
self._logger.fatal("Connecting to database failed", e)
|
||||
|
||||
async def execute(self, statement: str, args=None, multi=True) -> list[list]:
|
||||
_logger.trace(f"execute {statement} with args: {args}")
|
||||
self._logger.trace(f"execute {statement} with args: {args}")
|
||||
return await self._pool.execute(statement, args, multi)
|
||||
|
||||
async def select_map(self, statement: str, args=None) -> list[dict]:
|
||||
_logger.trace(f"select {statement} with args: {args}")
|
||||
self._logger.trace(f"select {statement} with args: {args}")
|
||||
try:
|
||||
return await self._pool.select_map(statement, args)
|
||||
except (OperationalError, PoolTimeout) as e:
|
||||
if self._fails >= 3:
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
uid = uuid.uuid4()
|
||||
raise Exception(
|
||||
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
|
||||
)
|
||||
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._fails += 1
|
||||
try:
|
||||
_logger.debug("Retry select")
|
||||
self._logger.debug("Retry select")
|
||||
return await self.select_map(statement, args)
|
||||
except Exception as e:
|
||||
pass
|
||||
return []
|
||||
except Exception as e:
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
raise e
|
||||
|
||||
async def select(self, statement: str, args=None) -> list[str] | list[tuple] | list[Any]:
|
||||
_logger.trace(f"select {statement} with args: {args}")
|
||||
self._logger.trace(f"select {statement} with args: {args}")
|
||||
try:
|
||||
return await self._pool.select(statement, args)
|
||||
except (OperationalError, PoolTimeout) as e:
|
||||
if self._fails >= 3:
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
uid = uuid.uuid4()
|
||||
raise Exception(
|
||||
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
|
||||
)
|
||||
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._fails += 1
|
||||
try:
|
||||
_logger.debug("Retry select")
|
||||
self._logger.debug("Retry select")
|
||||
return await self.select(statement, args)
|
||||
except Exception as e:
|
||||
pass
|
||||
return []
|
||||
except Exception as e:
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
raise e
|
||||
|
||||
@@ -5,10 +5,9 @@ from psycopg import sql
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.model import DatabaseSettings
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class PostgresPool:
|
||||
@@ -38,7 +37,8 @@ class PostgresPool:
|
||||
await pool.check_connection(con)
|
||||
except PoolTimeout as e:
|
||||
await pool.close()
|
||||
_logger.fatal(f"Failed to connect to the database", e)
|
||||
logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||
logger.fatal(f"Failed to connect to the database", e)
|
||||
self._pool = pool
|
||||
|
||||
return self._pool
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.schema.executed_migration import ExecutedMigration
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class ExecutedMigrationDao(DataAccessObjectABC[ExecutedMigration]):
|
||||
|
||||
def __init__(self):
|
||||
DataAccessObjectABC.__init__(self, __name__, ExecutedMigration, TableManager.get("executed_migrations"))
|
||||
DataAccessObjectABC.__init__(self, ExecutedMigration, TableManager.get("executed_migrations"))
|
||||
|
||||
self.attribute(ExecutedMigration.migration_id, str, primary_key=True, db_name="migrationId")
|
||||
|
||||
@@ -2,18 +2,17 @@ import glob
|
||||
import os
|
||||
|
||||
from cpl.database.abc import DBContextABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.model import Migration
|
||||
from cpl.database.model.server_type import ServerType, ServerTypes
|
||||
from cpl.database.schema.executed_migration import ExecutedMigration
|
||||
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class MigrationService:
|
||||
|
||||
def __init__(self, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao):
|
||||
def __init__(self, logger: DBLogger, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao):
|
||||
self._logger = logger
|
||||
self._db = db
|
||||
self._executedMigrationDao = executedMigrationDao
|
||||
|
||||
@@ -96,13 +95,13 @@ class MigrationService:
|
||||
if migration_from_db is not None:
|
||||
continue
|
||||
|
||||
_logger.debug(f"Running upgrade migration: {migration.name}")
|
||||
self._logger.debug(f"Running upgrade migration: {migration.name}")
|
||||
|
||||
await self._db.execute(migration.script, multi=True)
|
||||
|
||||
await self._executedMigrationDao.create(ExecutedMigration(migration.name), skip_editor=True)
|
||||
except Exception as e:
|
||||
_logger.fatal(
|
||||
self._logger.fatal(
|
||||
f"Migration failed: {migration.name}\n{active_statement}",
|
||||
e,
|
||||
)
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class SeederService:
|
||||
|
||||
def __init__(self, provider: ServiceProviderABC):
|
||||
self._provider = provider
|
||||
self._logger = provider.get_service(DBLogger)
|
||||
|
||||
async def seed(self):
|
||||
seeders = self._provider.get_services(DataSeederABC)
|
||||
_logger.debug(f"Found {len(seeders)} seeders")
|
||||
self._logger.debug(f"Found {len(seeders)} seeders")
|
||||
for seeder in seeders:
|
||||
await seeder.seed()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Union, Type, Callable, Self
|
||||
|
||||
from cpl.core.log.logger import Logger
|
||||
from cpl.core.log.logger_abc import LoggerABC
|
||||
from cpl.core.typing import T, Service
|
||||
from cpl.dependency.service_descriptor import ServiceDescriptor
|
||||
@@ -80,5 +79,20 @@ class ServiceCollection:
|
||||
return self
|
||||
|
||||
def add_logging(self) -> Self:
|
||||
from cpl.core.log.logger import Logger
|
||||
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||
|
||||
self.add_transient(LoggerABC, Logger)
|
||||
for wrapper in WrappedLogger.__subclasses__():
|
||||
self.add_transient(wrapper)
|
||||
return self
|
||||
|
||||
def add_structured_logging(self) -> Self:
|
||||
from cpl.core.log.structured_logger import StructuredLogger
|
||||
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||
|
||||
self.add_transient(LoggerABC, StructuredLogger)
|
||||
|
||||
for wrapper in WrappedLogger.__subclasses__():
|
||||
self.add_transient(wrapper)
|
||||
return self
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import copy
|
||||
import typing
|
||||
from inspect import signature, Parameter, Signature
|
||||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC
|
||||
@@ -158,6 +158,13 @@ class ServiceProvider(ServiceProviderABC):
|
||||
|
||||
return implementation
|
||||
|
||||
|
||||
def get_service_type(self, service_type: Type[T]) -> Optional[Type[T]]:
|
||||
for descriptor in self._service_descriptors:
|
||||
if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type):
|
||||
return descriptor.service_type
|
||||
return None
|
||||
|
||||
def get_services(self, service_type: T, *args, **kwargs) -> list[Optional[R]]:
|
||||
implementations = []
|
||||
|
||||
@@ -167,3 +174,10 @@ class ServiceProvider(ServiceProviderABC):
|
||||
implementations.extend(self._get_services(service_type))
|
||||
|
||||
return implementations
|
||||
|
||||
def get_service_types(self, service_type: Type[T]) -> list[Type[T]]:
|
||||
types = []
|
||||
for descriptor in self._service_descriptors:
|
||||
if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type):
|
||||
types.append(descriptor.service_type)
|
||||
return types
|
||||
|
||||
@@ -85,6 +85,20 @@ class ServiceProviderABC(ABC):
|
||||
Object of type Optional[:class:`cpl.core.type.T`]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_service_type(self,instance_type: Type[T]) -> Optional[Type[T]]:
|
||||
r"""Returns the registered service type for loggers
|
||||
|
||||
Parameter
|
||||
---------
|
||||
instance_type: :class:`cpl.core.type.T`
|
||||
The type of the searched instance
|
||||
|
||||
Returns
|
||||
-------
|
||||
Object of type Optional[:class:`type`]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_services(self, service_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
|
||||
r"""Returns instance of given type
|
||||
@@ -99,6 +113,20 @@ class ServiceProviderABC(ABC):
|
||||
Object of type list[Optional[:class:`cpl.core.type.T`]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_service_types(self, service_type: Type[T]) -> list[Type[T]]:
|
||||
r"""Returns all registered service types
|
||||
|
||||
Parameter
|
||||
---------
|
||||
service_type: :class:`cpl.core.type.T`
|
||||
The type of the searched instance
|
||||
|
||||
Returns
|
||||
-------
|
||||
Object of type list[:class:`type`]
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def inject(cls, f=None):
|
||||
r"""Decorator to allow injection into static and class methods
|
||||
|
||||
@@ -3,7 +3,7 @@ from .abc.email_client_abc import EMailClientABC
|
||||
from .email_client import EMailClient
|
||||
from .email_client_settings import EMailClientSettings
|
||||
from .email_model import EMail
|
||||
from .mail_logger import MailLogger
|
||||
from .logger import MailLogger
|
||||
|
||||
|
||||
def add_mail(collection: _ServiceCollection):
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional
|
||||
from cpl.mail.abc.email_client_abc import EMailClientABC
|
||||
from cpl.mail.email_client_settings import EMailClientSettings
|
||||
from cpl.mail.email_model import EMail
|
||||
from cpl.mail.mail_logger import MailLogger
|
||||
from cpl.mail.logger import MailLogger
|
||||
|
||||
|
||||
class EMailClient(EMailClientABC):
|
||||
|
||||
7
src/cpl-mail/cpl/mail/logger.py
Normal file
7
src/cpl-mail/cpl/mail/logger.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||
|
||||
|
||||
class MailLogger(WrappedLogger):
|
||||
|
||||
def __init__(self):
|
||||
WrappedLogger.__init__(self, "mail")
|
||||
@@ -1,8 +0,0 @@
|
||||
from cpl.core.log.logger import Logger
|
||||
from cpl.core.typing import Source
|
||||
|
||||
|
||||
class MailLogger(Logger):
|
||||
|
||||
def __init__(self, source: Source):
|
||||
Logger.__init__(self, source, "mail")
|
||||
@@ -3,6 +3,7 @@ from starlette.responses import JSONResponse
|
||||
from cpl import api
|
||||
from cpl.api.application.web_app import WebApp
|
||||
from cpl.application import ApplicationBuilder
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.environment import Environment
|
||||
from service import PingService
|
||||
@@ -15,7 +16,8 @@ def main():
|
||||
Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json")
|
||||
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
|
||||
|
||||
builder.services.add_logging()
|
||||
# builder.services.add_logging()
|
||||
builder.services.add_structured_logging()
|
||||
builder.services.add_transient(PingService)
|
||||
builder.services.add_module(api)
|
||||
|
||||
@@ -26,7 +28,7 @@ def main():
|
||||
app.with_authentication()
|
||||
app.with_authorization()
|
||||
|
||||
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET")
|
||||
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET", authentication=True, permissions=[Permissions.administrator])
|
||||
app.with_routes_directory("routes")
|
||||
|
||||
app.run()
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
from urllib.request import Request
|
||||
|
||||
from service import PingService
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from cpl.api import APILogger
|
||||
from cpl.api.router import Router
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.core.log import Logger
|
||||
from service import PingService
|
||||
|
||||
|
||||
@Router.authenticate()
|
||||
@Router.authorize(permissions=[Permissions.administrator])
|
||||
# @Router.authorize(permissions=[Permissions.administrator])
|
||||
# @Router.authorize(policies=["test"])
|
||||
@Router.get(f"/ping")
|
||||
async def ping(r: Request, ping: PingService, logger: Logger):
|
||||
async def ping(r: Request, ping: PingService, logger: APILogger):
|
||||
logger.info(f"Ping: {ping}")
|
||||
return JSONResponse(ping.ping(r))
|
||||
|
||||
@@ -5,7 +5,7 @@ from model.city import City
|
||||
class CityDao(DbModelDaoABC[City]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, __name__, City, "city")
|
||||
DbModelDaoABC.__init__(self, City, "city")
|
||||
|
||||
self.attribute(City.name, str)
|
||||
self.attribute(City.zip, int)
|
||||
|
||||
@@ -5,7 +5,7 @@ from model.user import User
|
||||
class UserDao(DbModelDaoABC[User]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, __name__, User, "users")
|
||||
DbModelDaoABC.__init__(self, User, "users")
|
||||
|
||||
self.attribute(User.name, str)
|
||||
self.attribute(User.city_id, int, db_name="CityId")
|
||||
|
||||
Reference in New Issue
Block a user