Compare commits

..

1 Commits

Author SHA1 Message Date
69bbbc8cee Authorization via with_route
Some checks failed
Build on push / prepare (push) Successful in 9s
Build on push / query (push) Successful in 20s
Build on push / core (push) Successful in 20s
Build on push / dependency (push) Successful in 17s
Build on push / mail (push) Successful in 15s
Build on push / application (push) Successful in 18s
Build on push / translation (push) Successful in 18s
Build on push / database (push) Successful in 25s
Build on push / auth (push) Successful in 15s
Build on push / api (push) Successful in 14s
Test before pr merge / test-lint (pull_request) Failing after 6s
2025-09-22 22:03:42 +02:00
213 changed files with 1974 additions and 2156 deletions

View File

@@ -25,11 +25,7 @@ jobs:
git tag git tag
DATE=$(date +'%Y.%m.%d') DATE=$(date +'%Y.%m.%d')
TAG_COUNT=$(git tag -l "${DATE}.*" | wc -l) TAG_COUNT=$(git tag -l "${DATE}.*" | wc -l)
if [ "$TAG_COUNT" -eq 0 ]; then BUILD_NUMBER=$(($TAG_COUNT + 1))
BUILD_NUMBER=0
else
BUILD_NUMBER=$(($TAG_COUNT + 1))
fi
VERSION_SUFFIX=${{ inputs.version_suffix }} VERSION_SUFFIX=${{ inputs.version_suffix }}
if [ -n "$VERSION_SUFFIX" ] && [ "$VERSION_SUFFIX" = "dev" ]; then if [ -n "$VERSION_SUFFIX" ] && [ "$VERSION_SUFFIX" = "dev" ]; then

3
.gitignore vendored
View File

@@ -139,6 +139,3 @@ PythonImportHelper-v2-Completion.json
# cpl unittest stuff # cpl unittest stuff
unittests/test_*_playground unittests/test_*_playground
# cpl logs
**/logs/*.jsonl

View File

@@ -1,80 +0,0 @@
from starlette.responses import JSONResponse
from cpl import api
from cpl.api.application.web_app import WebApp
from cpl.api_module import ApiModule
from cpl.application import ApplicationBuilder
from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema import AuthUser, Role
from cpl.core.configuration import Configuration
from cpl.core.console import Console
from cpl.core.environment import Environment
from cpl.core.utils.cache import Cache
from cpl.database.mysql.mysql_module import MySQLModule
from scoped_service import ScopedService
from service import PingService
def main():
builder = ApplicationBuilder[WebApp](WebApp)
Configuration.add_json_file(f"appsettings.json")
Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json")
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
# builder.services.add_logging()
builder.services.add_structured_logging()
builder.services.add_transient(PingService)
builder.services.add_module(MySQLModule)
builder.services.add_module(ApiModule)
builder.services.add_scoped(ScopedService)
builder.services.add_cache(AuthUser)
builder.services.add_cache(Role)
app = builder.build()
app.with_logging()
app.with_database()
app.with_authentication()
app.with_authorization()
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET", authentication=True, permissions=[Permissions.administrator])
app.with_routes_directory("routes")
provider = builder.service_provider
user_cache = provider.get_service(Cache[AuthUser])
role_cache = provider.get_service(Cache[Role])
if role_cache == user_cache:
raise Exception("Cache service is not working")
s1 = provider.get_service(ScopedService)
s2 = provider.get_service(ScopedService)
if s1.name == s2.name:
raise Exception("Scoped service is not working")
with provider.create_scope() as scope:
s3 = scope.get_service(ScopedService)
s4 = scope.get_service(ScopedService)
if s3.name != s4.name:
raise Exception("Scoped service is not working")
if s1.name == s3.name:
raise Exception("Scoped service is not working")
Console.write_line(
s1.name,
s2.name,
s3.name,
s4.name,
)
app.run()
if __name__ == "__main__":
main()

View File

@@ -1,21 +0,0 @@
from urllib.request import Request
from service import PingService
from starlette.responses import JSONResponse
from cpl.api import APILogger
from cpl.api.router import Router
from cpl.core.console import Console
from cpl.dependency import ServiceProvider
from scoped_service import ScopedService
@Router.authenticate()
# @Router.authorize(permissions=[Permissions.administrator])
# @Router.authorize(policies=["test"])
@Router.get(f"/ping")
async def ping(r: Request, ping: PingService, logger: APILogger, provider: ServiceProvider, scoped: ScopedService):
logger.info(f"Ping: {ping}")
Console.write_line(scoped.name)
return JSONResponse(ping.ping(r))

View File

@@ -1,14 +0,0 @@
from cpl.core.console.console import Console
from cpl.core.utils.string import String
class ScopedService:
def __init__(self):
self._name = String.random(8)
@property
def name(self) -> str:
return self._name
def run(self):
Console.write_line(f"Im {self._name}")

View File

@@ -1,45 +0,0 @@
from cpl.application.abc import ApplicationABC
from cpl.core.console.console import Console
from cpl.dependency import ServiceProvider
from test_abc import TestABC
from test_service import TestService
from di_tester_service import DITesterService
from tester import Tester
class Application(ApplicationABC):
def __init__(self, services: ServiceProvider):
ApplicationABC.__init__(self, services)
def _part_of_scoped(self):
ts: TestService = self._services.get_service(TestService)
ts.run()
def main(self):
with self._services.create_scope() as scope:
Console.write_line("Scope1")
ts: TestService = scope.get_service(TestService)
ts.run()
dit: DITesterService = scope.get_service(DITesterService)
dit.run()
if ts.name != dit.name:
raise Exception("DI is broken!")
with self._services.create_scope() as scope:
Console.write_line("Scope2")
ts: TestService = scope.get_service(TestService)
ts.run()
dit: DITesterService = scope.get_service(DITesterService)
dit.run()
if ts.name != dit.name:
raise Exception("DI is broken!")
Console.write_line("Global")
self._part_of_scoped()
#from static_test import StaticTest
#StaticTest.test()
self._services.get_service(Tester)
Console.write_line(self._services.get_services(TestABC))

View File

@@ -1,27 +0,0 @@
from cpl.application.abc import StartupABC
from cpl.dependency import ServiceProvider, ServiceCollection
from di_tester_service import DITesterService
from test1_service import Test1Service
from test2_service import Test2Service
from test_abc import TestABC
from test_service import TestService
from tester import Tester
class Startup(StartupABC):
def __init__(self):
StartupABC.__init__(self)
@staticmethod
def configure_configuration(): ...
@staticmethod
def configure_services(services: ServiceCollection) -> ServiceProvider:
services.add_scoped(TestService)
services.add_scoped(DITesterService)
services.add_singleton(TestABC, Test1Service)
services.add_singleton(TestABC, Test2Service)
services.add_singleton(Tester)
return services.build()

View File

@@ -1,10 +0,0 @@
from cpl.dependency import ServiceProvider, ServiceProvider
from cpl.dependency.inject import inject
from test_service import TestService
class StaticTest:
@staticmethod
@inject
def test(services: ServiceProvider, t1: TestService):
t1.run()

View File

@@ -1,7 +0,0 @@
from cpl.core.console.console import Console
from test_abc import TestABC
class Tester:
def __init__(self, t1: TestABC, t2: TestABC, t3: TestABC, t: list[TestABC]):
Console.write_line("Tester:", t, t1, t2, t3)

View File

@@ -1,20 +0,0 @@
import asyncio
import time
from cpl.core.console import Console
from cpl.dependency.hosted.hosted_service import HostedService
class Hosted(HostedService):
def __init__(self):
self._stopped = False
async def start(self):
Console.write_line("Hosted Service Started")
while not self._stopped:
Console.write_line("Hosted Service Running")
await asyncio.sleep(5)
async def stop(self):
Console.write_line("Hosted Service Stopped")
self._stopped = True

View File

@@ -1,10 +0,0 @@
from cpl.core.console import Console
class ScopedService:
def __init__(self):
self.value = "I am a scoped service"
Console.write_line(self.value, self)
def get_value(self):
return self.value

View File

@@ -1,60 +0,0 @@
from cpl.core.console import Console
from cpl.core.utils.benchmark import Benchmark
from cpl.query.enumerable import Enumerable
from cpl.query.immutable_list import ImmutableList
from cpl.query.list import List
from cpl.query.set import Set
def _default():
Console.write_line(Enumerable.empty().to_list())
Console.write_line(Enumerable.range(0, 100).length)
Console.write_line(Enumerable.range(0, 100).to_list())
Console.write_line(Enumerable.range(0, 100).where(lambda x: x % 2 == 0).length)
Console.write_line(
Enumerable.range(0, 100).where(lambda x: x % 2 == 0).to_list().select(lambda x: str(x)).to_list()
)
Console.write_line(List)
s =Enumerable.range(0, 10).to_set()
Console.write_line(s)
s.add(1)
Console.write_line(s)
data = Enumerable(
[
{"name": "Alice", "age": 30},
{"name": "Dave", "age": 35},
{"name": "Charlie", "age": 25},
{"name": "Bob", "age": 25},
]
)
Console.write_line(data.order_by(lambda x: x["age"]).to_list())
Console.write_line(data.order_by(lambda x: x["age"]).then_by(lambda x: x["name"]).to_list())
Console.write_line(data.order_by(lambda x: x["name"]).then_by(lambda x: x["age"]).to_list())
def t_benchmark(data: list):
Benchmark.all("Enumerable", lambda: Enumerable(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
Benchmark.all("Set", lambda: Set(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
Benchmark.all("List", lambda: List(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
Benchmark.all(
"ImmutableList", lambda: ImmutableList(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list()
)
Benchmark.all("List comprehension", lambda: [x * 2 for x in data if x % 2 == 0])
def main():
N = 1_000_000
data = list(range(N))
t_benchmark(data)
Console.write_line()
_default()
if __name__ == "__main__":
main()

View File

@@ -1,4 +1,32 @@
from .error import APIError, AlreadyExists, EndpointNotImplemented, Forbidden, NotFound, Unauthorized from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
from .logger import APILogger
from .settings import ApiSettings
def add_api(collection: _ServiceCollection):
try:
from cpl.database import mysql
collection.add_module(mysql)
except ImportError as e:
from cpl.core.errors import dependency_error
dependency_error("cpl-database", e)
try:
from cpl import auth
from cpl.auth import permission
collection.add_module(auth)
collection.add_module(permission)
except ImportError as e:
from cpl.core.errors import dependency_error
dependency_error("cpl-auth", e)
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.registry.route import RouteRegistry
collection.add_singleton(PolicyRegistry)
collection.add_singleton(RouteRegistry)
_ServiceCollection.with_module(add_api, __name__)

View File

@@ -1 +0,0 @@
from .asgi_middleware_abc import ASGIMiddleware

View File

@@ -1 +0,0 @@
from .web_app import WebApp

View File

@@ -25,46 +25,42 @@ from cpl.api.registry.route import RouteRegistry
from cpl.api.router import Router from cpl.api.router import Router
from cpl.api.settings import ApiSettings from cpl.api.settings import ApiSettings
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
from cpl.api_module import ApiModule
from cpl.application.abc.application_abc import ApplicationABC from cpl.application.abc.application_abc import ApplicationABC
from cpl.auth.auth_module import AuthModule
from cpl.auth.permission.permission_module import PermissionsModule
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.dependency.inject import inject from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.dependency.service_provider import ServiceProvider
_logger = APILogger("API")
PolicyInput = Union[dict[str, PolicyResolver], Policy] PolicyInput = Union[dict[str, PolicyResolver], Policy]
class WebApp(ApplicationABC): class WebApp(ApplicationABC):
def __init__(self, services: ServiceProvider): def __init__(self, services: ServiceProviderABC):
super().__init__(services, [AuthModule, PermissionsModule, ApiModule]) super().__init__(services, [auth, api])
self._app: Starlette | None = None self._app: Starlette | None = None
self._logger = services.get_service(APILogger)
self._api_settings = Configuration.get(ApiSettings) self._api_settings = Configuration.get(ApiSettings)
self._policies = services.get_service(PolicyRegistry) self._policies = services.get_service(PolicyRegistry)
self._routes = services.get_service(RouteRegistry) self._routes = services.get_service(RouteRegistry)
self._middleware: list[Middleware] = [] self._middleware: list[Middleware] = [
Middleware(RequestMiddleware),
Middleware(LoggingMiddleware),
]
self._exception_handlers: Mapping[Any, ExceptionHandler] = { self._exception_handlers: Mapping[Any, ExceptionHandler] = {
Exception: self._handle_exception, Exception: self._handle_exception,
APIError: self._handle_exception, APIError: self._handle_exception,
} }
self.with_middleware(RequestMiddleware) @staticmethod
self.with_middleware(LoggingMiddleware) async def _handle_exception(request: Request, exc: Exception):
async def _handle_exception(self, request: Request, exc: Exception):
if isinstance(exc, APIError): if isinstance(exc, APIError):
self._logger.error(exc) _logger.error(exc)
return JSONResponse({"error": str(exc)}, status_code=exc.status_code) return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
if hasattr(request.state, "request_id"): if hasattr(request.state, "request_id"):
self._logger.error(f"Request {request.state.request_id}", exc) _logger.error(f"Request {request.state.request_id}", exc)
else: else:
self._logger.error("Request unknown", exc) _logger.error("Request unknown", exc)
return JSONResponse({"error": str(exc)}, status_code=500) return JSONResponse({"error": str(exc)}, status_code=500)
@@ -72,10 +68,10 @@ class WebApp(ApplicationABC):
origins = self._api_settings.allowed_origins origins = self._api_settings.allowed_origins
if origins is None or origins == "": if origins is None or origins == "":
self._logger.warning("No allowed origins specified, allowing all origins") _logger.warning("No allowed origins specified, allowing all origins")
return ["*"] return ["*"]
self._logger.debug(f"Allowed origins: {origins}") _logger.debug(f"Allowed origins: {origins}")
return origins.split(",") return origins.split(",")
def with_database(self) -> Self: def with_database(self) -> Self:
@@ -171,9 +167,9 @@ class WebApp(ApplicationABC):
self._check_for_app() self._check_for_app()
if isinstance(middleware, Middleware): if isinstance(middleware, Middleware):
self._middleware.append(inject(middleware)) self._middleware.append(middleware)
elif callable(middleware): elif callable(middleware):
self._middleware.append(Middleware(inject(middleware))) self._middleware.append(Middleware(middleware))
else: else:
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable") raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
@@ -194,11 +190,11 @@ class WebApp(ApplicationABC):
if isinstance(policy, dict): if isinstance(policy, dict):
for name, resolver in policy.items(): for name, resolver in policy.items():
if not isinstance(name, str): if not isinstance(name, str):
self._logger.warning(f"Skipping policy at index {i}, name must be a string") _logger.warning(f"Skipping policy at index {i}, name must be a string")
continue continue
if not callable(resolver): if not callable(resolver):
self._logger.warning(f"Skipping policy {name}, resolver must be callable") _logger.warning(f"Skipping policy {name}, resolver must be callable")
continue continue
_policies.append(Policy(name, resolver)) _policies.append(Policy(name, resolver))
@@ -206,7 +202,7 @@ class WebApp(ApplicationABC):
_policies.append(policy) _policies.append(policy)
self._policies.extend(_policies) self._policies.extend_policies(_policies)
self.with_middleware(AuthorizationMiddleware) self.with_middleware(AuthorizationMiddleware)
return self return self
@@ -216,14 +212,14 @@ class WebApp(ApplicationABC):
for policy_name in rule["policies"]: for policy_name in rule["policies"]:
policy = self._policies.get(policy_name) policy = self._policies.get(policy_name)
if not policy: if not policy:
self._logger.fatal(f"Authorization policy '{policy_name}' not found") _logger.fatal(f"Authorization policy '{policy_name}' not found")
async def main(self): async def main(self):
self._logger.debug(f"Preparing API") _logger.debug(f"Preparing API")
self._validate_policies() self._validate_policies()
if self._app is None: if self._app is None:
routes = [route.to_starlette(inject) for route in self._routes.all()] routes = [route.to_starlette(self._services.inject) for route in self._routes.all()]
app = Starlette( app = Starlette(
routes=routes, routes=routes,
@@ -241,7 +237,7 @@ class WebApp(ApplicationABC):
else: else:
app = self._app app = self._app
self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}") _logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
config = uvicorn.Config( config = uvicorn.Config(
app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio" app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio"
@@ -249,4 +245,4 @@ class WebApp(ApplicationABC):
server = uvicorn.Server(config) server = uvicorn.Server(config)
await server.serve() await server.serve()
self._logger.info("Shutdown API") _logger.info("Shutdown API")

View File

@@ -1,7 +1,7 @@
from cpl.core.log.wrapped_logger import WrappedLogger from cpl.core.log.logger import Logger
class APILogger(WrappedLogger): class APILogger(Logger):
def __init__(self): def __init__(self, source: str):
WrappedLogger.__init__(self, "api") Logger.__init__(self, source, "api")

View File

@@ -1,4 +0,0 @@
from .authentication import AuthenticationMiddleware
from .authorization import AuthorizationMiddleware
from .logging import LoggingMiddleware
from .request import RequestMiddleware

View File

@@ -2,22 +2,24 @@ from keycloak import KeycloakAuthenticationError
from starlette.types import Scope, Receive, Send from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.error import Unauthorized
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.api.error import Unauthorized
from cpl.api.middleware.request import get_request from cpl.api.middleware.request import get_request
from cpl.api.router import Router from cpl.api.router import Router
from cpl.auth.keycloak import KeycloakClient from cpl.auth.keycloak import KeycloakClient
from cpl.auth.schema import AuthUserDao, AuthUser from cpl.auth.schema import AuthUserDao, AuthUser
from cpl.core.ctx import set_user from cpl.core.ctx import set_user
from cpl.dependency import ServiceProviderABC
_logger = APILogger(__name__)
class AuthenticationMiddleware(ASGIMiddleware): class AuthenticationMiddleware(ASGIMiddleware):
def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao): @ServiceProviderABC.inject
def __init__(self, app, keycloak: KeycloakClient, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)
self._logger = logger
self._keycloak = keycloak self._keycloak = keycloak
self._user_dao = user_dao self._user_dao = user_dao
@@ -26,11 +28,11 @@ class AuthenticationMiddleware(ASGIMiddleware):
url = request.url.path url = request.url.path
if url not in Router.get_auth_required_routes(): if url not in Router.get_auth_required_routes():
self._logger.trace(f"No authentication required for {url}") _logger.trace(f"No authentication required for {url}")
return await self._app(scope, receive, send) return await self._app(scope, receive, send)
if not request.headers.get("Authorization"): if not request.headers.get("Authorization"):
self._logger.debug(f"Unauthorized access to {url}, missing Authorization header") _logger.debug(f"Unauthorized access to {url}, missing Authorization header")
return await Unauthorized(f"Missing header Authorization").asgi_response(scope, receive, send) return await Unauthorized(f"Missing header Authorization").asgi_response(scope, receive, send)
auth_header = request.headers.get("Authorization", None) auth_header = request.headers.get("Authorization", None)
@@ -39,7 +41,7 @@ class AuthenticationMiddleware(ASGIMiddleware):
token = auth_header.split("Bearer ")[1] token = auth_header.split("Bearer ")[1]
if not await self._verify_login(token): if not await self._verify_login(token):
self._logger.debug(f"Unauthorized access to {url}, invalid token") _logger.debug(f"Unauthorized access to {url}, invalid token")
return await Unauthorized("Invalid token").asgi_response(scope, receive, send) return await Unauthorized("Invalid token").asgi_response(scope, receive, send)
# check user exists in db, if not create # check user exists in db, if not create
@@ -49,7 +51,7 @@ class AuthenticationMiddleware(ASGIMiddleware):
user = await self._get_or_crate_user(keycloak_id) user = await self._get_or_crate_user(keycloak_id)
if user.deleted: if user.deleted:
self._logger.debug(f"Unauthorized access to {url}, user is deleted") _logger.debug(f"Unauthorized access to {url}, user is deleted")
return await Unauthorized("User is deleted").asgi_response(scope, receive, send) return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
request.state.user = user request.state.user = user
@@ -71,8 +73,8 @@ class AuthenticationMiddleware(ASGIMiddleware):
token_info = self._keycloak.introspect(token) token_info = self._keycloak.introspect(token)
return token_info.get("active", False) return token_info.get("active", False)
except KeycloakAuthenticationError as e: except KeycloakAuthenticationError as e:
self._logger.debug(f"Keycloak authentication error: {e}") _logger.debug(f"Keycloak authentication error: {e}")
return False return False
except Exception as e: except Exception as e:
self._logger.error(f"Unexpected error during token verification: {e}") _logger.error(f"Unexpected error during token verification: {e}")
return False return False

View File

@@ -9,15 +9,17 @@ from cpl.api.registry.policy import PolicyRegistry
from cpl.api.router import Router from cpl.api.router import Router
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
from cpl.core.ctx.user_context import get_user from cpl.core.ctx.user_context import get_user
from cpl.dependency.service_provider_abc import ServiceProviderABC
_logger = APILogger(__name__)
class AuthorizationMiddleware(ASGIMiddleware): class AuthorizationMiddleware(ASGIMiddleware):
def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao): @ServiceProviderABC.inject
def __init__(self, app, policies: PolicyRegistry, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)
self._logger = logger
self._policies = policies self._policies = policies
self._user_dao = user_dao self._user_dao = user_dao
@@ -26,7 +28,7 @@ class AuthorizationMiddleware(ASGIMiddleware):
url = request.url.path url = request.url.path
if url not in Router.get_authorization_rules_paths(): if url not in Router.get_authorization_rules_paths():
self._logger.trace(f"No authorization required for {url}") _logger.trace(f"No authorization required for {url}")
return await self._app(scope, receive, send) return await self._app(scope, receive, send)
user = get_user() user = get_user()
@@ -51,18 +53,14 @@ class AuthorizationMiddleware(ASGIMiddleware):
if rule["permissions"]: if rule["permissions"]:
if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]): if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response( return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
scope, receive, send
)
if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]): if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response( return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
scope, receive, send
)
for policy_name in rule["policies"]: for policy_name in rule["policies"]:
policy = self._policies.get(policy_name) policy = self._policies.get(policy_name)
if not policy: if not policy:
self._logger.warning(f"Authorization policy '{policy_name}' not found") _logger.warning(f"Authorization policy '{policy_name}' not found")
continue continue
if not await policy.resolve(user): if not await policy.resolve(user):

View File

@@ -7,14 +7,14 @@ from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.api.middleware.request import get_request from cpl.api.middleware.request import get_request
_logger = APILogger(__name__)
class LoggingMiddleware(ASGIMiddleware): class LoggingMiddleware(ASGIMiddleware):
def __init__(self, app, logger: APILogger): def __init__(self, app):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)
self._logger = logger
async def __call__(self, scope: Scope, receive: Receive, send: Send): async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http": if scope["type"] != "http":
await self._call_next(scope, receive, send) await self._call_next(scope, receive, send)
@@ -53,8 +53,9 @@ class LoggingMiddleware(ASGIMiddleware):
} }
return {key: value for key, value in headers.items() if key in relevant_keys} return {key: value for key, value in headers.items() if key in relevant_keys}
async def _log_request(self, request: Request): @classmethod
self._logger.debug( async def _log_request(cls, request: Request):
_logger.debug(
f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}" f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}"
) )
@@ -63,7 +64,7 @@ class LoggingMiddleware(ASGIMiddleware):
user = get_user() user = get_user()
request_info = { request_info = {
"headers": self._filter_relevant_headers(dict(request.headers)), "headers": cls._filter_relevant_headers(dict(request.headers)),
"args": dict(request.query_params), "args": dict(request.query_params),
"form-data": ( "form-data": (
await request.form() await request.form()
@@ -77,9 +78,10 @@ class LoggingMiddleware(ASGIMiddleware):
), ),
} }
self._logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}") _logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
async def _log_after_request(self, request: Request, status_code: int, duration: float): @staticmethod
self._logger.info( async def _log_after_request(request: Request, status_code: int, duration: float):
_logger.info(
f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms" f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
) )

View File

@@ -9,20 +9,16 @@ from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.api.typing import TRequest from cpl.api.typing import TRequest
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None) _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
_logger = APILogger(__name__)
class RequestMiddleware(ASGIMiddleware): class RequestMiddleware(ASGIMiddleware):
def __init__(self, app, provider: ServiceProvider, logger: APILogger): def __init__(self, app):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)
self._provider = provider
self._logger = logger
self._ctx_token = None self._ctx_token = None
async def __call__(self, scope: Scope, receive: Receive, send: Send): async def __call__(self, scope: Scope, receive: Receive, send: Send):
@@ -30,15 +26,14 @@ class RequestMiddleware(ASGIMiddleware):
await self.set_request_data(request) await self.set_request_data(request)
try: try:
with self._provider.create_scope(): await self._app(scope, receive, send)
inject(await self._app(scope, receive, send))
finally: finally:
await self.clean_request_data() await self.clean_request_data()
async def set_request_data(self, request: TRequest): async def set_request_data(self, request: TRequest):
request.state.request_id = uuid4() request.state.request_id = uuid4()
request.state.start_time = time.time() request.state.start_time = time.time()
self._logger.trace(f"Set new current request: {request.state.request_id}") _logger.trace(f"Set new current request: {request.state.request_id}")
self._ctx_token = _request_context.set(request) self._ctx_token = _request_context.set(request)
@@ -50,7 +45,7 @@ class RequestMiddleware(ASGIMiddleware):
if self._ctx_token is None: if self._ctx_token is None:
return return
self._logger.trace(f"Clearing current request: {request.state.request_id}") _logger.trace(f"Clearing current request: {request.state.request_id}")
_request_context.reset(self._ctx_token) _request_context.reset(self._ctx_token)

View File

@@ -1,3 +0,0 @@
from .api_route import ApiRoute
from .policy import Policy
from .validation_match import ValidationMatch

View File

@@ -7,7 +7,13 @@ from cpl.api.typing import HTTPMethods
class ApiRoute: class ApiRoute:
def __init__(self, path: str, fn: Callable, method: HTTPMethods, **kwargs): def __init__(
self,
path: str,
fn: Callable,
method: HTTPMethods,
**kwargs
):
self._path = path self._path = path
self._fn = fn self._fn = fn
self._method = method self._method = method

View File

@@ -1,5 +1,5 @@
from asyncio import iscoroutinefunction from asyncio import iscoroutinefunction
from typing import Optional from typing import Optional, Any, Coroutine, Awaitable
from cpl.api.typing import PolicyResolver from cpl.api.typing import PolicyResolver
from cpl.core.ctx import get_user from cpl.core.ctx import get_user

View File

@@ -1,2 +0,0 @@
from .policy import PolicyRegistry
from .route import RouteRegistry

View File

@@ -1,5 +1,6 @@
from typing import Optional from typing import Optional
from cpl.api.model.policy import Policy
from cpl.api.model.api_route import ApiRoute from cpl.api.model.api_route import ApiRoute
from cpl.core.abc.registry_abc import RegistryABC from cpl.core.abc.registry_abc import RegistryABC

View File

@@ -3,7 +3,6 @@ from enum import Enum
from cpl.api.model.validation_match import ValidationMatch from cpl.api.model.validation_match import ValidationMatch
from cpl.api.registry.route import RouteRegistry from cpl.api.registry.route import RouteRegistry
from cpl.api.typing import HTTPMethods from cpl.api.typing import HTTPMethods
from cpl.dependency import get_provider
class Router: class Router:
@@ -42,13 +41,7 @@ class Router:
return inner return inner
@classmethod @classmethod
def authorize( def authorize(cls, roles: list[str | Enum]=None, permissions: list[str | Enum]=None, policies: list[str]=None, match: ValidationMatch=None):
cls,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
):
""" """
Decorator to mark a route as requiring authorization. Decorator to mark a route as requiring authorization.
Usage: Usage:
@@ -92,14 +85,15 @@ class Router:
return inner return inner
@classmethod @classmethod
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs): def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry=None, **kwargs):
from cpl.api.model.api_route import ApiRoute
if not registry: if not registry:
routes = get_provider().get_service(RouteRegistry) from cpl.api.model.api_route import ApiRoute
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = ServiceProviderABC.get_global_service(RouteRegistry)
else: else:
routes = registry routes = registry
def inner(fn): def inner(fn):
routes.add(ApiRoute(path, fn, method, **kwargs)) routes.add(ApiRoute(path, fn, method, **kwargs))
setattr(fn, "_route_path", path) setattr(fn, "_route_path", path)
@@ -143,9 +137,8 @@ class Router:
""" """
from cpl.api.model.api_route import ApiRoute from cpl.api.model.api_route import ApiRoute
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = get_provider().get_service(RouteRegistry) routes = ServiceProviderABC.get_global_service(RouteRegistry)
def inner(fn): def inner(fn):
path = getattr(fn, "_route_path", None) path = getattr(fn, "_route_path", None)
if path is None: if path is None:

View File

@@ -1,26 +0,0 @@
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.registry.route import RouteRegistry
from cpl.auth.auth_module import AuthModule
from cpl.auth.permission.permission_module import PermissionsModule
from cpl.core.errors import dependency_error
from cpl.database.database_module import DatabaseModule
from cpl.database.model.server_type import ServerType, ServerTypes
from cpl.database.mysql.mysql_module import MySQLModule
from cpl.dependency.module import Module, TModule
class ApiModule(Module):
@staticmethod
def dependencies() -> list[TModule]:
return [AuthModule, DatabaseModule, PermissionsModule]
@staticmethod
def register(collection: "ServiceCollection"):
collection.add_module(DatabaseModule)
collection.add_module(AuthModule)
collection.add_module(PermissionsModule)
collection.add_singleton(PolicyRegistry)
collection.add_singleton(RouteRegistry)

View File

@@ -1,2 +1 @@
from .application_builder import ApplicationBuilder from .application_builder import ApplicationBuilder
from .host import Host

View File

@@ -2,10 +2,11 @@ from abc import ABC, abstractmethod
from typing import Callable, Self from typing import Callable, Self
from cpl.application.host import Host from cpl.application.host import Host
from cpl.core.console.console import Console
from cpl.core.log import LogSettings
from cpl.core.log.log_level import LogLevel from cpl.core.log.log_level import LogLevel
from cpl.core.log.log_settings import LogSettings
from cpl.core.log.logger_abc import LoggerABC from cpl.core.log.logger_abc import LoggerABC
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider_abc import ServiceProviderABC
def __not_implemented__(package: str, func: Callable): def __not_implemented__(package: str, func: Callable):
@@ -16,12 +17,12 @@ class ApplicationABC(ABC):
r"""ABC for the Application class r"""ABC for the Application class
Parameters: Parameters:
services: :class:`cpl.dependency.service_provider.ServiceProvider` services: :class:`cpl.dependency.service_provider_abc.ServiceProviderABC`
Contains instances of prepared objects Contains instances of prepared objects
""" """
@abstractmethod @abstractmethod
def __init__(self, services: ServiceProvider, required_modules: list[str | object] = None): def __init__(self, services: ServiceProviderABC, required_modules: list[str | object] = None):
self._services = services self._services = services
self._required_modules = ( self._required_modules = (
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else [] [x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []
@@ -84,7 +85,7 @@ class ApplicationABC(ABC):
Called by custom Application.main Called by custom Application.main
""" """
try: try:
Host.run_app(self.main) Host.run(self.main)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass

View File

@@ -1,10 +1,10 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency import ServiceProviderABC
class ApplicationExtensionABC(ABC): class ApplicationExtensionABC(ABC):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def run(services: ServiceProvider): ... def run(services: ServiceProviderABC): ...

View File

@@ -7,7 +7,6 @@ from cpl.application.abc.startup_abc import StartupABC
from cpl.application.abc.startup_extension_abc import StartupExtensionABC from cpl.application.abc.startup_extension_abc import StartupExtensionABC
from cpl.application.host import Host from cpl.application.host import Host
from cpl.core.errors import dependency_error from cpl.core.errors import dependency_error
from cpl.dependency.context import get_provider, use_root_provider
from cpl.dependency.service_collection import ServiceCollection from cpl.dependency.service_collection import ServiceCollection
TApp = TypeVar("TApp", bound=ApplicationABC) TApp = TypeVar("TApp", bound=ApplicationABC)
@@ -22,7 +21,6 @@ class ApplicationBuilder(Generic[TApp]):
self._app = app if app is not None else ApplicationABC self._app = app if app is not None else ApplicationABC
self._services = ServiceCollection() self._services = ServiceCollection()
use_root_provider(self._services.build())
self._startup: Optional[StartupABC] = None self._startup: Optional[StartupABC] = None
self._app_extensions: list[Type[ApplicationExtensionABC]] = [] self._app_extensions: list[Type[ApplicationExtensionABC]] = []
@@ -36,12 +34,7 @@ class ApplicationBuilder(Generic[TApp]):
@property @property
def service_provider(self): def service_provider(self):
provider = get_provider() return self._services.build()
if provider is None:
provider = self._services.build()
use_root_provider(provider)
return provider
def validate_app_required_modules(self, app: ApplicationABC): def validate_app_required_modules(self, app: ApplicationABC):
for module in app.required_modules: for module in app.required_modules:
@@ -49,7 +42,6 @@ class ApplicationBuilder(Generic[TApp]):
continue continue
dependency_error( dependency_error(
type(app).__name__,
module, module,
ImportError( ImportError(
f"Required module '{module}' for application '{app.__class__.__name__}' is not loaded. Load using 'add_module({module})' method." f"Required module '{module}' for application '{app.__class__.__name__}' is not loaded. Load using 'add_module({module})' method."
@@ -83,7 +75,6 @@ class ApplicationBuilder(Generic[TApp]):
for extension in self._app_extensions: for extension in self._app_extensions:
Host.run(extension.run, self.service_provider) Host.run(extension.run, self.service_provider)
use_root_provider(self._services.build())
app = self._app(self.service_provider) app = self._app(self.service_provider)
self.validate_app_required_modules(app) self.validate_app_required_modules(app)
return app return app

View File

@@ -1,80 +1,17 @@
import asyncio import asyncio
from typing import Callable from typing import Callable
from cpl.dependency import get_provider
from cpl.dependency.hosted.startup_task import StartupTask
class Host: class Host:
_loop: asyncio.AbstractEventLoop | None = None _loop = asyncio.get_event_loop()
_tasks: dict = {}
@classmethod @classmethod
def get_loop(cls) -> asyncio.AbstractEventLoop: def get_loop(cls):
if cls._loop is None:
cls._loop = asyncio.new_event_loop()
asyncio.set_event_loop(cls._loop)
return cls._loop return cls._loop
@classmethod
def run_start_tasks(cls):
provider = get_provider()
tasks = provider.get_services(StartupTask)
loop = cls.get_loop()
for task in tasks:
if asyncio.iscoroutinefunction(task.run):
loop.run_until_complete(task.run())
else:
task.run()
@classmethod
def run_hosted_services(cls):
provider = get_provider()
services = provider.get_hosted_services()
loop = cls.get_loop()
for service in services:
if asyncio.iscoroutinefunction(service.start):
cls._tasks[service] = loop.create_task(service.start())
@classmethod
async def _stop_all(cls):
for service in cls._tasks.keys():
if asyncio.iscoroutinefunction(service.stop):
await service.stop()
for task in cls._tasks.values():
task.cancel()
cls._tasks.clear()
@classmethod
def run_app(cls, func: Callable, *args, **kwargs):
cls.run_start_tasks()
cls.run_hosted_services()
async def runner():
try:
if asyncio.iscoroutinefunction(func):
app_task = asyncio.create_task(func(*args, **kwargs))
else:
app_task = cls.get_loop().run_in_executor(None, func, *args, **kwargs)
await asyncio.wait(
[app_task, *cls._tasks.values()],
return_when=asyncio.FIRST_COMPLETED,
)
except (KeyboardInterrupt, asyncio.CancelledError):
pass
finally:
await cls._stop_all()
cls.get_loop().run_until_complete(runner())
@classmethod @classmethod
def run(cls, func: Callable, *args, **kwargs): def run(cls, func: Callable, *args, **kwargs):
if asyncio.iscoroutinefunction(func): if asyncio.iscoroutinefunction(func):
return cls.get_loop().run_until_complete(func(*args, **kwargs)) return cls._loop.run_until_complete(func(*args, **kwargs))
return func(*args, **kwargs) return func(*args, **kwargs)

View File

@@ -5,8 +5,10 @@ from cpl.application.abc import ApplicationABC as _ApplicationABC
from cpl.auth import permission as _permission from cpl.auth import permission as _permission
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin as _KeycloakAdmin from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin as _KeycloakAdmin
from cpl.auth.keycloak.keycloak_client import KeycloakClient as _KeycloakClient from cpl.auth.keycloak.keycloak_client import KeycloakClient as _KeycloakClient
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
from .auth_logger import AuthLogger
from .keycloak_settings import KeycloakSettings from .keycloak_settings import KeycloakSettings
from .logger import AuthLogger from .permission_seeder import PermissionSeeder
def _with_permissions(self: _ApplicationABC, *permissions: Type[Enum]) -> _ApplicationABC: def _with_permissions(self: _ApplicationABC, *permissions: Type[Enum]) -> _ApplicationABC:
@@ -17,5 +19,66 @@ def _with_permissions(self: _ApplicationABC, *permissions: Type[Enum]) -> _Appli
return self return self
def _add_daos(collection: _ServiceCollection):
from .schema._administration.auth_user_dao import AuthUserDao
from .schema._administration.api_key_dao import ApiKeyDao
from .schema._permission.api_key_permission_dao import ApiKeyPermissionDao
from .schema._permission.permission_dao import PermissionDao
from .schema._permission.role_dao import RoleDao
from .schema._permission.role_permission_dao import RolePermissionDao
from .schema._permission.role_user_dao import RoleUserDao
collection.add_singleton(AuthUserDao)
collection.add_singleton(ApiKeyDao)
collection.add_singleton(ApiKeyPermissionDao)
collection.add_singleton(PermissionDao)
collection.add_singleton(RoleDao)
collection.add_singleton(RolePermissionDao)
collection.add_singleton(RoleUserDao)
def add_auth(collection: _ServiceCollection):
import os
try:
from cpl.database.service.migration_service import MigrationService
from cpl.database.model.server_type import ServerType, ServerTypes
collection.add_singleton(_KeycloakClient)
collection.add_singleton(_KeycloakAdmin)
_add_daos(collection)
provider = collection.build()
migration_service: MigrationService = provider.get_service(MigrationService)
if ServerType.server_type == ServerTypes.POSTGRES:
migration_service.with_directory(
os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/postgres")
)
elif ServerType.server_type == ServerTypes.MYSQL:
migration_service.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/mysql"))
except ImportError as e:
from cpl.core.console import Console
Console.error("cpl-database is not installed", str(e))
def add_permission(collection: _ServiceCollection):
from .permission_seeder import PermissionSeeder
from .permission.permissions_registry import PermissionsRegistry
from .permission.permissions import Permissions
try:
from cpl.database.abc.data_seeder_abc import DataSeederABC
collection.add_singleton(DataSeederABC, PermissionSeeder)
PermissionsRegistry.with_enum(Permissions)
except ImportError as e:
from cpl.core.console import Console
Console.error("cpl-database is not installed", str(e))
_ServiceCollection.with_module(add_auth, __name__)
_ServiceCollection.with_module(add_permission, _permission.__name__)
_ApplicationABC.extend(_ApplicationABC.with_permissions, _with_permissions) _ApplicationABC.extend(_ApplicationABC.with_permissions, _with_permissions)

View File

@@ -0,0 +1,8 @@
from cpl.core.log import Logger
from cpl.core.typing import Source
class AuthLogger(Logger):
def __init__(self, source: Source):
Logger.__init__(self, source, "auth")

View File

@@ -1,44 +0,0 @@
import os
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin as _KeycloakAdmin
from cpl.auth.keycloak.keycloak_client import KeycloakClient as _KeycloakClient
from cpl.database.database_module import DatabaseModule
from cpl.database.model.server_type import ServerType, ServerTypes
from cpl.database.service.migration_service import MigrationService
from cpl.dependency.module import Module, TModule
from cpl.dependency.service_collection import ServiceCollection
from .schema._administration.api_key_dao import ApiKeyDao
from .schema._administration.auth_user_dao import AuthUserDao
from .schema._permission.api_key_permission_dao import ApiKeyPermissionDao
from .schema._permission.permission_dao import PermissionDao
from .schema._permission.role_dao import RoleDao
from .schema._permission.role_permission_dao import RolePermissionDao
from .schema._permission.role_user_dao import RoleUserDao
class AuthModule(Module):
@staticmethod
def dependencies() -> list[TModule]:
return [DatabaseModule]
@staticmethod
def register(collection: ServiceCollection):
collection.add_singleton(_KeycloakClient)
collection.add_singleton(_KeycloakAdmin)
collection.add_singleton(AuthUserDao)
collection.add_singleton(ApiKeyDao)
collection.add_singleton(ApiKeyPermissionDao)
collection.add_singleton(PermissionDao)
collection.add_singleton(RoleDao)
collection.add_singleton(RolePermissionDao)
collection.add_singleton(RoleUserDao)
provider = collection.build()
migration_service: MigrationService = provider.get_service(MigrationService)
if ServerType.server_type == ServerTypes.POSTGRES:
migration_service.with_directory(
os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/postgres")
)
elif ServerType.server_type == ServerTypes.MYSQL:
migration_service.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/mysql"))

View File

@@ -1,13 +1,15 @@
from keycloak import KeycloakAdmin as _KeycloakAdmin, KeycloakOpenIDConnection from keycloak import KeycloakAdmin as _KeycloakAdmin, KeycloakOpenIDConnection
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.keycloak_settings import KeycloakSettings from cpl.auth.keycloak_settings import KeycloakSettings
from cpl.auth.logger import AuthLogger
_logger = AuthLogger("keycloak")
class KeycloakAdmin(_KeycloakAdmin): class KeycloakAdmin(_KeycloakAdmin):
def __init__(self, logger: AuthLogger, settings: KeycloakSettings): def __init__(self, settings: KeycloakSettings):
# logger.info("Initializing Keycloak admin") _logger.info("Initializing Keycloak admin")
_connection = KeycloakOpenIDConnection( _connection = KeycloakOpenIDConnection(
server_url=settings.url, server_url=settings.url,
client_id=settings.client_id, client_id=settings.client_id,

View File

@@ -2,13 +2,15 @@ from typing import Optional
from keycloak import KeycloakOpenID from keycloak import KeycloakOpenID
from cpl.auth.logger import AuthLogger from cpl.auth.auth_logger import AuthLogger
from cpl.auth.keycloak_settings import KeycloakSettings from cpl.auth.keycloak_settings import KeycloakSettings
_logger = AuthLogger("keycloak")
class KeycloakClient(KeycloakOpenID): class KeycloakClient(KeycloakOpenID):
def __init__(self, logger: AuthLogger, settings: KeycloakSettings): def __init__(self, settings: KeycloakSettings):
KeycloakOpenID.__init__( KeycloakOpenID.__init__(
self, self,
server_url=settings.url, server_url=settings.url,
@@ -16,7 +18,7 @@ class KeycloakClient(KeycloakOpenID):
realm_name=settings.realm, realm_name=settings.realm,
client_secret_key=settings.client_secret, client_secret_key=settings.client_secret,
) )
logger.info("Initializing Keycloak client") _logger.info("Initializing Keycloak client")
def get_user_id(self, token: str) -> Optional[str]: def get_user_id(self, token: str) -> Optional[str]:
info = self.introspect(token) info = self.introspect(token)

View File

@@ -1,5 +1,5 @@
from cpl.core.utils.get_value import get_value from cpl.core.utils.get_value import get_value
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProviderABC
class KeycloakUser: class KeycloakUser:
@@ -32,5 +32,5 @@ class KeycloakUser:
def id(self) -> str: def id(self) -> str:
from cpl.auth import KeycloakAdmin from cpl.auth import KeycloakAdmin
keycloak_admin: KeycloakAdmin = get_provider().get_service(KeycloakAdmin) keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
return keycloak_admin.get_user_id(self._username) return keycloak_admin.get_user_id(self._username)

View File

@@ -1,7 +0,0 @@
from cpl.core.log.wrapped_logger import WrappedLogger
class AuthLogger(WrappedLogger):
def __init__(self):
WrappedLogger.__init__(self, "auth")

View File

@@ -1,20 +0,0 @@
from cpl.auth.auth_module import AuthModule
from cpl.auth.permission.permission_seeder import PermissionSeeder
from cpl.auth.permission.permissions import Permissions
from cpl.auth.permission.permissions_registry import PermissionsRegistry
from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.dependency.module import Module, TModule
from cpl.dependency.service_collection import ServiceCollection
class PermissionsModule(Module):
@staticmethod
def dependencies() -> list[TModule]:
from cpl.database.database_module import DatabaseModule
return [DatabaseModule, AuthModule]
@staticmethod
def register(collection: ServiceCollection):
collection.add_singleton(DataSeederABC, PermissionSeeder)
PermissionsRegistry.with_enum(Permissions)

View File

@@ -14,13 +14,14 @@ from cpl.auth.schema import (
) )
from cpl.core.utils.get_value import get_value from cpl.core.utils.get_value import get_value
from cpl.database.abc.data_seeder_abc import DataSeederABC from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class PermissionSeeder(DataSeederABC): class PermissionSeeder(DataSeederABC):
def __init__( def __init__(
self, self,
logger: DBLogger,
permission_dao: PermissionDao, permission_dao: PermissionDao,
role_dao: RoleDao, role_dao: RoleDao,
role_permission_dao: RolePermissionDao, role_permission_dao: RolePermissionDao,
@@ -28,7 +29,6 @@ class PermissionSeeder(DataSeederABC):
api_key_permission_dao: ApiKeyPermissionDao, api_key_permission_dao: ApiKeyPermissionDao,
): ):
DataSeederABC.__init__(self) DataSeederABC.__init__(self)
self._logger = logger
self._permission_dao = permission_dao self._permission_dao = permission_dao
self._role_dao = role_dao self._role_dao = role_dao
self._role_permission_dao = role_permission_dao self._role_permission_dao = role_permission_dao
@@ -40,7 +40,7 @@ class PermissionSeeder(DataSeederABC):
possible_permissions = [permission for permission in PermissionsRegistry.get()] possible_permissions = [permission for permission in PermissionsRegistry.get()]
if len(permissions) == len(possible_permissions): if len(permissions) == len(possible_permissions):
self._logger.info("Permissions already existing") _logger.info("Permissions already existing")
await self._update_missing_descriptions() await self._update_missing_descriptions()
return return
@@ -53,7 +53,7 @@ class PermissionSeeder(DataSeederABC):
await self._permission_dao.delete_many(to_delete, hard_delete=True) await self._permission_dao.delete_many(to_delete, hard_delete=True)
self._logger.warning("Permissions incomplete") _logger.warning("Permissions incomplete")
permission_names = [permission.name for permission in permissions] permission_names = [permission.name for permission in permissions]
await self._permission_dao.create_many( await self._permission_dao.create_many(
[ [

View File

@@ -10,8 +10,7 @@ from cpl.core.log.logger import Logger
from cpl.core.typing import Id, SerialId from cpl.core.typing import Id, SerialId
from cpl.core.utils.credential_manager import CredentialManager from cpl.core.utils.credential_manager import CredentialManager
from cpl.database.abc.db_model_abc import DbModelABC from cpl.database.abc.db_model_abc import DbModelABC
from cpl.dependency import get_provider from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.dependency.service_provider import ServiceProvider
_logger = Logger(__name__) _logger = Logger(__name__)
@@ -48,7 +47,7 @@ class ApiKey(DbModelABC):
async def permissions(self): async def permissions(self):
from cpl.auth.schema._permission.api_key_permission_dao import ApiKeyPermissionDao from cpl.auth.schema._permission.api_key_permission_dao import ApiKeyPermissionDao
apiKeyPermissionDao = get_provider().get_service(ApiKeyPermissionDao) apiKeyPermissionDao = ServiceProviderABC.get_global_provider().get_service(ApiKeyPermissionDao)
return [await x.permission for x in await apiKeyPermissionDao.find_by_api_key_id(self.id)] return [await x.permission for x in await apiKeyPermissionDao.find_by_api_key_id(self.id)]

View File

@@ -3,12 +3,15 @@ from typing import Optional
from cpl.auth.schema._administration.api_key import ApiKey from cpl.auth.schema._administration.api_key import ApiKey
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class ApiKeyDao(DbModelDaoABC[ApiKey]): class ApiKeyDao(DbModelDaoABC[ApiKey]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, ApiKey, TableManager.get("api_keys")) DbModelDaoABC.__init__(self, __name__, ApiKey, TableManager.get("api_keys"))
self.attribute(ApiKey.identifier, str) self.attribute(ApiKey.identifier, str)
self.attribute(ApiKey.key, str, "keystring") self.attribute(ApiKey.key, str, "keystring")

View File

@@ -6,11 +6,13 @@ from async_property import async_property
from keycloak import KeycloakGetError from keycloak import KeycloakGetError
from cpl.auth.keycloak import KeycloakAdmin from cpl.auth.keycloak import KeycloakAdmin
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions import Permissions
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
from cpl.database.logger import DBLogger from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
_logger = AuthLogger(__name__)
class AuthUser(DbModelABC): class AuthUser(DbModelABC):
@@ -36,13 +38,12 @@ class AuthUser(DbModelABC):
return "ANONYMOUS" return "ANONYMOUS"
try: try:
keycloak = get_provider().get_service(KeycloakAdmin) keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
return keycloak.get_user(self._keycloak_id).get("username") return keycloak_admin.get_user(self._keycloak_id).get("username")
except KeycloakGetError as e: except KeycloakGetError as e:
return "UNKNOWN" return "UNKNOWN"
except Exception as e: except Exception as e:
logger = get_provider().get_service(DBLogger) _logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
return "UNKNOWN" return "UNKNOWN"
@property @property
@@ -51,39 +52,38 @@ class AuthUser(DbModelABC):
return "ANONYMOUS" return "ANONYMOUS"
try: try:
keycloak = get_provider().get_service(KeycloakAdmin) keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
return keycloak.get_user(self._keycloak_id).get("email") return keycloak_admin.get_user(self._keycloak_id).get("email")
except KeycloakGetError as e: except KeycloakGetError as e:
return "UNKNOWN" return "UNKNOWN"
except Exception as e: except Exception as e:
logger = get_provider().get_service(DBLogger) _logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
return "UNKNOWN" return "UNKNOWN"
@async_property @async_property
async def roles(self): async def roles(self):
from cpl.auth.schema._permission.role_user_dao import RoleUserDao from cpl.auth.schema._permission.role_user_dao import RoleUserDao
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao) role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
return [await x.role for x in await role_user_dao.get_by_user_id(self.id)] return [await x.role for x in await role_user_dao.get_by_user_id(self.id)]
@async_property @async_property
async def permissions(self): async def permissions(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
return await auth_user_dao.get_permissions(self.id) return await auth_user_dao.get_permissions(self.id)
async def has_permission(self, permission: Permissions) -> bool: async def has_permission(self, permission: Permissions) -> bool:
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
return await auth_user_dao.has_permission(self.id, permission) return await auth_user_dao.has_permission(self.id, permission)
async def anonymize(self): async def anonymize(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
self._keycloak_id = str(uuid.UUID(int=0)) self._keycloak_id = str(uuid.UUID(int=0))
await auth_user_dao.update(self) await auth_user_dao.update(self)

View File

@@ -4,14 +4,17 @@ from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema._administration.auth_user import AuthUser from cpl.auth.schema._administration.auth_user import AuthUser
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProviderABC
_logger = DBLogger(__name__)
class AuthUserDao(DbModelDaoABC[AuthUser]): class AuthUserDao(DbModelDaoABC[AuthUser]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, AuthUser, TableManager.get("auth_users")) DbModelDaoABC.__init__(self, __name__, AuthUser, TableManager.get("auth_users"))
self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId") self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId")
@@ -36,7 +39,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool: async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool:
from cpl.auth.schema._permission.permission_dao import PermissionDao from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao) permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value) p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
result = await self._db.select_map( result = await self._db.select_map(
f""" f"""

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbJoinModelABC from cpl.database.abc import DbJoinModelABC
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProviderABC
class ApiKeyPermission(DbJoinModelABC): class ApiKeyPermission(DbJoinModelABC):
@@ -31,7 +31,7 @@ class ApiKeyPermission(DbJoinModelABC):
async def api_key(self): async def api_key(self):
from cpl.auth.schema._administration.api_key_dao import ApiKeyDao from cpl.auth.schema._administration.api_key_dao import ApiKeyDao
api_key_dao: ApiKeyDao = get_provider().get_service(ApiKeyDao) api_key_dao: ApiKeyDao = ServiceProviderABC.get_global_service(ApiKeyDao)
return await api_key_dao.get_by_id(self._api_key_id) return await api_key_dao.get_by_id(self._api_key_id)
@property @property
@@ -42,5 +42,5 @@ class ApiKeyPermission(DbJoinModelABC):
async def permission(self): async def permission(self):
from cpl.auth.schema._permission.permission_dao import PermissionDao from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao) permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
return await permission_dao.get_by_id(self._permission_id) return await permission_dao.get_by_id(self._permission_id)

View File

@@ -1,12 +1,15 @@
from cpl.auth.schema._permission.api_key_permission import ApiKeyPermission from cpl.auth.schema._permission.api_key_permission import ApiKeyPermission
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class ApiKeyPermissionDao(DbModelDaoABC[ApiKeyPermission]): class ApiKeyPermissionDao(DbModelDaoABC[ApiKeyPermission]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, ApiKeyPermission, TableManager.get("api_key_permissions")) DbModelDaoABC.__init__(self, __name__, ApiKeyPermission, TableManager.get("api_key_permissions"))
self.attribute(ApiKeyPermission.api_key_id, int) self.attribute(ApiKeyPermission.api_key_id, int)
self.attribute(ApiKeyPermission.permission_id, int) self.attribute(ApiKeyPermission.permission_id, int)

View File

@@ -3,12 +3,15 @@ from typing import Optional
from cpl.auth.schema._permission.permission import Permission from cpl.auth.schema._permission.permission import Permission
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class PermissionDao(DbModelDaoABC[Permission]): class PermissionDao(DbModelDaoABC[Permission]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, Permission, TableManager.get("permissions")) DbModelDaoABC.__init__(self, __name__, Permission, TableManager.get("permissions"))
self.attribute(Permission.name, str) self.attribute(Permission.name, str)
self.attribute(Permission.description, Optional[str]) self.attribute(Permission.description, Optional[str])

View File

@@ -6,7 +6,7 @@ from async_property import async_property
from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions import Permissions
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProviderABC
class Role(DbModelABC): class Role(DbModelABC):
@@ -44,22 +44,22 @@ class Role(DbModelABC):
async def permissions(self): async def permissions(self):
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao) role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
return [await x.permission for x in await role_permission_dao.get_by_role_id(self.id)] return [await x.permission for x in await role_permission_dao.get_by_role_id(self.id)]
@async_property @async_property
async def users(self): async def users(self):
from cpl.auth.schema._permission.role_user_dao import RoleUserDao from cpl.auth.schema._permission.role_user_dao import RoleUserDao
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao) role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
return [await x.user for x in await role_user_dao.get_by_role_id(self.id)] return [await x.user for x in await role_user_dao.get_by_role_id(self.id)]
async def has_permission(self, permission: Permissions) -> bool: async def has_permission(self, permission: Permissions) -> bool:
from cpl.auth.schema._permission.permission_dao import PermissionDao from cpl.auth.schema._permission.permission_dao import PermissionDao
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao) permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao) role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
p = await permission_dao.get_by_name(permission.value) p = await permission_dao.get_by_name(permission.value)

View File

@@ -1,11 +1,14 @@
from cpl.auth.schema._permission.role import Role from cpl.auth.schema._permission.role import Role
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class RoleDao(DbModelDaoABC[Role]): class RoleDao(DbModelDaoABC[Role]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, Role, TableManager.get("roles")) DbModelDaoABC.__init__(self, __name__, Role, TableManager.get("roles"))
self.attribute(Role.name, str) self.attribute(Role.name, str)
self.attribute(Role.description, str) self.attribute(Role.description, str)

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProviderABC
class RolePermission(DbModelABC): class RolePermission(DbModelABC):
@@ -31,7 +31,7 @@ class RolePermission(DbModelABC):
async def role(self): async def role(self):
from cpl.auth.schema._permission.role_dao import RoleDao from cpl.auth.schema._permission.role_dao import RoleDao
role_dao: RoleDao = get_provider().get_service(RoleDao) role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
return await role_dao.get_by_id(self._role_id) return await role_dao.get_by_id(self._role_id)
@property @property
@@ -42,5 +42,5 @@ class RolePermission(DbModelABC):
async def permission(self): async def permission(self):
from cpl.auth.schema._permission.permission_dao import PermissionDao from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao) permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
return await permission_dao.get_by_id(self._permission_id) return await permission_dao.get_by_id(self._permission_id)

View File

@@ -1,12 +1,15 @@
from cpl.auth.schema._permission.role_permission import RolePermission from cpl.auth.schema._permission.role_permission import RolePermission
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class RolePermissionDao(DbModelDaoABC[RolePermission]): class RolePermissionDao(DbModelDaoABC[RolePermission]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, RolePermission, TableManager.get("role_permissions")) DbModelDaoABC.__init__(self, __name__, RolePermission, TableManager.get("role_permissions"))
self.attribute(RolePermission.role_id, int) self.attribute(RolePermission.role_id, int)
self.attribute(RolePermission.permission_id, int) self.attribute(RolePermission.permission_id, int)

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbJoinModelABC from cpl.database.abc import DbJoinModelABC
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProviderABC
class RoleUser(DbJoinModelABC): class RoleUser(DbJoinModelABC):
@@ -31,7 +31,7 @@ class RoleUser(DbJoinModelABC):
async def user(self): async def user(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
return await auth_user_dao.get_by_id(self._user_id) return await auth_user_dao.get_by_id(self._user_id)
@property @property
@@ -42,5 +42,5 @@ class RoleUser(DbJoinModelABC):
async def role(self): async def role(self):
from cpl.auth.schema._permission.role_dao import RoleDao from cpl.auth.schema._permission.role_dao import RoleDao
role_dao: RoleDao = get_provider().get_service(RoleDao) role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
return await role_dao.get_by_id(self._role_id) return await role_dao.get_by_id(self._role_id)

View File

@@ -1,12 +1,15 @@
from cpl.auth.schema._permission.role_user import RoleUser from cpl.auth.schema._permission.role_user import RoleUser
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class RoleUserDao(DbModelDaoABC[RoleUser]): class RoleUserDao(DbModelDaoABC[RoleUser]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, RoleUser, TableManager.get("role_users")) DbModelDaoABC.__init__(self, __name__, RoleUser, TableManager.get("role_users"))
self.attribute(RoleUser.role_id, int) self.attribute(RoleUser.role_id, int)
self.attribute(RoleUser.user_id, int) self.attribute(RoleUser.user_id, int)

View File

@@ -1,18 +1,17 @@
from contextvars import ContextVar from contextvars import ContextVar
from typing import Optional from typing import Optional
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.schema._administration.auth_user import AuthUser from cpl.auth.schema._administration.auth_user import AuthUser
from cpl.dependency import get_provider
_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None) _user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None)
_logger = AuthLogger(__name__)
def set_user(user: Optional[AuthUser]):
from cpl.core.log.logger_abc import LoggerABC
logger = get_provider().get_service(LoggerABC) def set_user(user_id: Optional[AuthUser]):
logger.trace("Setting user context", user.id) _logger.trace("Setting user context", user_id)
_user_context.set(user) _user_context.set(user_id)
def get_user() -> Optional[AuthUser]: def get_user() -> Optional[AuthUser]:

View File

@@ -3,19 +3,8 @@ import traceback
from cpl.core.console import Console from cpl.core.console import Console
def dependency_error(src: str, package_name: str, e: ImportError = None) -> None: def dependency_error(package_name: str, e: ImportError) -> None:
Console.error(f"'{package_name}' is required to use feature: {src}. Please install it and try again.") Console.error(f"'{package_name}' is required to use this feature. Please install it and try again.")
tb = traceback.format_exc()
if not tb.startswith("NoneType: None"):
Console.write_line("->", tb)
elif e is not None:
Console.write_line("->", str(e))
exit(1)
def module_dependency_error(src: str, module: str, e: ImportError = None) -> None:
Console.error(f"'{module}' is required to use feature: {src}. Please initialize it with `add_module({module})`.")
tb = traceback.format_exc() tb = traceback.format_exc()
if not tb.startswith("NoneType: None"): if not tb.startswith("NoneType: None"):
Console.write_line("->", tb) Console.write_line("->", tb)

View File

@@ -2,4 +2,3 @@ from .logger import Logger
from .logger_abc import LoggerABC from .logger_abc import LoggerABC
from .log_level import LogLevel from .log_level import LogLevel
from .log_settings import LogSettings from .log_settings import LogSettings
from .structured_logger import StructuredLogger

View File

@@ -1,111 +0,0 @@
import asyncio
import importlib.util
import json
import traceback
from datetime import datetime
from starlette.requests import Request
from cpl.core.log.log_level import LogLevel
from cpl.core.log.logger import Logger
from cpl.core.typing import Source, Messages
from cpl.dependency import get_provider
class StructuredLogger(Logger):
def __init__(self, source: Source, file_prefix: str = None):
Logger.__init__(self, source, file_prefix)
@property
def log_file(self):
return f"logs/{self._file_prefix}_{datetime.now().strftime('%Y-%m-%d')}.jsonl"
def _log(self, level: LogLevel, *messages: Messages):
try:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
formatted_message = self._format_message(level.value, timestamp, *messages)
structured_message = self._get_structured_message(level.value, timestamp, formatted_message)
self._write_log_to_file(level, structured_message)
self._write_to_console(level, formatted_message)
except Exception as e:
print(f"Error while logging: {e} -> {traceback.format_exc()}")
def _get_structured_message(self, level: str, timestamp: str, messages: str) -> str:
structured_message = {
"timestamp": timestamp,
"level": level.upper(),
"source": self._source,
"messages": messages,
}
self._enrich_message_with_request(structured_message)
self._enrich_message_with_user(structured_message)
return json.dumps(structured_message, ensure_ascii=False)
@staticmethod
def _scope_to_json(request: Request, include_headers: bool = False) -> dict:
scope = dict(request.scope)
def convert(value):
if isinstance(value, bytes):
return value.decode("utf-8")
if isinstance(value, (list, tuple)):
return [convert(v) for v in value]
if isinstance(value, dict):
return {str(k): convert(v) for k, v in value.items()}
if not isinstance(value, (str, int, float, bool, type(None))):
return str(value)
return value
serializable_scope = {str(k): convert(v) for k, v in scope.items()}
if not include_headers and "headers" in serializable_scope:
serializable_scope["headers"] = "<omitted>"
return serializable_scope
def _enrich_message_with_request(self, message: dict):
if importlib.util.find_spec("cpl.api") is None:
return
from cpl.api.middleware.request import get_request
from starlette.requests import Request
request = get_request()
if request is None:
return
message["request"] = {
"url": str(request.url),
"method": request.method,
"scope": self._scope_to_json(request),
}
if isinstance(request, Request) and request.scope == "http":
request: Request = request # fix typing for IDEs
message["request"]["data"] = asyncio.create_task(request.body())
@staticmethod
def _enrich_message_with_user(message: dict):
if importlib.util.find_spec("cpl-auth") is None:
return
from cpl.core.ctx import get_user
user = get_user()
if user is None:
return
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin
keycloak = get_provider().get_service(KeycloakAdmin)
kc_user = keycloak.get_user(user.keycloak_id)
message["user"] = {
"id": str(user.id),
"username": kc_user.get("username"),
"email": kc_user.get("email"),
}

View File

@@ -1,101 +0,0 @@
import inspect
from typing import Type
from cpl.core.log import LoggerABC, LogLevel
from cpl.core.typing import Messages
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
class WrappedLogger(LoggerABC):
def __init__(self, file_prefix: str):
LoggerABC.__init__(self)
assert file_prefix is not None and file_prefix != "", "file_prefix must be a non-empty string"
self._source = None
self._file_prefix = file_prefix
self._set_logger()
@inject
def _set_logger(self, services: ServiceProvider):
from cpl.core.log import Logger
t_logger: Type[Logger] = services.get_service_type(LoggerABC)
if t_logger is None:
raise Exception("No LoggerABC service registered in ServiceProvider")
self._logger = t_logger(self._source, self._file_prefix)
def set_level(self, level: LogLevel):
self._logger.set_level(level)
def _format_message(self, level: str, timestamp, *messages: Messages) -> str:
return self._logger._format_message(level, timestamp, *messages)
@staticmethod
def _get_source() -> str | None:
stack = inspect.stack()
if len(stack) <= 1:
return None
from cpl.dependency import ServiceCollection
ignore_classes = [
ServiceProvider,
ServiceProvider.__subclasses__(),
ServiceCollection,
WrappedLogger,
WrappedLogger.__subclasses__(),
]
ignore_modules = [x.__module__ for x in ignore_classes if isinstance(x, type)]
for i, frame_info in enumerate(stack[1:]):
module = inspect.getmodule(frame_info.frame)
if module is None:
continue
if module.__name__ in ignore_classes or module in ignore_classes:
continue
if module in ignore_modules or module.__name__ in ignore_modules:
continue
if module.__name__ != __name__:
return module.__name__
return None
def _set_source(self):
self._source = self._get_source()
self._set_logger()
def header(self, string: str):
self._set_source()
self._logger.header(string)
def trace(self, *messages: Messages):
self._set_source()
self._logger.trace(*messages)
def debug(self, *messages: Messages):
self._set_source()
self._logger.debug(*messages)
def info(self, *messages: Messages):
self._set_source()
self._logger.info(*messages)
def warning(self, *messages: Messages):
self._set_source()
self._logger.warning(*messages)
def error(self, messages: str, e: Exception = None):
self._set_source()
self._logger.error(messages, e)
def fatal(self, messages: str, e: Exception = None):
self._set_source()
self._logger.fatal(messages, e)

View File

@@ -14,4 +14,3 @@ UuidId = str | UUID
SerialId = int SerialId = int
Id = UuidId | SerialId Id = UuidId | SerialId
TNumber = int | float | complex

View File

@@ -1,57 +0,0 @@
import time
import tracemalloc
from typing import List, Callable
from cpl.core.console import Console
class Benchmark:
@staticmethod
def all(label: str, func: Callable, iterations: int = 5):
times: List[float] = []
mems: List[float] = []
for _ in range(iterations):
start = time.perf_counter()
func()
end = time.perf_counter()
times.append(end - start)
for _ in range(iterations):
tracemalloc.start()
func()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
mems.append(peak)
avg_time = sum(times) / len(times)
avg_mem = sum(mems) / len(mems) / (1024 * 1024)
Console.write_line(f"{label:20s} -> min {min(times):.6f}s avg {avg_time:.6f}s mem {avg_mem:.8f} MB")
@staticmethod
def time(label: str, func: Callable, iterations: int = 5):
times: List[float] = []
for _ in range(iterations):
start = time.perf_counter()
func()
end = time.perf_counter()
times.append(end - start)
avg_time = sum(times) / len(times)
Console.write_line(f"{label:20s} -> min {min(times):.6f}s avg {avg_time:.6f}s")
@staticmethod
def memory(label: str, func: Callable, iterations: int = 5):
mems: List[float] = []
for _ in range(iterations):
tracemalloc.start()
func()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
mems.append(peak)
avg_mem = sum(mems) / len(mems) / (1024 * 1024)
Console.write_line(f"{label:20s} -> mem {avg_mem:.2f} MB")

View File

@@ -1,100 +0,0 @@
import threading
import time
from typing import Generic
from cpl.core.typing import T
class Cache(Generic[T]):
def __init__(self, default_ttl: int = None, cleanup_interval: int = 60, t: type = None):
self._store = {}
self._default_ttl = default_ttl
self._lock = threading.Lock()
self._cleanup_interval = cleanup_interval
self._stop_event = threading.Event()
self._type = t
# Start background cleanup thread
self._thread = threading.Thread(target=self._auto_cleanup, daemon=True)
self._thread.start()
def set(self, key: str, value: T, ttl: int = None) -> None:
"""Store a value in the cache with optional TTL override."""
expire_at = None
ttl = ttl if ttl is not None else self._default_ttl
if ttl is not None:
expire_at = time.time() + ttl
with self._lock:
self._store[key] = (value, expire_at)
def get(self, key: str) -> T | None:
"""Retrieve a value from the cache if not expired."""
with self._lock:
item = self._store.get(key)
if not item:
return None
value, expire_at = item
if expire_at and expire_at < time.time():
# Expired -> remove and return None
del self._store[key]
return None
return value
def get_all(self) -> list[T]:
"""Retrieve all non-expired values from the cache."""
now = time.time()
with self._lock:
valid_items = []
expired_keys = []
for k, (v, exp) in self._store.items():
if exp and exp < now:
expired_keys.append(k)
else:
valid_items.append(v)
for k in expired_keys:
del self._store[k]
return valid_items
def has(self, key: str) -> bool:
"""Check if a key exists and is not expired."""
with self._lock:
item = self._store.get(key)
if not item:
return False
_, expire_at = item
if expire_at and expire_at < time.time():
# Expired -> remove and return False
del self._store[key]
return False
return True
def delete(self, key: str) -> None:
"""Remove an item from the cache."""
with self._lock:
self._store.pop(key, None)
def clear(self) -> None:
"""Clear the entire cache."""
with self._lock:
self._store.clear()
def _auto_cleanup(self):
"""Background thread to clean expired items."""
while not self._stop_event.is_set():
self.cleanup()
self._stop_event.wait(self._cleanup_interval)
def cleanup(self) -> None:
"""Remove expired items immediately."""
now = time.time()
with self._lock:
expired_keys = [k for k, (_, exp) in self._store.items() if exp and exp < now]
for k in expired_keys:
del self._store[k]
def stop(self):
"""Stop the background cleanup thread."""
self._stop_event.set()
self._thread.join()

View File

@@ -1,48 +0,0 @@
from typing import Any
class Number:
@staticmethod
def is_number(value: Any) -> bool:
"""Check if the value is a number (int or float)."""
return isinstance(value, (int, float, complex))
@staticmethod
def to_number(value: Any) -> int | float | complex:
"""
Convert a given value into int, float, or complex.
Raises ValueError if conversion is not possible.
"""
if isinstance(value, (int, float, complex)):
return value
if isinstance(value, str):
value = value.strip()
for caster in (int, float, complex):
try:
return caster(value)
except ValueError:
continue
raise ValueError(f"Cannot convert string '{value}' to number.")
if isinstance(value, bool):
return int(value)
try:
return int(value)
except Exception:
pass
try:
return float(value)
except Exception:
pass
try:
return complex(value)
except Exception:
pass
raise ValueError(f"Cannot convert type {type(value)} to number.")

View File

@@ -114,15 +114,12 @@ class String:
characters = [] characters = []
if letters: if letters:
characters.extend(string.ascii_letters) characters.append(string.ascii_letters)
if digits: if digits:
characters.extend(string.digits) characters.append(string.digits)
if special_characters: if special_characters:
characters.extend(string.punctuation) characters.append(string.punctuation)
x = "".join(random.choice(list(characters)) for _ in range(length)) if characters else "" return "".join(random.choice(characters) for _ in range(length)) if characters else ""
if len(x) != length:
raise Exception("No characters selected to generate random string")
return x

View File

@@ -1,12 +1,15 @@
import os import os
from typing import Type
from cpl.application.abc import ApplicationABC as _ApplicationABC from cpl.application.abc import ApplicationABC as _ApplicationABC
from cpl.dependency import ServiceCollection as _ServiceCollection
from . import mysql as _mysql from . import mysql as _mysql
from . import postgres as _postgres from . import postgres as _postgres
from .table_manager import TableManager from .table_manager import TableManager
def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC: def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC:
from cpl.application.host import Host
from cpl.database.service.migration_service import MigrationService from cpl.database.service.migration_service import MigrationService
migration_service = self._services.get_service(MigrationService) migration_service = self._services.get_service(MigrationService)
@@ -18,6 +21,8 @@ def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _Applica
for path in paths: for path in paths:
migration_service.with_directory(path) migration_service.with_directory(path)
Host.run(migration_service.migrate)
return self return self
@@ -30,5 +35,43 @@ def _with_seeders(self: _ApplicationABC) -> _ApplicationABC:
return self return self
def _add(collection: _ServiceCollection, db_context: Type, default_port: int, server_type: str):
from cpl.core.console import Console
from cpl.core.configuration import Configuration
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.model.server_type import ServerTypes, ServerType
from cpl.database.model.database_settings import DatabaseSettings
from cpl.database.service.migration_service import MigrationService
from cpl.database.service.seeder_service import SeederService
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
try:
ServerType.set_server_type(ServerTypes(server_type))
Configuration.set("DB_DEFAULT_PORT", default_port)
collection.add_singleton(DBContextABC, db_context)
collection.add_singleton(ExecutedMigrationDao)
collection.add_singleton(MigrationService)
collection.add_singleton(SeederService)
except ImportError as e:
Console.error("cpl-database is not installed", str(e))
def add_mysql(collection: _ServiceCollection):
from cpl.database.mysql.db_context import DBContext
from cpl.database.model import ServerTypes
_add(collection, DBContext, 3306, ServerTypes.MYSQL.value)
def add_postgres(collection: _ServiceCollection):
from cpl.database.mysql.db_context import DBContext
from cpl.database.model import ServerTypes
_add(collection, DBContext, 5432, ServerTypes.POSTGRES.value)
_ServiceCollection.with_module(add_mysql, _mysql.__name__)
_ServiceCollection.with_module(add_postgres, _postgres.__name__)
_ApplicationABC.extend(_ApplicationABC.with_migrations, _with_migrations) _ApplicationABC.extend(_ApplicationABC.with_migrations, _with_migrations)
_ApplicationABC.extend(_ApplicationABC.with_seeders, _with_seeders) _ApplicationABC.extend(_ApplicationABC.with_seeders, _with_seeders)

View File

@@ -9,20 +9,25 @@ from cpl.core.utils.get_value import get_value
from cpl.core.utils.string import String from cpl.core.utils.string import String
from cpl.database.abc.db_context_abc import DBContextABC from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.const import DATETIME_FORMAT from cpl.database.const import DATETIME_FORMAT
from cpl.database.db_logger import DBLogger
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.database.logger import DBLogger
from cpl.database.model.server_type import ServerType, ServerTypes
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
from cpl.dependency import get_provider
class DataAccessObjectABC(ABC, Generic[T_DBM]): class DataAccessObjectABC(ABC, Generic[T_DBM]):
@abstractmethod @abstractmethod
def __init__(self, model_type: Type[T_DBM], table_name: str): def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
self._db = get_provider().get_service(DBContextABC) from cpl.dependency.service_provider_abc import ServiceProviderABC
self._logger = get_provider().get_service(DBLogger)
self._db = ServiceProviderABC.get_global_service(DBContextABC)
self._logger = DBLogger(source)
self._model_type = model_type
self._table_name = table_name
self._logger = DBLogger(source)
self._model_type = model_type self._model_type = model_type
self._table_name = table_name self._table_name = table_name
@@ -352,13 +357,13 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
values = f"{await self._get_editor_id(obj) if not skip_editor else ''}{f', {values}' if not skip_editor and len(values) > 0 else f'{values}'}" values = f"{await self._get_editor_id(obj) if not skip_editor else ''}{f', {values}' if not skip_editor and len(values) > 0 else f'{values}'}"
return f""" return f"""
INSERT INTO {self._table_name} ( INSERT INTO {self._table_name} (
{fields} {fields}
) VALUES ( ) VALUES (
{values} {values}
) )
{"RETURNING {self.__primary_key};" if ServerType.server_type == ServerTypes.POSTGRES else ";SELECT LAST_INSERT_ID();"} RETURNING {self.__primary_key};
""" """
async def create(self, obj: T_DBM, skip_editor=False) -> int: async def create(self, obj: T_DBM, skip_editor=False) -> int:
self._logger.debug(f"create {type(obj).__name__} {obj.__dict__}") self._logger.debug(f"create {type(obj).__name__} {obj.__dict__}")

View File

@@ -10,8 +10,8 @@ from cpl.database.abc.db_model_abc import DbModelABC
class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]): class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]):
@abstractmethod @abstractmethod
def __init__(self, model_type: Type[T_DBM], table_name: str): def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
DataAccessObjectABC.__init__(self, model_type, table_name) DataAccessObjectABC.__init__(self, source, model_type, table_name)
self.attribute(DbModelABC.id, int, ignore=True) self.attribute(DbModelABC.id, int, ignore=True)
self.attribute(DbModelABC.deleted, bool) self.attribute(DbModelABC.deleted, bool)

View File

@@ -1,22 +0,0 @@
from cpl.core.errors import module_dependency_error
from cpl.database.model.server_type import ServerType
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
from cpl.database.service.migration_service import MigrationService
from cpl.database.service.seeder_service import SeederService
from cpl.dependency.module import Module, TModule
from cpl.dependency.service_collection import ServiceCollection
class DatabaseModule(Module):
@staticmethod
def dependencies() -> list[TModule]:
if not ServerType.has_server_type:
module_dependency_error(__name__, "MySQLModule or PostgresModule")
return []
@staticmethod
def register(collection: ServiceCollection):
collection.add_singleton(ExecutedMigrationDao)
collection.add_singleton(MigrationService)
collection.add_singleton(SeederService)

View File

@@ -0,0 +1,8 @@
from cpl.core.log import Logger
from cpl.core.typing import Source
class DBLogger(Logger):
def __init__(self, source: Source):
Logger.__init__(self, source, "db")

View File

@@ -1,7 +0,0 @@
from cpl.core.log.wrapped_logger import WrappedLogger
class DBLogger(WrappedLogger):
def __init__(self):
WrappedLogger.__init__(self, "db")

View File

@@ -15,11 +15,6 @@ class ServerType:
assert isinstance(server_type, ServerTypes), f"Expected ServerType but got {type(server_type)}" assert isinstance(server_type, ServerTypes), f"Expected ServerType but got {type(server_type)}"
cls._server_type = server_type cls._server_type = server_type
@classmethod
@property
def has_server_type(cls) -> bool:
return cls._server_type is not None
@classmethod @classmethod
@property @property
def server_type(cls) -> ServerTypes: def server_type(cls) -> ServerTypes:

View File

@@ -4,17 +4,18 @@ from typing import Any, List, Dict, Tuple, Union
from mysql.connector import Error as MySQLError, PoolError from mysql.connector import Error as MySQLError, PoolError
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.core.environment import Environment
from cpl.database.abc.db_context_abc import DBContextABC from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
from cpl.database.model.database_settings import DatabaseSettings from cpl.database.model.database_settings import DatabaseSettings
from cpl.database.mysql.mysql_pool import MySQLPool from cpl.database.mysql.mysql_pool import MySQLPool
_logger = DBLogger(__name__)
class DBContext(DBContextABC): class DBContext(DBContextABC):
def __init__(self, logger: DBLogger): def __init__(self):
DBContextABC.__init__(self) DBContextABC.__init__(self)
self._logger = logger
self._pool: MySQLPool = None self._pool: MySQLPool = None
self._fails = 0 self._fails = 0
@@ -22,62 +23,62 @@ class DBContext(DBContextABC):
def connect(self, database_settings: DatabaseSettings): def connect(self, database_settings: DatabaseSettings):
try: try:
self._logger.debug("Connecting to database") _logger.debug("Connecting to database")
self._pool = MySQLPool( self._pool = MySQLPool(
database_settings, database_settings,
) )
self._logger.info("Connected to database") _logger.info("Connected to database")
except Exception as e: except Exception as e:
self._logger.fatal("Connecting to database failed", e) _logger.fatal("Connecting to database failed", e)
async def execute(self, statement: str, args=None, multi=True) -> List[List]: async def execute(self, statement: str, args=None, multi=True) -> List[List]:
self._logger.trace(f"execute {statement} with args: {args}") _logger.trace(f"execute {statement} with args: {args}")
return await self._pool.execute(statement, args, multi) return await self._pool.execute(statement, args, multi)
async def select_map(self, statement: str, args=None) -> List[Dict]: async def select_map(self, statement: str, args=None) -> List[Dict]:
self._logger.trace(f"select {statement} with args: {args}") _logger.trace(f"select {statement} with args: {args}")
try: try:
return await self._pool.select_map(statement, args) return await self._pool.select_map(statement, args)
except (MySQLError, PoolError) as e: except (MySQLError, PoolError) as e:
if self._fails >= 3: if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4() uid = uuid.uuid4()
raise Exception( raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
) )
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1 self._fails += 1
try: try:
self._logger.debug("Retry select") _logger.debug("Retry select")
return await self.select_map(statement, args) return await self.select_map(statement, args)
except Exception as e: except Exception as e:
pass pass
return [] return []
except Exception as e: except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
raise e raise e
async def select(self, statement: str, args=None) -> Union[List[str], List[Tuple], List[Any]]: async def select(self, statement: str, args=None) -> Union[List[str], List[Tuple], List[Any]]:
self._logger.trace(f"select {statement} with args: {args}") _logger.trace(f"select {statement} with args: {args}")
try: try:
return await self._pool.select(statement, args) return await self._pool.select(statement, args)
except (MySQLError, PoolError) as e: except (MySQLError, PoolError) as e:
if self._fails >= 3: if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4() uid = uuid.uuid4()
raise Exception( raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
) )
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1 self._fails += 1
try: try:
self._logger.debug("Retry select") _logger.debug("Retry select")
return await self.select(statement, args) return await self.select(statement, args)
except Exception as e: except Exception as e:
pass pass
return [] return []
except Exception as e: except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
raise e raise e

View File

@@ -1,19 +0,0 @@
from cpl.core.configuration.configuration import Configuration
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.model.server_type import ServerTypes, ServerType
from cpl.database.mysql.db_context import DBContext
from cpl.dependency.module import Module, TModule
from cpl.dependency.service_collection import ServiceCollection
class MySQLModule(Module):
@staticmethod
def dependencies() -> list[TModule]:
return []
@staticmethod
def register(collection: ServiceCollection):
ServerType.set_server_type(ServerTypes(ServerTypes.MYSQL.value))
Configuration.set("DB_DEFAULT_PORT", 3306)
collection.add_singleton(DBContextABC, DBContext)

View File

@@ -4,9 +4,10 @@ import sqlparse
from mysql.connector.aio import MySQLConnectionPool from mysql.connector.aio import MySQLConnectionPool
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
from cpl.database.model import DatabaseSettings from cpl.database.model import DatabaseSettings
from cpl.dependency.context import get_provider
_logger = DBLogger(__name__)
class MySQLPool: class MySQLPool:
@@ -18,11 +19,7 @@ class MySQLPool:
"user": database_settings.user, "user": database_settings.user,
"password": database_settings.password, "password": database_settings.password,
"database": database_settings.database, "database": database_settings.database,
"charset": database_settings.charset, "ssl_disabled": True,
"use_unicode": database_settings.use_unicode,
"buffered": database_settings.buffered,
"auth_plugin": database_settings.auth_plugin,
"ssl_disabled": False,
} }
self._pool: Optional[MySQLConnectionPool] = None self._pool: Optional[MySQLConnectionPool] = None
@@ -39,8 +36,7 @@ class MySQLPool:
await cursor.execute("SELECT 1") await cursor.execute("SELECT 1")
await cursor.fetchall() await cursor.fetchall()
except Exception as e: except Exception as e:
logger = get_provider().get_service(DBLogger) _logger.fatal(f"Error connecting to the database: {e}")
logger.fatal(f"Error connecting to the database: {e}")
finally: finally:
await con.close() await con.close()

View File

@@ -7,16 +7,16 @@ from psycopg_pool import PoolTimeout
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.database.abc.db_context_abc import DBContextABC from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.logger import DBLogger from cpl.database.database_settings import DatabaseSettings
from cpl.database.model import DatabaseSettings from cpl.database.db_logger import DBLogger
from cpl.database.postgres.postgres_pool import PostgresPool from cpl.database.postgres.postgres_pool import PostgresPool
_logger = DBLogger(__name__)
class DBContext(DBContextABC): class DBContext(DBContextABC):
def __init__(self, logger: DBLogger): def __init__(self):
DBContextABC.__init__(self) DBContextABC.__init__(self)
self._logger = logger
self._pool: PostgresPool = None self._pool: PostgresPool = None
self._fails = 0 self._fails = 0
@@ -24,63 +24,63 @@ class DBContext(DBContextABC):
def connect(self, database_settings: DatabaseSettings): def connect(self, database_settings: DatabaseSettings):
try: try:
self._logger.debug("Connecting to database") _logger.debug("Connecting to database")
self._pool = PostgresPool( self._pool = PostgresPool(
database_settings, database_settings,
Environment.get("DB_POOL_SIZE", int, 1), Environment.get("DB_POOL_SIZE", int, 1),
) )
self._logger.info("Connected to database") _logger.info("Connected to database")
except Exception as e: except Exception as e:
self._logger.fatal("Connecting to database failed", e) _logger.fatal("Connecting to database failed", e)
async def execute(self, statement: str, args=None, multi=True) -> list[list]: async def execute(self, statement: str, args=None, multi=True) -> list[list]:
self._logger.trace(f"execute {statement} with args: {args}") _logger.trace(f"execute {statement} with args: {args}")
return await self._pool.execute(statement, args, multi) return await self._pool.execute(statement, args, multi)
async def select_map(self, statement: str, args=None) -> list[dict]: async def select_map(self, statement: str, args=None) -> list[dict]:
self._logger.trace(f"select {statement} with args: {args}") _logger.trace(f"select {statement} with args: {args}")
try: try:
return await self._pool.select_map(statement, args) return await self._pool.select_map(statement, args)
except (OperationalError, PoolTimeout) as e: except (OperationalError, PoolTimeout) as e:
if self._fails >= 3: if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4() uid = uuid.uuid4()
raise Exception( raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
) )
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1 self._fails += 1
try: try:
self._logger.debug("Retry select") _logger.debug("Retry select")
return await self.select_map(statement, args) return await self.select_map(statement, args)
except Exception as e: except Exception as e:
pass pass
return [] return []
except Exception as e: except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
raise e raise e
async def select(self, statement: str, args=None) -> list[str] | list[tuple] | list[Any]: async def select(self, statement: str, args=None) -> list[str] | list[tuple] | list[Any]:
self._logger.trace(f"select {statement} with args: {args}") _logger.trace(f"select {statement} with args: {args}")
try: try:
return await self._pool.select(statement, args) return await self._pool.select(statement, args)
except (OperationalError, PoolTimeout) as e: except (OperationalError, PoolTimeout) as e:
if self._fails >= 3: if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4() uid = uuid.uuid4()
raise Exception( raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
) )
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1 self._fails += 1
try: try:
self._logger.debug("Retry select") _logger.debug("Retry select")
return await self.select(statement, args) return await self.select(statement, args)
except Exception as e: except Exception as e:
pass pass
return [] return []
except Exception as e: except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
raise e raise e

View File

@@ -1,20 +0,0 @@
from cpl.core.configuration.configuration import Configuration
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.database_module import DatabaseModule
from cpl.database.model.server_type import ServerTypes, ServerType
from cpl.database.postgres.db_context import DBContext
from cpl.dependency.module import Module, TModule
from cpl.dependency.service_collection import ServiceCollection
class PostgresModule(Module):
@staticmethod
def dependencies() -> list[TModule]:
return [DatabaseModule]
@staticmethod
def register(collection: ServiceCollection):
ServerType.set_server_type(ServerTypes(ServerTypes.POSTGRES.value))
Configuration.set("DB_DEFAULT_PORT", 5432)
collection.add_singleton(DBContextABC, DBContext)

View File

@@ -5,9 +5,10 @@ from psycopg import sql
from psycopg_pool import AsyncConnectionPool, PoolTimeout from psycopg_pool import AsyncConnectionPool, PoolTimeout
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
from cpl.database.model import DatabaseSettings from cpl.database.model import DatabaseSettings
from cpl.dependency import ServiceProvider
_logger = DBLogger(__name__)
class PostgresPool: class PostgresPool:
@@ -37,8 +38,7 @@ class PostgresPool:
await pool.check_connection(con) await pool.check_connection(con)
except PoolTimeout as e: except PoolTimeout as e:
await pool.close() await pool.close()
logger = get_provider().get_service(DBLogger) _logger.fatal(f"Failed to connect to the database", e)
logger.fatal(f"Failed to connect to the database", e)
self._pool = pool self._pool = pool
return self._pool return self._pool

View File

@@ -1,11 +1,14 @@
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.database.db_logger import DBLogger
from cpl.database.schema.executed_migration import ExecutedMigration from cpl.database.schema.executed_migration import ExecutedMigration
_logger = DBLogger(__name__)
class ExecutedMigrationDao(DataAccessObjectABC[ExecutedMigration]): class ExecutedMigrationDao(DataAccessObjectABC[ExecutedMigration]):
def __init__(self): def __init__(self):
DataAccessObjectABC.__init__(self, ExecutedMigration, TableManager.get("executed_migrations")) DataAccessObjectABC.__init__(self, __name__, ExecutedMigration, TableManager.get("executed_migrations"))
self.attribute(ExecutedMigration.migration_id, str, primary_key=True, db_name="migrationId") self.attribute(ExecutedMigration.migration_id, str, primary_key=True, db_name="migrationId")

View File

@@ -2,21 +2,20 @@ import glob
import os import os
from cpl.database.abc import DBContextABC from cpl.database.abc import DBContextABC
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
from cpl.database.model import Migration from cpl.database.model import Migration
from cpl.database.model.server_type import ServerType, ServerTypes from cpl.database.model.server_type import ServerType, ServerTypes
from cpl.database.schema.executed_migration import ExecutedMigration from cpl.database.schema.executed_migration import ExecutedMigration
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
from cpl.dependency.hosted.startup_task import StartupTask
_logger = DBLogger(__name__)
class MigrationService(StartupTask): class MigrationService:
def __init__(self, logger: DBLogger, db: DBContextABC, executed_migration_dao: ExecutedMigrationDao): def __init__(self, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao):
StartupTask.__init__(self)
self._logger = logger
self._db = db self._db = db
self._executed_migration_dao = executed_migration_dao self._executedMigrationDao = executedMigrationDao
self._script_directories: list[str] = [] self._script_directories: list[str] = []
@@ -25,15 +24,12 @@ class MigrationService(StartupTask):
elif ServerType.server_type == ServerTypes.MYSQL: elif ServerType.server_type == ServerTypes.MYSQL:
self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/mysql")) self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/mysql"))
async def run(self):
await self._execute(self._load_scripts())
def with_directory(self, directory: str) -> "MigrationService": def with_directory(self, directory: str) -> "MigrationService":
self._script_directories.append(directory) self._script_directories.append(directory)
return self return self
async def _get_migration_history(self) -> list[ExecutedMigration]: async def _get_migration_history(self) -> list[ExecutedMigration]:
results = await self._db.select(f"SELECT * FROM {self._executed_migration_dao.table_name}") results = await self._db.select(f"SELECT * FROM {self._executedMigrationDao.table_name}")
applied_migrations = [] applied_migrations = []
for result in results: for result in results:
applied_migrations.append(ExecutedMigration(result[0])) applied_migrations.append(ExecutedMigration(result[0]))
@@ -96,17 +92,20 @@ class MigrationService(StartupTask):
try: try:
# check if table exists # check if table exists
if len(result) > 0: if len(result) > 0:
migration_from_db = await self._executed_migration_dao.find_by_id(migration.name) migration_from_db = await self._executedMigrationDao.find_by_id(migration.name)
if migration_from_db is not None: if migration_from_db is not None:
continue continue
self._logger.debug(f"Running upgrade migration: {migration.name}") _logger.debug(f"Running upgrade migration: {migration.name}")
await self._db.execute(migration.script, multi=True) await self._db.execute(migration.script, multi=True)
await self._executed_migration_dao.create(ExecutedMigration(migration.name), skip_editor=True) await self._executedMigrationDao.create(ExecutedMigration(migration.name), skip_editor=True)
except Exception as e: except Exception as e:
self._logger.fatal( _logger.fatal(
f"Migration failed: {migration.name}\n{active_statement}", f"Migration failed: {migration.name}\n{active_statement}",
e, e,
) )
async def migrate(self):
await self._execute(self._load_scripts())

View File

@@ -1,16 +1,18 @@
from cpl.database.abc.data_seeder_abc import DataSeederABC from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProviderABC
_logger = DBLogger(__name__)
class SeederService: class SeederService:
def __init__(self, provider: ServiceProvider): def __init__(self, provider: ServiceProviderABC):
self._provider = provider self._provider = provider
self._logger = provider.get_service(DBLogger)
async def seed(self): async def seed(self):
seeders = self._provider.get_services(DataSeederABC) seeders = self._provider.get_services(DataSeederABC)
self._logger.debug(f"Found {len(seeders)} seeders") _logger.debug(f"Found {len(seeders)} seeders")
for seeder in seeders: for seeder in seeders:
await seeder.seed() await seeder.seed()

View File

@@ -1,7 +1,7 @@
from .context import get_provider, use_provider from .scope import Scope
from .inject import inject from .scope_abc import ScopeABC
from .service_collection import ServiceCollection from .service_collection import ServiceCollection
from .service_descriptor import ServiceDescriptor from .service_descriptor import ServiceDescriptor
from .service_lifetime import ServiceLifetimeEnum from .service_lifetime_enum import ServiceLifetimeEnum
from .service_provider import ServiceProvider
from .service_provider import ServiceProvider from .service_provider import ServiceProvider
from .service_provider_abc import ServiceProviderABC

View File

@@ -1,22 +0,0 @@
import contextvars
from contextlib import contextmanager
_current_provider = contextvars.ContextVar("current_provider", default=None)
def use_root_provider(provider: "ServiceProvider"):
_current_provider.set(provider)
@contextmanager
def use_provider(provider: "ServiceProvider"):
token = _current_provider.set(provider)
try:
yield
finally:
_current_provider.reset(token)
def get_provider() -> "ServiceProvider":
return _current_provider.get()

View File

@@ -1,2 +0,0 @@
from .hosted_service import HostedService
from .startup_task import StartupTask

View File

@@ -1,9 +0,0 @@
from abc import ABC, abstractmethod
class HostedService(ABC):
@abstractmethod
async def start(self): ...
@abstractmethod
async def stop(self): ...

View File

@@ -1,6 +0,0 @@
from abc import ABC, abstractmethod
class StartupTask(ABC):
@abstractmethod
async def run(self): ...

View File

@@ -1,42 +0,0 @@
import functools
from asyncio import iscoroutinefunction
from inspect import signature
from cpl.dependency.context import get_provider
def inject(f=None):
if f is None:
return functools.partial(inject)
if iscoroutinefunction(f):
@functools.wraps(f)
async def async_inner(*args, **kwargs):
from cpl.dependency.service_provider import ServiceProvider
provider: ServiceProvider | None = get_provider()
if provider is None:
raise ValueError(
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
)
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
return await f(*args, *injection, **kwargs)
return async_inner
@functools.wraps(f)
def inner(*args, **kwargs):
from cpl.dependency.service_provider import ServiceProvider
provider: ServiceProvider | None = get_provider()
if provider is None:
raise ValueError(
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
)
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
return f(*args, *injection, **kwargs)
return inner

View File

@@ -1,14 +0,0 @@
from abc import abstractmethod, ABC
from typing import Type
TModule = Type["Module"]
class Module(ABC):
@staticmethod
@abstractmethod
def dependencies() -> list[TModule]: ...
@staticmethod
@abstractmethod
def register(collection: "ServiceCollection"): ...

View File

@@ -0,0 +1,22 @@
from cpl.dependency.scope_abc import ScopeABC
from cpl.dependency.service_provider_abc import ServiceProviderABC
class Scope(ScopeABC):
def __init__(self, service_provider: ServiceProviderABC):
self._service_provider = service_provider
self._service_provider.set_scope(self)
ScopeABC.__init__(self)
def __enter__(self):
return self
def __exit__(self, *args):
self.dispose()
@property
def service_provider(self) -> ServiceProviderABC:
return self._service_provider
def dispose(self):
self._service_provider = None

View File

@@ -0,0 +1,20 @@
from abc import ABC, abstractmethod
class ScopeABC(ABC):
r"""ABC for the class :class:`cpl.dependency.scope.Scope`"""
def __init__(self): ...
@property
@abstractmethod
def service_provider(self):
r"""Returns to service provider of scope
Returns:
Object of type :class:`cpl.dependency.service_provider_abc.ServiceProviderABC`
"""
@abstractmethod
def dispose(self):
r"""Sets service_provider to None"""

View File

@@ -0,0 +1,18 @@
from cpl.dependency.scope import Scope
from cpl.dependency.scope_abc import ScopeABC
from cpl.dependency.service_provider_abc import ServiceProviderABC
class ScopeBuilder:
r"""Class to build :class:`cpl.dependency.scope.Scope`"""
def __init__(self, service_provider: ServiceProviderABC) -> None:
self._service_provider = service_provider
def build(self) -> ScopeABC:
r"""Returns scope
Returns:
Object of type :class:`cpl.dependency.scope.Scope`
"""
return Scope(self._service_provider)

View File

@@ -1,14 +1,12 @@
from inspect import isclass
from typing import Union, Type, Callable, Self from typing import Union, Type, Callable, Self
from cpl.core.log.logger import Logger
from cpl.core.log.logger_abc import LoggerABC from cpl.core.log.logger_abc import LoggerABC
from cpl.core.typing import T, Service from cpl.core.typing import T, Service
from cpl.core.utils.cache import Cache
from cpl.dependency.hosted.startup_task import StartupTask
from cpl.dependency.module import Module
from cpl.dependency.service_descriptor import ServiceDescriptor from cpl.dependency.service_descriptor import ServiceDescriptor
from cpl.dependency.service_lifetime import ServiceLifetimeEnum from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.service_provider_abc import ServiceProviderABC
class ServiceCollection: class ServiceCollection:
@@ -18,7 +16,7 @@ class ServiceCollection:
@classmethod @classmethod
def with_module(cls, func: Callable, name: str = None) -> type[Self]: def with_module(cls, func: Callable, name: str = None) -> type[Self]:
# cls._modules[func.__name__ if name is None else name] = func cls._modules[func.__name__ if name is None else name] = func
return cls return cls
def __init__(self): def __init__(self):
@@ -64,56 +62,23 @@ class ServiceCollection:
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service) self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service)
return self return self
def add_startup_task(self, task: Type[StartupTask]) -> Self: def build(self) -> ServiceProviderABC:
self.add_singleton(StartupTask, task)
return self
def add_hosted_service(self, service_type: T, service: Service = None) -> Self:
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.hosted, service)
return self
def build(self) -> ServiceProvider:
sp = ServiceProvider(self._service_descriptors) sp = ServiceProvider(self._service_descriptors)
ServiceProviderABC.set_global_provider(sp)
return sp return sp
def add_module(self, module: Type[Module]) -> Self: def add_module(self, module: str | object) -> Self:
assert isclass(module), "Module must be a Module" if not isinstance(module, str):
assert issubclass(module, Module), f"Module must be subclass of {Module.__name__}" module = module.__name__
name = module.__name__ if module not in self._modules:
if module in self._modules:
raise ValueError(f"Module {module} not found") raise ValueError(f"Module {module} not found")
for dependency in module.dependencies(): self._modules[module](self)
if dependency.__name__ not in self._loaded_modules: if module not in self._loaded_modules:
self.add_module(dependency) self._loaded_modules.add(module)
module().register(self)
if name not in self._loaded_modules:
self._loaded_modules.add(name)
return self return self
def add_logging(self) -> Self: def add_logging(self) -> Self:
from cpl.core.log.logger import Logger
from cpl.core.log.wrapped_logger import WrappedLogger
self.add_transient(LoggerABC, Logger) self.add_transient(LoggerABC, Logger)
for wrapper in WrappedLogger.__subclasses__():
self.add_transient(wrapper)
return self
def add_structured_logging(self) -> Self:
from cpl.core.log.structured_logger import StructuredLogger
from cpl.core.log.wrapped_logger import WrappedLogger
self.add_transient(LoggerABC, StructuredLogger)
for wrapper in WrappedLogger.__subclasses__():
self.add_transient(wrapper)
return self
def add_cache(self, t: Type[T]):
self._service_descriptors.append(ServiceDescriptor(Cache(t=t), ServiceLifetimeEnum.singleton, Cache[t]))
return self return self

View File

@@ -1,6 +1,6 @@
from typing import Union, Optional from typing import Union, Optional
from cpl.dependency.service_lifetime import ServiceLifetimeEnum from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
class ServiceDescriptor: class ServiceDescriptor:

View File

@@ -1,8 +0,0 @@
from enum import Enum, auto
class ServiceLifetimeEnum(Enum):
singleton = auto()
scoped = auto()
transient = auto()
hosted = auto()

View File

@@ -0,0 +1,7 @@
from enum import Enum
class ServiceLifetimeEnum(Enum):
singleton = 0
scoped = 1
transient = 2

View File

@@ -1,44 +1,44 @@
import copy import copy
import typing import typing
from contextlib import contextmanager
from inspect import signature, Parameter, Signature from inspect import signature, Parameter, Signature
from typing import Optional, Type from typing import Optional
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.core.typing import T, Source from cpl.core.typing import T, R, Source
from cpl.dependency import use_provider from cpl.dependency.scope_abc import ScopeABC
from cpl.dependency.scope_builder import ScopeBuilder
from cpl.dependency.service_descriptor import ServiceDescriptor from cpl.dependency.service_descriptor import ServiceDescriptor
from cpl.dependency.service_lifetime import ServiceLifetimeEnum from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
from cpl.dependency.service_provider_abc import ServiceProviderABC
class ServiceProvider: class ServiceProvider(ServiceProviderABC):
def __init__(self, service_descriptors: list[ServiceDescriptor], is_scope: bool = False): r"""Provider for the services
Parameter
---------
service_descriptors: list[:class:`cpl.dependency.service_descriptor.ServiceDescriptor`]
Descriptor of the service
config: :class:`cpl.core.configuration.configuration_abc.ConfigurationABC`
CPL Configuration
db_context: Optional[:class:`cpl.database.context.database_context_abc.DatabaseContextABC`]
Database representation
"""
def __init__(
self,
service_descriptors: list[ServiceDescriptor],
):
ServiceProviderABC.__init__(self)
self._service_descriptors: list[ServiceDescriptor] = service_descriptors self._service_descriptors: list[ServiceDescriptor] = service_descriptors
self._is_scope = is_scope self._scope: Optional[ScopeABC] = None
def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]: def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]:
origin_type = typing.get_origin(service_type) or service_type
type_args = list(typing.get_args(service_type))
for descriptor in self._service_descriptors: for descriptor in self._service_descriptors:
if typing.get_origin(service_type) is None and (descriptor.service_type == service_type or issubclass(descriptor.base_type, service_type)): if descriptor.service_type == service_type or issubclass(descriptor.base_type, service_type):
return descriptor
descriptor_base_type = typing.get_origin(descriptor.base_type) or descriptor.base_type
descriptor_type_args = list(typing.get_args(descriptor.base_type))
if descriptor_base_type == origin_type and len(descriptor_type_args) == 0 and len(type_args) == 0:
return descriptor
if descriptor_base_type != origin_type or len(descriptor_type_args) != len(type_args):
continue
if descriptor_base_type == origin_type and type_args != descriptor_type_args:
continue
if descriptor.service_type == origin_type or issubclass(descriptor.base_type, origin_type):
return descriptor return descriptor
return None return None
@@ -51,13 +51,15 @@ class ServiceProvider:
if descriptor.implementation is not None: if descriptor.implementation is not None:
return descriptor.implementation return descriptor.implementation
implementation = self._build_service(descriptor, origin_service_type=origin_service_type) implementation = self._build_service(descriptor.service_type, origin_service_type=origin_service_type)
if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped): if descriptor.lifetime == ServiceLifetimeEnum.singleton:
descriptor.implementation = implementation descriptor.implementation = implementation
return implementation return implementation
def _get_services(self, t: type, *args, service_type: type = None, **kwargs) -> list[Optional[object]]: # raise Exception(f'Service {parameter.annotation} not found')
def _get_services(self, t: type, service_type: type = None, **kwargs) -> list[Optional[object]]:
implementations = [] implementations = []
for descriptor in self._service_descriptors: for descriptor in self._service_descriptors:
if descriptor.service_type == t or issubclass(descriptor.service_type, t): if descriptor.service_type == t or issubclass(descriptor.service_type, t):
@@ -66,27 +68,27 @@ class ServiceProvider:
continue continue
implementation = self._build_service( implementation = self._build_service(
descriptor, *args, origin_service_type=service_type, **kwargs descriptor.service_type, origin_service_type=service_type, **kwargs
) )
if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped): if descriptor.lifetime == ServiceLifetimeEnum.singleton:
descriptor.implementation = implementation descriptor.implementation = implementation
implementations.append(implementation) implementations.append(implementation)
return implementations return implementations
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[T]: def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[R]:
params = [] params = []
for param in sig.parameters.items(): for param in sig.parameters.items():
parameter = param[1] parameter = param[1]
if parameter.name != "self" and parameter.annotation != Parameter.empty: if parameter.name != "self" and parameter.annotation != Parameter.empty:
if typing.get_origin(parameter.annotation) == list: if typing.get_origin(parameter.annotation) == list:
params.append(self._get_services(typing.get_args(parameter.annotation)[0], service_type=origin_service_type)) params.append(self._get_services(typing.get_args(parameter.annotation)[0], origin_service_type))
elif parameter.annotation == Source: elif parameter.annotation == Source:
params.append(origin_service_type.__name__) params.append(origin_service_type.__name__)
elif issubclass(parameter.annotation, ServiceProvider): elif issubclass(parameter.annotation, ServiceProviderABC):
params.append(self) params.append(self)
elif issubclass(parameter.annotation, Environment): elif issubclass(parameter.annotation, Environment):
@@ -104,69 +106,64 @@ class ServiceProvider:
return params return params
def _build_service(self, descriptor: ServiceDescriptor, *args, origin_service_type: type = None, **kwargs) -> object: def _build_service(self, service_type: type, *args, origin_service_type: type = None, **kwargs) -> object:
if descriptor.implementation is not None:
service_type = type(descriptor.implementation)
else:
service_type = descriptor.service_type
if origin_service_type is None: if origin_service_type is None:
origin_service_type = service_type origin_service_type = service_type
for descriptor in self._service_descriptors:
if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type):
if descriptor.implementation is not None:
service_type = type(descriptor.implementation)
else:
service_type = descriptor.service_type
break
sig = signature(service_type.__init__) sig = signature(service_type.__init__)
params = self._build_by_signature(sig, origin_service_type) params = self._build_by_signature(sig, origin_service_type)
return service_type(*params, *args, **kwargs) return service_type(*params, *args, **kwargs)
@contextmanager def set_scope(self, scope: ScopeABC):
def create_scope(self): self._scope = scope
scoped_descriptors = []
for d in self._service_descriptors: def create_scope(self) -> ScopeABC:
if d.lifetime == ServiceLifetimeEnum.singleton: descriptors = []
scoped_descriptors.append(d)
for descriptor in self._service_descriptors:
if descriptor.lifetime == ServiceLifetimeEnum.singleton:
descriptors.append(descriptor)
else: else:
scoped_descriptors.append(copy.deepcopy(d)) descriptors.append(copy.deepcopy(descriptor))
scoped_provider = ServiceProvider(scoped_descriptors, is_scope=True) sb = ScopeBuilder(ServiceProvider(descriptors))
with use_provider(scoped_provider): return sb.build()
yield scoped_provider
def get_hosted_services(self) -> list[Optional[T]]: def get_service(self, service_type: T, *args, **kwargs) -> Optional[R]:
hosted_services = [self.get_service(d.service_type) for d in self._service_descriptors if d.lifetime == ServiceLifetimeEnum.hosted]
return hosted_services
def get_service(self, service_type: Type[T], *args, **kwargs) -> Optional[T]:
result = self._find_service(service_type) result = self._find_service(service_type)
if result is None: if result is None:
return None return None
if result.implementation is not None: if result.implementation is not None:
return result.implementation return result.implementation
implementation = self._build_service(result, *args, **kwargs) implementation = self._build_service(service_type, *args, **kwargs)
if (
if result.lifetime == ServiceLifetimeEnum.singleton: result.lifetime == ServiceLifetimeEnum.singleton
result.implementation = implementation or result.lifetime == ServiceLifetimeEnum.scoped
elif result.lifetime == ServiceLifetimeEnum.scoped and self._is_scope: and self._scope is not None
):
result.implementation = implementation result.implementation = implementation
return implementation return implementation
def get_service_type(self, service_type: Type[T]) -> Optional[Type[T]]: def get_services(self, service_type: T, *args, **kwargs) -> list[Optional[R]]:
for descriptor in self._service_descriptors:
if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type):
return descriptor.service_type
return None
def get_services(self, service_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
implementations = [] implementations = []
if typing.get_origin(service_type) == list: if typing.get_origin(service_type) == list:
raise Exception(f"Invalid type {service_type}! Expected single type not list of type") raise Exception(f"Invalid type {service_type}! Expected single type not list of type")
implementations.extend(self._get_services(service_type, *args, **kwargs))
return implementations
def get_service_types(self, service_type: Type[T]) -> list[Type[T]]: implementations.extend(self._get_services(service_type))
types = []
for descriptor in self._service_descriptors: return implementations
if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type):
types.append(descriptor.service_type)
return types

View File

@@ -0,0 +1,137 @@
import functools
from abc import abstractmethod, ABC
from inspect import Signature, signature, iscoroutinefunction
from typing import Optional, Type
from cpl.core.typing import T, R
from cpl.dependency.scope_abc import ScopeABC
class ServiceProviderABC(ABC):
r"""ABC for the class :class:`cpl.dependency.service_provider.ServiceProvider`"""
_provider: Optional["ServiceProviderABC"] = None
@abstractmethod
def __init__(self): ...
@classmethod
def set_global_provider(cls, provider: "ServiceProviderABC"):
cls._provider = provider
@classmethod
def get_global_provider(cls) -> Optional["ServiceProviderABC"]:
return cls._provider
@classmethod
def get_global_service(cls, instance_type: Type[T], *args, **kwargs) -> Optional[T]:
if cls._provider is None:
return None
return cls._provider.get_service(instance_type, *args, **kwargs)
@classmethod
def get_global_services(cls, instance_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
if cls._provider is None:
return []
return cls._provider.get_services(instance_type, *args, **kwargs)
@abstractmethod
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[T]: ...
@abstractmethod
def _build_service(self, service_type: type, *args, **kwargs) -> object:
r"""Creates instance of given type
Parameter
---------
instance_type: :class:`type`
The type of the searched instance
Returns
-------
Object of the given type
"""
@abstractmethod
def set_scope(self, scope: ScopeABC):
r"""Sets the scope of service provider
Parameter
---------
Object of type :class:`cpl.dependency.scope_abc.ScopeABC`
Service scope
"""
@abstractmethod
def create_scope(self) -> ScopeABC:
r"""Creates a service scope
Returns
-------
Object of type :class:`cpl.dependency.scope_abc.ScopeABC`
"""
@abstractmethod
def get_service(self, instance_type: Type[T], *args, **kwargs) -> Optional[T]:
r"""Returns instance of given type
Parameter
---------
instance_type: :class:`cpl.core.type.T`
The type of the searched instance
Returns
-------
Object of type Optional[:class:`cpl.core.type.T`]
"""
@abstractmethod
def get_services(self, service_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
r"""Returns instance of given type
Parameter
---------
service_type: :class:`cpl.core.type.T`
The type of the searched instance
Returns
-------
Object of type list[Optional[:class:`cpl.core.type.T`]
"""
@classmethod
def inject(cls, f=None):
r"""Decorator to allow injection into static and class methods
Parameter
---------
f: Callable
Returns
-------
function
"""
if f is None:
return functools.partial(cls.inject)
if iscoroutinefunction(f):
@functools.wraps(f)
async def async_inner(*args, **kwargs):
if cls._provider is None:
raise Exception(f"{cls.__name__} not build!")
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
return await f(*args, *injection, **kwargs)
return async_inner
@functools.wraps(f)
def inner(*args, **kwargs):
if cls._provider is None:
raise Exception(f"{cls.__name__} not build!")
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
return f(*args, *injection, **kwargs)
return inner

View File

@@ -3,7 +3,7 @@ from .abc.email_client_abc import EMailClientABC
from .email_client import EMailClient from .email_client import EMailClient
from .email_client_settings import EMailClientSettings from .email_client_settings import EMailClientSettings
from .email_model import EMail from .email_model import EMail
from .logger import MailLogger from .mail_logger import MailLogger
def add_mail(collection: _ServiceCollection): def add_mail(collection: _ServiceCollection):

View File

@@ -5,7 +5,7 @@ from typing import Optional
from cpl.mail.abc.email_client_abc import EMailClientABC from cpl.mail.abc.email_client_abc import EMailClientABC
from cpl.mail.email_client_settings import EMailClientSettings from cpl.mail.email_client_settings import EMailClientSettings
from cpl.mail.email_model import EMail from cpl.mail.email_model import EMail
from cpl.mail.logger import MailLogger from cpl.mail.mail_logger import MailLogger
class EMailClient(EMailClientABC): class EMailClient(EMailClientABC):

Some files were not shown because too many files have changed in this diff Show More