Compare commits

..

1 Commits

Author SHA1 Message Date
69bbbc8cee Authorization via with_route
Some checks failed
Build on push / prepare (push) Successful in 9s
Build on push / query (push) Successful in 20s
Build on push / core (push) Successful in 20s
Build on push / dependency (push) Successful in 17s
Build on push / mail (push) Successful in 15s
Build on push / application (push) Successful in 18s
Build on push / translation (push) Successful in 18s
Build on push / database (push) Successful in 25s
Build on push / auth (push) Successful in 15s
Build on push / api (push) Successful in 14s
Test before pr merge / test-lint (pull_request) Failing after 6s
2025-09-22 22:03:42 +02:00
226 changed files with 2134 additions and 2527 deletions

View File

@@ -25,11 +25,7 @@ jobs:
git tag git tag
DATE=$(date +'%Y.%m.%d') DATE=$(date +'%Y.%m.%d')
TAG_COUNT=$(git tag -l "${DATE}.*" | wc -l) TAG_COUNT=$(git tag -l "${DATE}.*" | wc -l)
if [ "$TAG_COUNT" -eq 0 ]; then BUILD_NUMBER=$(($TAG_COUNT + 1))
BUILD_NUMBER=0
else
BUILD_NUMBER=$(($TAG_COUNT + 1))
fi
VERSION_SUFFIX=${{ inputs.version_suffix }} VERSION_SUFFIX=${{ inputs.version_suffix }}
if [ -n "$VERSION_SUFFIX" ] && [ "$VERSION_SUFFIX" = "dev" ]; then if [ -n "$VERSION_SUFFIX" ] && [ "$VERSION_SUFFIX" = "dev" ]; then

3
.gitignore vendored
View File

@@ -139,6 +139,3 @@ PythonImportHelper-v2-Completion.json
# cpl unittest stuff # cpl unittest stuff
unittests/test_*_playground unittests/test_*_playground
# cpl logs
**/logs/*.jsonl

View File

@@ -1,85 +0,0 @@
from starlette.responses import JSONResponse
from cpl.api.api_module import ApiModule
from cpl.api.application.web_app import WebApp
from cpl.application.application_builder import ApplicationBuilder
from cpl.auth import AuthModule
from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema import AuthUser, Role
from cpl.core.configuration import Configuration
from cpl.core.console import Console
from cpl.core.environment import Environment
from cpl.core.utils.cache import Cache
from cpl.database.mysql.mysql_module import MySQLModule
from scoped_service import ScopedService
from service import PingService
def main():
builder = ApplicationBuilder[WebApp](WebApp)
Configuration.add_json_file(f"appsettings.json")
Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json")
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
# builder.services.add_logging()
builder.services.add_structured_logging()
builder.services.add_transient(PingService)
builder.services.add_module(MySQLModule)
builder.services.add_module(ApiModule)
builder.services.add_scoped(ScopedService)
builder.services.add_cache(AuthUser)
builder.services.add_cache(Role)
app = builder.build()
app.with_logging()
app.with_authentication()
app.with_authorization()
app.with_route(
path="/route1",
fn=lambda r: JSONResponse("route1"),
method="GET",
authentication=True,
permissions=[Permissions.administrator],
)
app.with_routes_directory("routes")
provider = builder.service_provider
user_cache = provider.get_service(Cache[AuthUser])
role_cache = provider.get_service(Cache[Role])
if role_cache == user_cache:
raise Exception("Cache service is not working")
s1 = provider.get_service(ScopedService)
s2 = provider.get_service(ScopedService)
if s1.name == s2.name:
raise Exception("Scoped service is not working")
with provider.create_scope() as scope:
s3 = scope.get_service(ScopedService)
s4 = scope.get_service(ScopedService)
if s3.name != s4.name:
raise Exception("Scoped service is not working")
if s1.name == s3.name:
raise Exception("Scoped service is not working")
Console.write_line(
s1.name,
s2.name,
s3.name,
s4.name,
)
app.run()
if __name__ == "__main__":
main()

View File

@@ -1,21 +0,0 @@
from urllib.request import Request
from service import PingService
from starlette.responses import JSONResponse
from cpl.api import APILogger
from cpl.api.router import Router
from cpl.core.console import Console
from cpl.dependency import ServiceProvider
from scoped_service import ScopedService
@Router.authenticate()
# @Router.authorize(permissions=[Permissions.administrator])
# @Router.authorize(policies=["test"])
@Router.get(f"/ping")
async def ping(r: Request, ping: PingService, logger: APILogger, provider: ServiceProvider, scoped: ScopedService):
logger.info(f"Ping: {ping}")
Console.write_line(scoped.name)
return JSONResponse(ping.ping(r))

View File

@@ -1,14 +0,0 @@
from cpl.core.console.console import Console
from cpl.core.utils.string import String
class ScopedService:
def __init__(self):
self._name = String.random(8)
@property
def name(self) -> str:
return self._name
def run(self):
Console.write_line(f"Im {self._name}")

View File

@@ -1,45 +0,0 @@
from cpl.application.abc import ApplicationABC
from cpl.core.console.console import Console
from cpl.dependency import ServiceProvider
from test_abc import TestABC
from test_service import TestService
from di_tester_service import DITesterService
from tester import Tester
class Application(ApplicationABC):
def __init__(self, services: ServiceProvider):
ApplicationABC.__init__(self, services)
def _part_of_scoped(self):
ts: TestService = self._services.get_service(TestService)
ts.run()
def main(self):
with self._services.create_scope() as scope:
Console.write_line("Scope1")
ts: TestService = scope.get_service(TestService)
ts.run()
dit: DITesterService = scope.get_service(DITesterService)
dit.run()
if ts.name != dit.name:
raise Exception("DI is broken!")
with self._services.create_scope() as scope:
Console.write_line("Scope2")
ts: TestService = scope.get_service(TestService)
ts.run()
dit: DITesterService = scope.get_service(DITesterService)
dit.run()
if ts.name != dit.name:
raise Exception("DI is broken!")
Console.write_line("Global")
self._part_of_scoped()
#from static_test import StaticTest
#StaticTest.test()
self._services.get_service(Tester)
Console.write_line(self._services.get_services(TestABC))

View File

@@ -1,27 +0,0 @@
from cpl.application.abc import StartupABC
from cpl.dependency import ServiceProvider, ServiceCollection
from di_tester_service import DITesterService
from test1_service import Test1Service
from test2_service import Test2Service
from test_abc import TestABC
from test_service import TestService
from tester import Tester
class Startup(StartupABC):
def __init__(self):
StartupABC.__init__(self)
@staticmethod
def configure_configuration(): ...
@staticmethod
def configure_services(services: ServiceCollection) -> ServiceProvider:
services.add_scoped(TestService)
services.add_scoped(DITesterService)
services.add_singleton(TestABC, Test1Service)
services.add_singleton(TestABC, Test2Service)
services.add_singleton(Tester)
return services.build()

View File

@@ -1,10 +0,0 @@
from cpl.dependency import ServiceProvider, ServiceProvider
from cpl.dependency.inject import inject
from test_service import TestService
class StaticTest:
@staticmethod
@inject
def test(services: ServiceProvider, t1: TestService):
t1.run()

View File

@@ -1,7 +0,0 @@
from cpl.core.console.console import Console
from test_abc import TestABC
class Tester:
def __init__(self, t1: TestABC, t2: TestABC, t3: TestABC, t: list[TestABC]):
Console.write_line("Tester:", t, t1, t2, t3)

View File

@@ -1,20 +0,0 @@
import asyncio
import time
from cpl.core.console import Console
from cpl.dependency.hosted.hosted_service import HostedService
class Hosted(HostedService):
def __init__(self):
self._stopped = False
async def start(self):
Console.write_line("Hosted Service Started")
while not self._stopped:
Console.write_line("Hosted Service Running")
await asyncio.sleep(5)
async def stop(self):
Console.write_line("Hosted Service Stopped")
self._stopped = True

View File

@@ -1,10 +0,0 @@
from cpl.core.console import Console
class ScopedService:
def __init__(self):
self.value = "I am a scoped service"
Console.write_line(self.value, self)
def get_value(self):
return self.value

View File

@@ -1,60 +0,0 @@
from cpl.core.console import Console
from cpl.core.utils.benchmark import Benchmark
from cpl.query.enumerable import Enumerable
from cpl.query.immutable_list import ImmutableList
from cpl.query.list import List
from cpl.query.set import Set
def _default():
Console.write_line(Enumerable.empty().to_list())
Console.write_line(Enumerable.range(0, 100).length)
Console.write_line(Enumerable.range(0, 100).to_list())
Console.write_line(Enumerable.range(0, 100).where(lambda x: x % 2 == 0).length)
Console.write_line(
Enumerable.range(0, 100).where(lambda x: x % 2 == 0).to_list().select(lambda x: str(x)).to_list()
)
Console.write_line(List)
s =Enumerable.range(0, 10).to_set()
Console.write_line(s)
s.add(1)
Console.write_line(s)
data = Enumerable(
[
{"name": "Alice", "age": 30},
{"name": "Dave", "age": 35},
{"name": "Charlie", "age": 25},
{"name": "Bob", "age": 25},
]
)
Console.write_line(data.order_by(lambda x: x["age"]).to_list())
Console.write_line(data.order_by(lambda x: x["age"]).then_by(lambda x: x["name"]).to_list())
Console.write_line(data.order_by(lambda x: x["name"]).then_by(lambda x: x["age"]).to_list())
def t_benchmark(data: list):
Benchmark.all("Enumerable", lambda: Enumerable(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
Benchmark.all("Set", lambda: Set(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
Benchmark.all("List", lambda: List(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
Benchmark.all(
"ImmutableList", lambda: ImmutableList(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list()
)
Benchmark.all("List comprehension", lambda: [x * 2 for x in data if x % 2 == 0])
def main():
N = 1_000_000
data = list(range(N))
t_benchmark(data)
Console.write_line()
_default()
if __name__ == "__main__":
main()

View File

@@ -1,4 +1,32 @@
from .error import APIError, AlreadyExists, EndpointNotImplemented, Forbidden, NotFound, Unauthorized from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
from .logger import APILogger
from .settings import ApiSettings
from .api_module import ApiModule def add_api(collection: _ServiceCollection):
try:
from cpl.database import mysql
collection.add_module(mysql)
except ImportError as e:
from cpl.core.errors import dependency_error
dependency_error("cpl-database", e)
try:
from cpl import auth
from cpl.auth import permission
collection.add_module(auth)
collection.add_module(permission)
except ImportError as e:
from cpl.core.errors import dependency_error
dependency_error("cpl-auth", e)
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.registry.route import RouteRegistry
collection.add_singleton(PolicyRegistry)
collection.add_singleton(RouteRegistry)
_ServiceCollection.with_module(add_api, __name__)

View File

@@ -1 +0,0 @@
from .asgi_middleware_abc import ASGIMiddleware

View File

@@ -1,22 +0,0 @@
from cpl.api import ApiSettings
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.registry.route import RouteRegistry
from cpl.auth.auth_module import AuthModule
from cpl.auth.permission.permission_module import PermissionsModule
from cpl.database.database_module import DatabaseModule
from cpl.dependency import ServiceCollection
from cpl.dependency.module.module import Module
class ApiModule(Module):
config = [ApiSettings]
singleton = [
PolicyRegistry,
RouteRegistry,
]
@staticmethod
def register(collection: ServiceCollection):
collection.add_module(DatabaseModule)
collection.add_module(AuthModule)
collection.add_module(PermissionsModule)

View File

@@ -1 +0,0 @@
from .web_app import WebApp

View File

@@ -10,7 +10,7 @@ from starlette.requests import Request
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from starlette.types import ExceptionHandler from starlette.types import ExceptionHandler
from cpl.api.api_module import ApiModule from cpl import api, auth
from cpl.api.error import APIError from cpl.api.error import APIError
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.api.middleware.authentication import AuthenticationMiddleware from cpl.api.middleware.authentication import AuthenticationMiddleware
@@ -26,45 +26,41 @@ from cpl.api.router import Router
from cpl.api.settings import ApiSettings from cpl.api.settings import ApiSettings
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
from cpl.application.abc.application_abc import ApplicationABC from cpl.application.abc.application_abc import ApplicationABC
from cpl.auth.auth_module import AuthModule from cpl.core.configuration import Configuration
from cpl.auth.permission.permission_module import PermissionsModule from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.core.configuration.configuration import Configuration
from cpl.dependency.inject import inject _logger = APILogger("API")
from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.typing import Modules
PolicyInput = Union[dict[str, PolicyResolver], Policy] PolicyInput = Union[dict[str, PolicyResolver], Policy]
class WebApp(ApplicationABC): class WebApp(ApplicationABC):
def __init__(self, services: ServiceProvider, modules: Modules): def __init__(self, services: ServiceProviderABC):
super().__init__(services, modules, [AuthModule, PermissionsModule, ApiModule]) super().__init__(services, [auth, api])
self._app: Starlette | None = None self._app: Starlette | None = None
self._logger = services.get_service(APILogger)
self._api_settings = Configuration.get(ApiSettings) self._api_settings = Configuration.get(ApiSettings)
self._policies = services.get_service(PolicyRegistry) self._policies = services.get_service(PolicyRegistry)
self._routes = services.get_service(RouteRegistry) self._routes = services.get_service(RouteRegistry)
self._middleware: list[Middleware] = [] self._middleware: list[Middleware] = [
Middleware(RequestMiddleware),
Middleware(LoggingMiddleware),
]
self._exception_handlers: Mapping[Any, ExceptionHandler] = { self._exception_handlers: Mapping[Any, ExceptionHandler] = {
Exception: self._handle_exception, Exception: self._handle_exception,
APIError: self._handle_exception, APIError: self._handle_exception,
} }
self.with_middleware(RequestMiddleware) @staticmethod
self.with_middleware(LoggingMiddleware) async def _handle_exception(request: Request, exc: Exception):
async def _handle_exception(self, request: Request, exc: Exception):
if isinstance(exc, APIError): if isinstance(exc, APIError):
self._logger.error(exc) _logger.error(exc)
return JSONResponse({"error": str(exc)}, status_code=exc.status_code) return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
if hasattr(request.state, "request_id"): if hasattr(request.state, "request_id"):
self._logger.error(f"Request {request.state.request_id}", exc) _logger.error(f"Request {request.state.request_id}", exc)
else: else:
self._logger.error("Request unknown", exc) _logger.error("Request unknown", exc)
return JSONResponse({"error": str(exc)}, status_code=500) return JSONResponse({"error": str(exc)}, status_code=500)
@@ -72,12 +68,17 @@ class WebApp(ApplicationABC):
origins = self._api_settings.allowed_origins origins = self._api_settings.allowed_origins
if origins is None or origins == "": if origins is None or origins == "":
self._logger.warning("No allowed origins specified, allowing all origins") _logger.warning("No allowed origins specified, allowing all origins")
return ["*"] return ["*"]
self._logger.debug(f"Allowed origins: {origins}") _logger.debug(f"Allowed origins: {origins}")
return origins.split(",") return origins.split(",")
def with_database(self) -> Self:
self.with_migrations()
self.with_seeders()
return self
def with_app(self, app: Starlette) -> Self: def with_app(self, app: Starlette) -> Self:
assert app is not None, "app must not be None" assert app is not None, "app must not be None"
assert isinstance(app, Starlette), "app must be an instance of Starlette" assert isinstance(app, Starlette), "app must be an instance of Starlette"
@@ -166,9 +167,9 @@ class WebApp(ApplicationABC):
self._check_for_app() self._check_for_app()
if isinstance(middleware, Middleware): if isinstance(middleware, Middleware):
self._middleware.append(inject(middleware)) self._middleware.append(middleware)
elif callable(middleware): elif callable(middleware):
self._middleware.append(Middleware(inject(middleware))) self._middleware.append(Middleware(middleware))
else: else:
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable") raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
@@ -189,11 +190,11 @@ class WebApp(ApplicationABC):
if isinstance(policy, dict): if isinstance(policy, dict):
for name, resolver in policy.items(): for name, resolver in policy.items():
if not isinstance(name, str): if not isinstance(name, str):
self._logger.warning(f"Skipping policy at index {i}, name must be a string") _logger.warning(f"Skipping policy at index {i}, name must be a string")
continue continue
if not callable(resolver): if not callable(resolver):
self._logger.warning(f"Skipping policy {name}, resolver must be callable") _logger.warning(f"Skipping policy {name}, resolver must be callable")
continue continue
_policies.append(Policy(name, resolver)) _policies.append(Policy(name, resolver))
@@ -201,7 +202,7 @@ class WebApp(ApplicationABC):
_policies.append(policy) _policies.append(policy)
self._policies.extend(_policies) self._policies.extend_policies(_policies)
self.with_middleware(AuthorizationMiddleware) self.with_middleware(AuthorizationMiddleware)
return self return self
@@ -211,14 +212,14 @@ class WebApp(ApplicationABC):
for policy_name in rule["policies"]: for policy_name in rule["policies"]:
policy = self._policies.get(policy_name) policy = self._policies.get(policy_name)
if not policy: if not policy:
self._logger.fatal(f"Authorization policy '{policy_name}' not found") _logger.fatal(f"Authorization policy '{policy_name}' not found")
async def main(self): async def main(self):
self._logger.debug(f"Preparing API") _logger.debug(f"Preparing API")
self._validate_policies() self._validate_policies()
if self._app is None: if self._app is None:
routes = [route.to_starlette(inject) for route in self._routes.all()] routes = [route.to_starlette(self._services.inject) for route in self._routes.all()]
app = Starlette( app = Starlette(
routes=routes, routes=routes,
@@ -236,7 +237,7 @@ class WebApp(ApplicationABC):
else: else:
app = self._app app = self._app
self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}") _logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
config = uvicorn.Config( config = uvicorn.Config(
app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio" app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio"
@@ -244,4 +245,4 @@ class WebApp(ApplicationABC):
server = uvicorn.Server(config) server = uvicorn.Server(config)
await server.serve() await server.serve()
self._logger.info("Shutdown API") _logger.info("Shutdown API")

View File

@@ -1,7 +1,7 @@
from cpl.core.log.wrapped_logger import WrappedLogger from cpl.core.log.logger import Logger
class APILogger(WrappedLogger): class APILogger(Logger):
def __init__(self): def __init__(self, source: str):
WrappedLogger.__init__(self, "api") Logger.__init__(self, source, "api")

View File

@@ -1,4 +0,0 @@
from .authentication import AuthenticationMiddleware
from .authorization import AuthorizationMiddleware
from .logging import LoggingMiddleware
from .request import RequestMiddleware

View File

@@ -2,22 +2,24 @@ from keycloak import KeycloakAuthenticationError
from starlette.types import Scope, Receive, Send from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.error import Unauthorized
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.api.error import Unauthorized
from cpl.api.middleware.request import get_request from cpl.api.middleware.request import get_request
from cpl.api.router import Router from cpl.api.router import Router
from cpl.auth.keycloak import KeycloakClient from cpl.auth.keycloak import KeycloakClient
from cpl.auth.schema import AuthUserDao, AuthUser from cpl.auth.schema import AuthUserDao, AuthUser
from cpl.core.ctx import set_user from cpl.core.ctx import set_user
from cpl.dependency import ServiceProviderABC
_logger = APILogger(__name__)
class AuthenticationMiddleware(ASGIMiddleware): class AuthenticationMiddleware(ASGIMiddleware):
def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao): @ServiceProviderABC.inject
def __init__(self, app, keycloak: KeycloakClient, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)
self._logger = logger
self._keycloak = keycloak self._keycloak = keycloak
self._user_dao = user_dao self._user_dao = user_dao
@@ -26,11 +28,11 @@ class AuthenticationMiddleware(ASGIMiddleware):
url = request.url.path url = request.url.path
if url not in Router.get_auth_required_routes(): if url not in Router.get_auth_required_routes():
self._logger.trace(f"No authentication required for {url}") _logger.trace(f"No authentication required for {url}")
return await self._app(scope, receive, send) return await self._app(scope, receive, send)
if not request.headers.get("Authorization"): if not request.headers.get("Authorization"):
self._logger.debug(f"Unauthorized access to {url}, missing Authorization header") _logger.debug(f"Unauthorized access to {url}, missing Authorization header")
return await Unauthorized(f"Missing header Authorization").asgi_response(scope, receive, send) return await Unauthorized(f"Missing header Authorization").asgi_response(scope, receive, send)
auth_header = request.headers.get("Authorization", None) auth_header = request.headers.get("Authorization", None)
@@ -39,7 +41,7 @@ class AuthenticationMiddleware(ASGIMiddleware):
token = auth_header.split("Bearer ")[1] token = auth_header.split("Bearer ")[1]
if not await self._verify_login(token): if not await self._verify_login(token):
self._logger.debug(f"Unauthorized access to {url}, invalid token") _logger.debug(f"Unauthorized access to {url}, invalid token")
return await Unauthorized("Invalid token").asgi_response(scope, receive, send) return await Unauthorized("Invalid token").asgi_response(scope, receive, send)
# check user exists in db, if not create # check user exists in db, if not create
@@ -49,7 +51,7 @@ class AuthenticationMiddleware(ASGIMiddleware):
user = await self._get_or_crate_user(keycloak_id) user = await self._get_or_crate_user(keycloak_id)
if user.deleted: if user.deleted:
self._logger.debug(f"Unauthorized access to {url}, user is deleted") _logger.debug(f"Unauthorized access to {url}, user is deleted")
return await Unauthorized("User is deleted").asgi_response(scope, receive, send) return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
request.state.user = user request.state.user = user
@@ -71,8 +73,8 @@ class AuthenticationMiddleware(ASGIMiddleware):
token_info = self._keycloak.introspect(token) token_info = self._keycloak.introspect(token)
return token_info.get("active", False) return token_info.get("active", False)
except KeycloakAuthenticationError as e: except KeycloakAuthenticationError as e:
self._logger.debug(f"Keycloak authentication error: {e}") _logger.debug(f"Keycloak authentication error: {e}")
return False return False
except Exception as e: except Exception as e:
self._logger.error(f"Unexpected error during token verification: {e}") _logger.error(f"Unexpected error during token verification: {e}")
return False return False

View File

@@ -9,15 +9,17 @@ from cpl.api.registry.policy import PolicyRegistry
from cpl.api.router import Router from cpl.api.router import Router
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
from cpl.core.ctx.user_context import get_user from cpl.core.ctx.user_context import get_user
from cpl.dependency.service_provider_abc import ServiceProviderABC
_logger = APILogger(__name__)
class AuthorizationMiddleware(ASGIMiddleware): class AuthorizationMiddleware(ASGIMiddleware):
def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao): @ServiceProviderABC.inject
def __init__(self, app, policies: PolicyRegistry, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)
self._logger = logger
self._policies = policies self._policies = policies
self._user_dao = user_dao self._user_dao = user_dao
@@ -26,7 +28,7 @@ class AuthorizationMiddleware(ASGIMiddleware):
url = request.url.path url = request.url.path
if url not in Router.get_authorization_rules_paths(): if url not in Router.get_authorization_rules_paths():
self._logger.trace(f"No authorization required for {url}") _logger.trace(f"No authorization required for {url}")
return await self._app(scope, receive, send) return await self._app(scope, receive, send)
user = get_user() user = get_user()
@@ -51,21 +53,17 @@ class AuthorizationMiddleware(ASGIMiddleware):
if rule["permissions"]: if rule["permissions"]:
if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]): if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response( return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
scope, receive, send
)
if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]): if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response( return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
scope, receive, send
)
for policy_name in rule["policies"]: for policy_name in rule["policies"]:
policy = self._policies.get(policy_name) policy = self._policies.get(policy_name)
if not policy: if not policy:
self._logger.warning(f"Authorization policy '{policy_name}' not found") _logger.warning(f"Authorization policy '{policy_name}' not found")
continue continue
if not await policy.resolve(user): if not await policy.resolve(user):
return await Forbidden(f"policy {policy.name} failed").asgi_response(scope, receive, send) return await Forbidden(f"policy {policy.name} failed").asgi_response(scope, receive, send)
return await self._call_next(scope, receive, send) return await self._call_next(scope, receive, send)

View File

@@ -7,14 +7,14 @@ from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.api.middleware.request import get_request from cpl.api.middleware.request import get_request
_logger = APILogger(__name__)
class LoggingMiddleware(ASGIMiddleware): class LoggingMiddleware(ASGIMiddleware):
def __init__(self, app, logger: APILogger): def __init__(self, app):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)
self._logger = logger
async def __call__(self, scope: Scope, receive: Receive, send: Send): async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http": if scope["type"] != "http":
await self._call_next(scope, receive, send) await self._call_next(scope, receive, send)
@@ -53,8 +53,9 @@ class LoggingMiddleware(ASGIMiddleware):
} }
return {key: value for key, value in headers.items() if key in relevant_keys} return {key: value for key, value in headers.items() if key in relevant_keys}
async def _log_request(self, request: Request): @classmethod
self._logger.debug( async def _log_request(cls, request: Request):
_logger.debug(
f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}" f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}"
) )
@@ -63,7 +64,7 @@ class LoggingMiddleware(ASGIMiddleware):
user = get_user() user = get_user()
request_info = { request_info = {
"headers": self._filter_relevant_headers(dict(request.headers)), "headers": cls._filter_relevant_headers(dict(request.headers)),
"args": dict(request.query_params), "args": dict(request.query_params),
"form-data": ( "form-data": (
await request.form() await request.form()
@@ -77,9 +78,10 @@ class LoggingMiddleware(ASGIMiddleware):
), ),
} }
self._logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}") _logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
async def _log_after_request(self, request: Request, status_code: int, duration: float): @staticmethod
self._logger.info( async def _log_after_request(request: Request, status_code: int, duration: float):
_logger.info(
f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms" f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
) )

View File

@@ -9,20 +9,16 @@ from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.api.typing import TRequest from cpl.api.typing import TRequest
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None) _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
_logger = APILogger(__name__)
class RequestMiddleware(ASGIMiddleware): class RequestMiddleware(ASGIMiddleware):
def __init__(self, app, provider: ServiceProvider, logger: APILogger): def __init__(self, app):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)
self._provider = provider
self._logger = logger
self._ctx_token = None self._ctx_token = None
async def __call__(self, scope: Scope, receive: Receive, send: Send): async def __call__(self, scope: Scope, receive: Receive, send: Send):
@@ -30,15 +26,14 @@ class RequestMiddleware(ASGIMiddleware):
await self.set_request_data(request) await self.set_request_data(request)
try: try:
with self._provider.create_scope(): await self._app(scope, receive, send)
inject(await self._app(scope, receive, send))
finally: finally:
await self.clean_request_data() await self.clean_request_data()
async def set_request_data(self, request: TRequest): async def set_request_data(self, request: TRequest):
request.state.request_id = uuid4() request.state.request_id = uuid4()
request.state.start_time = time.time() request.state.start_time = time.time()
self._logger.trace(f"Set new current request: {request.state.request_id}") _logger.trace(f"Set new current request: {request.state.request_id}")
self._ctx_token = _request_context.set(request) self._ctx_token = _request_context.set(request)
@@ -50,7 +45,7 @@ class RequestMiddleware(ASGIMiddleware):
if self._ctx_token is None: if self._ctx_token is None:
return return
self._logger.trace(f"Clearing current request: {request.state.request_id}") _logger.trace(f"Clearing current request: {request.state.request_id}")
_request_context.reset(self._ctx_token) _request_context.reset(self._ctx_token)

View File

@@ -1,3 +0,0 @@
from .api_route import ApiRoute
from .policy import Policy
from .validation_match import ValidationMatch

View File

@@ -7,7 +7,13 @@ from cpl.api.typing import HTTPMethods
class ApiRoute: class ApiRoute:
def __init__(self, path: str, fn: Callable, method: HTTPMethods, **kwargs): def __init__(
self,
path: str,
fn: Callable,
method: HTTPMethods,
**kwargs
):
self._path = path self._path = path
self._fn = fn self._fn = fn
self._method = method self._method = method

View File

@@ -1,5 +1,5 @@
from asyncio import iscoroutinefunction from asyncio import iscoroutinefunction
from typing import Optional from typing import Optional, Any, Coroutine, Awaitable
from cpl.api.typing import PolicyResolver from cpl.api.typing import PolicyResolver
from cpl.core.ctx import get_user from cpl.core.ctx import get_user

View File

@@ -1,2 +0,0 @@
from .policy import PolicyRegistry
from .route import RouteRegistry

View File

@@ -1,5 +1,6 @@
from typing import Optional from typing import Optional
from cpl.api.model.policy import Policy
from cpl.api.model.api_route import ApiRoute from cpl.api.model.api_route import ApiRoute
from cpl.core.abc.registry_abc import RegistryABC from cpl.core.abc.registry_abc import RegistryABC

View File

@@ -3,7 +3,6 @@ from enum import Enum
from cpl.api.model.validation_match import ValidationMatch from cpl.api.model.validation_match import ValidationMatch
from cpl.api.registry.route import RouteRegistry from cpl.api.registry.route import RouteRegistry
from cpl.api.typing import HTTPMethods from cpl.api.typing import HTTPMethods
from cpl.dependency import get_provider
class Router: class Router:
@@ -42,13 +41,7 @@ class Router:
return inner return inner
@classmethod @classmethod
def authorize( def authorize(cls, roles: list[str | Enum]=None, permissions: list[str | Enum]=None, policies: list[str]=None, match: ValidationMatch=None):
cls,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
):
""" """
Decorator to mark a route as requiring authorization. Decorator to mark a route as requiring authorization.
Usage: Usage:
@@ -92,14 +85,15 @@ class Router:
return inner return inner
@classmethod @classmethod
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs): def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry=None, **kwargs):
from cpl.api.model.api_route import ApiRoute
if not registry: if not registry:
routes = get_provider().get_service(RouteRegistry) from cpl.api.model.api_route import ApiRoute
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = ServiceProviderABC.get_global_service(RouteRegistry)
else: else:
routes = registry routes = registry
def inner(fn): def inner(fn):
routes.add(ApiRoute(path, fn, method, **kwargs)) routes.add(ApiRoute(path, fn, method, **kwargs))
setattr(fn, "_route_path", path) setattr(fn, "_route_path", path)
@@ -143,9 +137,8 @@ class Router:
""" """
from cpl.api.model.api_route import ApiRoute from cpl.api.model.api_route import ApiRoute
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = get_provider().get_service(RouteRegistry) routes = ServiceProviderABC.get_global_service(RouteRegistry)
def inner(fn): def inner(fn):
path = getattr(fn, "_route_path", None) path = getattr(fn, "_route_path", None)
if path is None: if path is None:
@@ -154,7 +147,7 @@ class Router:
route = routes.get(path) route = routes.get(path)
if route is None: if route is None:
raise ValueError(f"Cannot override a route that does not exist: {path}") raise ValueError(f"Cannot override a route that does not exist: {path}")
routes.add(ApiRoute(path, fn, route.method, **route.kwargs)) routes.add(ApiRoute(path, fn, route.method, **route.kwargs))
setattr(fn, "_route_path", path) setattr(fn, "_route_path", path)
return fn return fn

View File

@@ -16,4 +16,4 @@ PartialMiddleware = Union[
Middleware, Middleware,
Callable[[ASGIApp], ASGIApp], Callable[[ASGIApp], ASGIApp],
] ]
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]] PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]

View File

@@ -1,2 +1 @@
from .application_builder import ApplicationBuilder from .application_builder import ApplicationBuilder
from .host import Host

View File

@@ -2,12 +2,11 @@ from abc import ABC, abstractmethod
from typing import Callable, Self from typing import Callable, Self
from cpl.application.host import Host from cpl.application.host import Host
from cpl.core.errors import module_dependency_error from cpl.core.console.console import Console
from cpl.core.log import LogSettings
from cpl.core.log.log_level import LogLevel from cpl.core.log.log_level import LogLevel
from cpl.core.log.log_settings import LogSettings
from cpl.core.log.logger_abc import LoggerABC from cpl.core.log.logger_abc import LoggerABC
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.dependency.typing import TModule
def __not_implemented__(package: str, func: Callable): def __not_implemented__(package: str, func: Callable):
@@ -18,10 +17,21 @@ class ApplicationABC(ABC):
r"""ABC for the Application class r"""ABC for the Application class
Parameters: Parameters:
services: :class:`cpl.dependency.service_provider.ServiceProvider` services: :class:`cpl.dependency.service_provider_abc.ServiceProviderABC`
Contains instances of prepared objects Contains instances of prepared objects
""" """
@abstractmethod
def __init__(self, services: ServiceProviderABC, required_modules: list[str | object] = None):
self._services = services
self._required_modules = (
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []
)
@property
def required_modules(self) -> list[str]:
return self._required_modules
@classmethod @classmethod
def extend(cls, name: str | Callable, func: Callable[[Self], Self]): def extend(cls, name: str | Callable, func: Callable[[Self], Self]):
r"""Extend the Application with a custom method r"""Extend the Application with a custom method
@@ -38,30 +48,6 @@ class ApplicationABC(ABC):
setattr(cls, name, func) setattr(cls, name, func)
return cls return cls
@abstractmethod
def __init__(
self, services: ServiceProvider, loaded_modules: set[TModule], required_modules: list[str | object] = None
):
self._services = services
self._modules = loaded_modules
self._required_modules = (
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []
)
def validate_app_required_modules(self):
modules_names = {x.__name__ for x in self._modules}
for module in self._required_modules:
if module in modules_names:
continue
module_dependency_error(
type(self).__name__,
module.__name__,
ImportError(
f"Required module '{module}' for application '{self.__class__.__name__}' is not loaded. Load using 'add_module({module})' method."
),
)
def with_logging(self, level: LogLevel = None): def with_logging(self, level: LogLevel = None):
if level is None: if level is None:
from cpl.core.configuration.configuration import Configuration from cpl.core.configuration.configuration import Configuration
@@ -72,21 +58,14 @@ class ApplicationABC(ABC):
logger = self._services.get_service(LoggerABC) logger = self._services.get_service(LoggerABC)
logger.set_level(level) logger.set_level(level)
def with_permissions(self, *args): def with_permissions(self, *args, **kwargs):
try: __not_implemented__("cpl-auth", self.with_permissions)
from cpl.auth import AuthModule
AuthModule.with_permissions(*args) def with_migrations(self, *args, **kwargs):
except ImportError: __not_implemented__("cpl-database", self.with_migrations)
__not_implemented__("cpl-auth", self.with_permissions)
def with_migrations(self, *args): def with_seeders(self, *args, **kwargs):
try: __not_implemented__("cpl-database", self.with_seeders)
from cpl.database.database_module import DatabaseModule
DatabaseModule.with_migrations(self._services, *args)
except ImportError:
__not_implemented__("cpl-database", self.with_migrations)
def with_extension(self, func: Callable[[Self, ...], None], *args, **kwargs): def with_extension(self, func: Callable[[Self, ...], None], *args, **kwargs):
r"""Extend the Application with a custom method r"""Extend the Application with a custom method
@@ -106,17 +85,9 @@ class ApplicationABC(ABC):
Called by custom Application.main Called by custom Application.main
""" """
try: try:
for module in self._modules: Host.run(self.main)
if not hasattr(module, "configure") and not callable(getattr(module, "configure")):
continue
module.configure(self._services)
Host.run_app(self.main)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
finally:
logger = self._services.get_service(LoggerABC)
logger.info("Application shutdown")
@abstractmethod @abstractmethod
def main(self): ... def main(self): ...

View File

@@ -1,10 +1,10 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency import ServiceProviderABC
class ApplicationExtensionABC(ABC): class ApplicationExtensionABC(ABC):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def run(services: ServiceProvider): ... def run(services: ServiceProviderABC): ...

View File

@@ -6,7 +6,7 @@ from cpl.application.abc.application_extension_abc import ApplicationExtensionAB
from cpl.application.abc.startup_abc import StartupABC from cpl.application.abc.startup_abc import StartupABC
from cpl.application.abc.startup_extension_abc import StartupExtensionABC from cpl.application.abc.startup_extension_abc import StartupExtensionABC
from cpl.application.host import Host from cpl.application.host import Host
from cpl.dependency.context import get_provider, use_root_provider from cpl.core.errors import dependency_error
from cpl.dependency.service_collection import ServiceCollection from cpl.dependency.service_collection import ServiceCollection
TApp = TypeVar("TApp", bound=ApplicationABC) TApp = TypeVar("TApp", bound=ApplicationABC)
@@ -21,7 +21,6 @@ class ApplicationBuilder(Generic[TApp]):
self._app = app if app is not None else ApplicationABC self._app = app if app is not None else ApplicationABC
self._services = ServiceCollection() self._services = ServiceCollection()
use_root_provider(self._services.build())
self._startup: Optional[StartupABC] = None self._startup: Optional[StartupABC] = None
self._app_extensions: list[Type[ApplicationExtensionABC]] = [] self._app_extensions: list[Type[ApplicationExtensionABC]] = []
@@ -35,12 +34,19 @@ class ApplicationBuilder(Generic[TApp]):
@property @property
def service_provider(self): def service_provider(self):
provider = get_provider() return self._services.build()
if provider is None:
provider = self._services.build()
use_root_provider(provider)
return provider def validate_app_required_modules(self, app: ApplicationABC):
for module in app.required_modules:
if module in self._services.loaded_modules:
continue
dependency_error(
module,
ImportError(
f"Required module '{module}' for application '{app.__class__.__name__}' is not loaded. Load using 'add_module({module})' method."
),
)
def with_startup(self, startup: Type[StartupABC]) -> "ApplicationBuilder": def with_startup(self, startup: Type[StartupABC]) -> "ApplicationBuilder":
self._startup = startup self._startup = startup
@@ -69,7 +75,6 @@ class ApplicationBuilder(Generic[TApp]):
for extension in self._app_extensions: for extension in self._app_extensions:
Host.run(extension.run, self.service_provider) Host.run(extension.run, self.service_provider)
use_root_provider(self._services.build()) app = self._app(self.service_provider)
app = self._app(self.service_provider, self._services.loaded_modules) self.validate_app_required_modules(app)
app.validate_app_required_modules()
return app return app

View File

@@ -1,80 +1,17 @@
import asyncio import asyncio
from typing import Callable from typing import Callable
from cpl.dependency import get_provider
from cpl.dependency.hosted.startup_task import StartupTask
class Host: class Host:
_loop: asyncio.AbstractEventLoop | None = None _loop = asyncio.get_event_loop()
_tasks: dict = {}
@classmethod @classmethod
def get_loop(cls) -> asyncio.AbstractEventLoop: def get_loop(cls):
if cls._loop is None:
cls._loop = asyncio.new_event_loop()
asyncio.set_event_loop(cls._loop)
return cls._loop return cls._loop
@classmethod
def run_start_tasks(cls):
provider = get_provider()
tasks = provider.get_services(StartupTask)
loop = cls.get_loop()
for task in tasks:
if asyncio.iscoroutinefunction(task.run):
loop.run_until_complete(task.run())
else:
task.run()
@classmethod
def run_hosted_services(cls):
provider = get_provider()
services = provider.get_hosted_services()
loop = cls.get_loop()
for service in services:
if asyncio.iscoroutinefunction(service.start):
cls._tasks[service] = loop.create_task(service.start())
@classmethod
async def _stop_all(cls):
for service in cls._tasks.keys():
if asyncio.iscoroutinefunction(service.stop):
await service.stop()
for task in cls._tasks.values():
task.cancel()
cls._tasks.clear()
@classmethod
def run_app(cls, func: Callable, *args, **kwargs):
cls.run_start_tasks()
cls.run_hosted_services()
async def runner():
try:
if asyncio.iscoroutinefunction(func):
app_task = asyncio.create_task(func(*args, **kwargs))
else:
app_task = cls.get_loop().run_in_executor(None, func, *args, **kwargs)
await asyncio.wait(
[app_task, *cls._tasks.values()],
return_when=asyncio.FIRST_COMPLETED,
)
except (KeyboardInterrupt, asyncio.CancelledError):
pass
finally:
await cls._stop_all()
cls.get_loop().run_until_complete(runner())
@classmethod @classmethod
def run(cls, func: Callable, *args, **kwargs): def run(cls, func: Callable, *args, **kwargs):
if asyncio.iscoroutinefunction(func): if asyncio.iscoroutinefunction(func):
return cls.get_loop().run_until_complete(func(*args, **kwargs)) return cls._loop.run_until_complete(func(*args, **kwargs))
return func(*args, **kwargs) return func(*args, **kwargs)

View File

@@ -1,6 +1,84 @@
from enum import Enum
from typing import Type
from cpl.application.abc import ApplicationABC as _ApplicationABC
from cpl.auth import permission as _permission from cpl.auth import permission as _permission
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin as _KeycloakAdmin from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin as _KeycloakAdmin
from cpl.auth.keycloak.keycloak_client import KeycloakClient as _KeycloakClient from cpl.auth.keycloak.keycloak_client import KeycloakClient as _KeycloakClient
from .auth_module import AuthModule from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
from .auth_logger import AuthLogger
from .keycloak_settings import KeycloakSettings from .keycloak_settings import KeycloakSettings
from .logger import AuthLogger from .permission_seeder import PermissionSeeder
def _with_permissions(self: _ApplicationABC, *permissions: Type[Enum]) -> _ApplicationABC:
from cpl.auth.permission.permissions_registry import PermissionsRegistry
for perm in permissions:
PermissionsRegistry.with_enum(perm)
return self
def _add_daos(collection: _ServiceCollection):
from .schema._administration.auth_user_dao import AuthUserDao
from .schema._administration.api_key_dao import ApiKeyDao
from .schema._permission.api_key_permission_dao import ApiKeyPermissionDao
from .schema._permission.permission_dao import PermissionDao
from .schema._permission.role_dao import RoleDao
from .schema._permission.role_permission_dao import RolePermissionDao
from .schema._permission.role_user_dao import RoleUserDao
collection.add_singleton(AuthUserDao)
collection.add_singleton(ApiKeyDao)
collection.add_singleton(ApiKeyPermissionDao)
collection.add_singleton(PermissionDao)
collection.add_singleton(RoleDao)
collection.add_singleton(RolePermissionDao)
collection.add_singleton(RoleUserDao)
def add_auth(collection: _ServiceCollection):
import os
try:
from cpl.database.service.migration_service import MigrationService
from cpl.database.model.server_type import ServerType, ServerTypes
collection.add_singleton(_KeycloakClient)
collection.add_singleton(_KeycloakAdmin)
_add_daos(collection)
provider = collection.build()
migration_service: MigrationService = provider.get_service(MigrationService)
if ServerType.server_type == ServerTypes.POSTGRES:
migration_service.with_directory(
os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/postgres")
)
elif ServerType.server_type == ServerTypes.MYSQL:
migration_service.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/mysql"))
except ImportError as e:
from cpl.core.console import Console
Console.error("cpl-database is not installed", str(e))
def add_permission(collection: _ServiceCollection):
from .permission_seeder import PermissionSeeder
from .permission.permissions_registry import PermissionsRegistry
from .permission.permissions import Permissions
try:
from cpl.database.abc.data_seeder_abc import DataSeederABC
collection.add_singleton(DataSeederABC, PermissionSeeder)
PermissionsRegistry.with_enum(Permissions)
except ImportError as e:
from cpl.core.console import Console
Console.error("cpl-database is not installed", str(e))
_ServiceCollection.with_module(add_auth, __name__)
_ServiceCollection.with_module(add_permission, _permission.__name__)
_ApplicationABC.extend(_ApplicationABC.with_permissions, _with_permissions)

View File

@@ -0,0 +1,8 @@
from cpl.core.log import Logger
from cpl.core.typing import Source
class AuthLogger(Logger):
def __init__(self, source: Source):
Logger.__init__(self, source, "auth")

View File

@@ -1,56 +0,0 @@
import os
from enum import Enum
from typing import Type
from cpl.auth.keycloak_settings import KeycloakSettings
from cpl.database.database_module import DatabaseModule
from cpl.database.model.server_type import ServerType, ServerTypes
from cpl.database.mysql.mysql_module import MySQLModule
from cpl.database.postgres.postgres_module import PostgresModule
from cpl.dependency.module.module import Module
from cpl.dependency.service_provider import ServiceProvider
from .keycloak.keycloak_admin import KeycloakAdmin
from .keycloak.keycloak_client import KeycloakClient
from .schema._administration.api_key_dao import ApiKeyDao
from .schema._administration.auth_user_dao import AuthUserDao
from .schema._permission.api_key_permission_dao import ApiKeyPermissionDao
from .schema._permission.permission_dao import PermissionDao
from .schema._permission.role_dao import RoleDao
from .schema._permission.role_permission_dao import RolePermissionDao
from .schema._permission.role_user_dao import RoleUserDao
class AuthModule(Module):
dependencies = [DatabaseModule, (MySQLModule, PostgresModule)]
config = [KeycloakSettings]
singleton = [
KeycloakClient,
KeycloakAdmin,
AuthUserDao,
ApiKeyDao,
ApiKeyPermissionDao,
PermissionDao,
RoleDao,
RolePermissionDao,
RoleUserDao,
]
scoped = []
transient = []
@staticmethod
def configure(provider: ServiceProvider):
paths = {
ServerTypes.POSTGRES: "scripts/postgres",
ServerTypes.MYSQL: "scripts/mysql",
}
DatabaseModule.with_migrations(
provider, str(os.path.join(os.path.dirname(os.path.realpath(__file__)), paths[ServerType.server_type]))
)
@staticmethod
def with_permissions(*permissions: Type[Enum]):
from cpl.auth.permission.permissions_registry import PermissionsRegistry
for perm in permissions:
PermissionsRegistry.with_enum(perm)

View File

@@ -1,13 +1,15 @@
from keycloak import KeycloakAdmin as _KeycloakAdmin, KeycloakOpenIDConnection from keycloak import KeycloakAdmin as _KeycloakAdmin, KeycloakOpenIDConnection
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.keycloak_settings import KeycloakSettings from cpl.auth.keycloak_settings import KeycloakSettings
from cpl.auth.logger import AuthLogger
_logger = AuthLogger("keycloak")
class KeycloakAdmin(_KeycloakAdmin): class KeycloakAdmin(_KeycloakAdmin):
def __init__(self, logger: AuthLogger, settings: KeycloakSettings): def __init__(self, settings: KeycloakSettings):
# logger.info("Initializing Keycloak admin") _logger.info("Initializing Keycloak admin")
_connection = KeycloakOpenIDConnection( _connection = KeycloakOpenIDConnection(
server_url=settings.url, server_url=settings.url,
client_id=settings.client_id, client_id=settings.client_id,

View File

@@ -2,13 +2,15 @@ from typing import Optional
from keycloak import KeycloakOpenID from keycloak import KeycloakOpenID
from cpl.auth.logger import AuthLogger from cpl.auth.auth_logger import AuthLogger
from cpl.auth.keycloak_settings import KeycloakSettings from cpl.auth.keycloak_settings import KeycloakSettings
_logger = AuthLogger("keycloak")
class KeycloakClient(KeycloakOpenID): class KeycloakClient(KeycloakOpenID):
def __init__(self, logger: AuthLogger, settings: KeycloakSettings): def __init__(self, settings: KeycloakSettings):
KeycloakOpenID.__init__( KeycloakOpenID.__init__(
self, self,
server_url=settings.url, server_url=settings.url,
@@ -16,7 +18,7 @@ class KeycloakClient(KeycloakOpenID):
realm_name=settings.realm, realm_name=settings.realm,
client_secret_key=settings.client_secret, client_secret_key=settings.client_secret,
) )
logger.info("Initializing Keycloak client") _logger.info("Initializing Keycloak client")
def get_user_id(self, token: str) -> Optional[str]: def get_user_id(self, token: str) -> Optional[str]:
info = self.introspect(token) info = self.introspect(token)

View File

@@ -1,5 +1,5 @@
from cpl.core.utils.get_value import get_value from cpl.core.utils.get_value import get_value
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProviderABC
class KeycloakUser: class KeycloakUser:
@@ -32,5 +32,5 @@ class KeycloakUser:
def id(self) -> str: def id(self) -> str:
from cpl.auth import KeycloakAdmin from cpl.auth import KeycloakAdmin
keycloak_admin: KeycloakAdmin = get_provider().get_service(KeycloakAdmin) keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
return keycloak_admin.get_user_id(self._username) return keycloak_admin.get_user_id(self._username)

View File

@@ -1,7 +0,0 @@
from cpl.core.log.wrapped_logger import WrappedLogger
class AuthLogger(WrappedLogger):
def __init__(self):
WrappedLogger.__init__(self, "auth")

View File

@@ -1,4 +0,0 @@
from .permission_module import PermissionsModule
from .permission_seeder import PermissionSeeder
from .permissions import Permissions
from .permissions_registry import PermissionsRegistry

View File

@@ -1,17 +0,0 @@
from cpl.auth.auth_module import AuthModule
from cpl.auth.permission.permission_seeder import PermissionSeeder
from cpl.auth.permission.permissions import Permissions
from cpl.auth.permission.permissions_registry import PermissionsRegistry
from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.database_module import DatabaseModule
from cpl.dependency.module.module import Module
from cpl.dependency.service_collection import ServiceCollection
class PermissionsModule(Module):
dependencies = [DatabaseModule, AuthModule]
singleton = [(DataSeederABC, PermissionSeeder)]
@staticmethod
def register(collection: ServiceCollection):
PermissionsRegistry.with_enum(Permissions)

View File

@@ -14,13 +14,14 @@ from cpl.auth.schema import (
) )
from cpl.core.utils.get_value import get_value from cpl.core.utils.get_value import get_value
from cpl.database.abc.data_seeder_abc import DataSeederABC from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class PermissionSeeder(DataSeederABC): class PermissionSeeder(DataSeederABC):
def __init__( def __init__(
self, self,
logger: DBLogger,
permission_dao: PermissionDao, permission_dao: PermissionDao,
role_dao: RoleDao, role_dao: RoleDao,
role_permission_dao: RolePermissionDao, role_permission_dao: RolePermissionDao,
@@ -28,7 +29,6 @@ class PermissionSeeder(DataSeederABC):
api_key_permission_dao: ApiKeyPermissionDao, api_key_permission_dao: ApiKeyPermissionDao,
): ):
DataSeederABC.__init__(self) DataSeederABC.__init__(self)
self._logger = logger
self._permission_dao = permission_dao self._permission_dao = permission_dao
self._role_dao = role_dao self._role_dao = role_dao
self._role_permission_dao = role_permission_dao self._role_permission_dao = role_permission_dao
@@ -40,7 +40,7 @@ class PermissionSeeder(DataSeederABC):
possible_permissions = [permission for permission in PermissionsRegistry.get()] possible_permissions = [permission for permission in PermissionsRegistry.get()]
if len(permissions) == len(possible_permissions): if len(permissions) == len(possible_permissions):
self._logger.info("Permissions already existing") _logger.info("Permissions already existing")
await self._update_missing_descriptions() await self._update_missing_descriptions()
return return
@@ -53,7 +53,7 @@ class PermissionSeeder(DataSeederABC):
await self._permission_dao.delete_many(to_delete, hard_delete=True) await self._permission_dao.delete_many(to_delete, hard_delete=True)
self._logger.warning("Permissions incomplete") _logger.warning("Permissions incomplete")
permission_names = [permission.name for permission in permissions] permission_names = [permission.name for permission in permissions]
await self._permission_dao.create_many( await self._permission_dao.create_many(
[ [

View File

@@ -10,8 +10,7 @@ from cpl.core.log.logger import Logger
from cpl.core.typing import Id, SerialId from cpl.core.typing import Id, SerialId
from cpl.core.utils.credential_manager import CredentialManager from cpl.core.utils.credential_manager import CredentialManager
from cpl.database.abc.db_model_abc import DbModelABC from cpl.database.abc.db_model_abc import DbModelABC
from cpl.dependency import get_provider from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.dependency.service_provider import ServiceProvider
_logger = Logger(__name__) _logger = Logger(__name__)
@@ -48,7 +47,7 @@ class ApiKey(DbModelABC):
async def permissions(self): async def permissions(self):
from cpl.auth.schema._permission.api_key_permission_dao import ApiKeyPermissionDao from cpl.auth.schema._permission.api_key_permission_dao import ApiKeyPermissionDao
apiKeyPermissionDao = get_provider().get_service(ApiKeyPermissionDao) apiKeyPermissionDao = ServiceProviderABC.get_global_provider().get_service(ApiKeyPermissionDao)
return [await x.permission for x in await apiKeyPermissionDao.find_by_api_key_id(self.id)] return [await x.permission for x in await apiKeyPermissionDao.find_by_api_key_id(self.id)]

View File

@@ -3,12 +3,15 @@ from typing import Optional
from cpl.auth.schema._administration.api_key import ApiKey from cpl.auth.schema._administration.api_key import ApiKey
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class ApiKeyDao(DbModelDaoABC[ApiKey]): class ApiKeyDao(DbModelDaoABC[ApiKey]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, ApiKey, TableManager.get("api_keys")) DbModelDaoABC.__init__(self, __name__, ApiKey, TableManager.get("api_keys"))
self.attribute(ApiKey.identifier, str) self.attribute(ApiKey.identifier, str)
self.attribute(ApiKey.key, str, "keystring") self.attribute(ApiKey.key, str, "keystring")

View File

@@ -6,11 +6,13 @@ from async_property import async_property
from keycloak import KeycloakGetError from keycloak import KeycloakGetError
from cpl.auth.keycloak import KeycloakAdmin from cpl.auth.keycloak import KeycloakAdmin
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions import Permissions
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
from cpl.database.logger import DBLogger from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
_logger = AuthLogger(__name__)
class AuthUser(DbModelABC): class AuthUser(DbModelABC):
@@ -36,13 +38,12 @@ class AuthUser(DbModelABC):
return "ANONYMOUS" return "ANONYMOUS"
try: try:
keycloak = get_provider().get_service(KeycloakAdmin) keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
return keycloak.get_user(self._keycloak_id).get("username") return keycloak_admin.get_user(self._keycloak_id).get("username")
except KeycloakGetError as e: except KeycloakGetError as e:
return "UNKNOWN" return "UNKNOWN"
except Exception as e: except Exception as e:
logger = get_provider().get_service(DBLogger) _logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
return "UNKNOWN" return "UNKNOWN"
@property @property
@@ -51,39 +52,38 @@ class AuthUser(DbModelABC):
return "ANONYMOUS" return "ANONYMOUS"
try: try:
keycloak = get_provider().get_service(KeycloakAdmin) keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
return keycloak.get_user(self._keycloak_id).get("email") return keycloak_admin.get_user(self._keycloak_id).get("email")
except KeycloakGetError as e: except KeycloakGetError as e:
return "UNKNOWN" return "UNKNOWN"
except Exception as e: except Exception as e:
logger = get_provider().get_service(DBLogger) _logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
return "UNKNOWN" return "UNKNOWN"
@async_property @async_property
async def roles(self): async def roles(self):
from cpl.auth.schema._permission.role_user_dao import RoleUserDao from cpl.auth.schema._permission.role_user_dao import RoleUserDao
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao) role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
return [await x.role for x in await role_user_dao.get_by_user_id(self.id)] return [await x.role for x in await role_user_dao.get_by_user_id(self.id)]
@async_property @async_property
async def permissions(self): async def permissions(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
return await auth_user_dao.get_permissions(self.id) return await auth_user_dao.get_permissions(self.id)
async def has_permission(self, permission: Permissions) -> bool: async def has_permission(self, permission: Permissions) -> bool:
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
return await auth_user_dao.has_permission(self.id, permission) return await auth_user_dao.has_permission(self.id, permission)
async def anonymize(self): async def anonymize(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
self._keycloak_id = str(uuid.UUID(int=0)) self._keycloak_id = str(uuid.UUID(int=0))
await auth_user_dao.update(self) await auth_user_dao.update(self)

View File

@@ -4,14 +4,17 @@ from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema._administration.auth_user import AuthUser from cpl.auth.schema._administration.auth_user import AuthUser
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProviderABC
_logger = DBLogger(__name__)
class AuthUserDao(DbModelDaoABC[AuthUser]): class AuthUserDao(DbModelDaoABC[AuthUser]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, AuthUser, TableManager.get("auth_users")) DbModelDaoABC.__init__(self, __name__, AuthUser, TableManager.get("auth_users"))
self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId") self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId")
@@ -36,7 +39,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool: async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool:
from cpl.auth.schema._permission.permission_dao import PermissionDao from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao) permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value) p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
result = await self._db.select_map( result = await self._db.select_map(
f""" f"""

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbJoinModelABC from cpl.database.abc import DbJoinModelABC
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProviderABC
class ApiKeyPermission(DbJoinModelABC): class ApiKeyPermission(DbJoinModelABC):
@@ -31,7 +31,7 @@ class ApiKeyPermission(DbJoinModelABC):
async def api_key(self): async def api_key(self):
from cpl.auth.schema._administration.api_key_dao import ApiKeyDao from cpl.auth.schema._administration.api_key_dao import ApiKeyDao
api_key_dao: ApiKeyDao = get_provider().get_service(ApiKeyDao) api_key_dao: ApiKeyDao = ServiceProviderABC.get_global_service(ApiKeyDao)
return await api_key_dao.get_by_id(self._api_key_id) return await api_key_dao.get_by_id(self._api_key_id)
@property @property
@@ -42,5 +42,5 @@ class ApiKeyPermission(DbJoinModelABC):
async def permission(self): async def permission(self):
from cpl.auth.schema._permission.permission_dao import PermissionDao from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao) permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
return await permission_dao.get_by_id(self._permission_id) return await permission_dao.get_by_id(self._permission_id)

View File

@@ -1,12 +1,15 @@
from cpl.auth.schema._permission.api_key_permission import ApiKeyPermission from cpl.auth.schema._permission.api_key_permission import ApiKeyPermission
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class ApiKeyPermissionDao(DbModelDaoABC[ApiKeyPermission]): class ApiKeyPermissionDao(DbModelDaoABC[ApiKeyPermission]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, ApiKeyPermission, TableManager.get("api_key_permissions")) DbModelDaoABC.__init__(self, __name__, ApiKeyPermission, TableManager.get("api_key_permissions"))
self.attribute(ApiKeyPermission.api_key_id, int) self.attribute(ApiKeyPermission.api_key_id, int)
self.attribute(ApiKeyPermission.permission_id, int) self.attribute(ApiKeyPermission.permission_id, int)

View File

@@ -3,12 +3,15 @@ from typing import Optional
from cpl.auth.schema._permission.permission import Permission from cpl.auth.schema._permission.permission import Permission
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class PermissionDao(DbModelDaoABC[Permission]): class PermissionDao(DbModelDaoABC[Permission]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, Permission, TableManager.get("permissions")) DbModelDaoABC.__init__(self, __name__, Permission, TableManager.get("permissions"))
self.attribute(Permission.name, str) self.attribute(Permission.name, str)
self.attribute(Permission.description, Optional[str]) self.attribute(Permission.description, Optional[str])

View File

@@ -6,7 +6,7 @@ from async_property import async_property
from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions import Permissions
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProviderABC
class Role(DbModelABC): class Role(DbModelABC):
@@ -44,22 +44,22 @@ class Role(DbModelABC):
async def permissions(self): async def permissions(self):
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao) role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
return [await x.permission for x in await role_permission_dao.get_by_role_id(self.id)] return [await x.permission for x in await role_permission_dao.get_by_role_id(self.id)]
@async_property @async_property
async def users(self): async def users(self):
from cpl.auth.schema._permission.role_user_dao import RoleUserDao from cpl.auth.schema._permission.role_user_dao import RoleUserDao
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao) role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
return [await x.user for x in await role_user_dao.get_by_role_id(self.id)] return [await x.user for x in await role_user_dao.get_by_role_id(self.id)]
async def has_permission(self, permission: Permissions) -> bool: async def has_permission(self, permission: Permissions) -> bool:
from cpl.auth.schema._permission.permission_dao import PermissionDao from cpl.auth.schema._permission.permission_dao import PermissionDao
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao) permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao) role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
p = await permission_dao.get_by_name(permission.value) p = await permission_dao.get_by_name(permission.value)

View File

@@ -1,11 +1,14 @@
from cpl.auth.schema._permission.role import Role from cpl.auth.schema._permission.role import Role
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class RoleDao(DbModelDaoABC[Role]): class RoleDao(DbModelDaoABC[Role]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, Role, TableManager.get("roles")) DbModelDaoABC.__init__(self, __name__, Role, TableManager.get("roles"))
self.attribute(Role.name, str) self.attribute(Role.name, str)
self.attribute(Role.description, str) self.attribute(Role.description, str)

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProviderABC
class RolePermission(DbModelABC): class RolePermission(DbModelABC):
@@ -31,7 +31,7 @@ class RolePermission(DbModelABC):
async def role(self): async def role(self):
from cpl.auth.schema._permission.role_dao import RoleDao from cpl.auth.schema._permission.role_dao import RoleDao
role_dao: RoleDao = get_provider().get_service(RoleDao) role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
return await role_dao.get_by_id(self._role_id) return await role_dao.get_by_id(self._role_id)
@property @property
@@ -42,5 +42,5 @@ class RolePermission(DbModelABC):
async def permission(self): async def permission(self):
from cpl.auth.schema._permission.permission_dao import PermissionDao from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao) permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
return await permission_dao.get_by_id(self._permission_id) return await permission_dao.get_by_id(self._permission_id)

View File

@@ -1,12 +1,15 @@
from cpl.auth.schema._permission.role_permission import RolePermission from cpl.auth.schema._permission.role_permission import RolePermission
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class RolePermissionDao(DbModelDaoABC[RolePermission]): class RolePermissionDao(DbModelDaoABC[RolePermission]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, RolePermission, TableManager.get("role_permissions")) DbModelDaoABC.__init__(self, __name__, RolePermission, TableManager.get("role_permissions"))
self.attribute(RolePermission.role_id, int) self.attribute(RolePermission.role_id, int)
self.attribute(RolePermission.permission_id, int) self.attribute(RolePermission.permission_id, int)

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbJoinModelABC from cpl.database.abc import DbJoinModelABC
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProviderABC
class RoleUser(DbJoinModelABC): class RoleUser(DbJoinModelABC):
@@ -31,7 +31,7 @@ class RoleUser(DbJoinModelABC):
async def user(self): async def user(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
return await auth_user_dao.get_by_id(self._user_id) return await auth_user_dao.get_by_id(self._user_id)
@property @property
@@ -42,5 +42,5 @@ class RoleUser(DbJoinModelABC):
async def role(self): async def role(self):
from cpl.auth.schema._permission.role_dao import RoleDao from cpl.auth.schema._permission.role_dao import RoleDao
role_dao: RoleDao = get_provider().get_service(RoleDao) role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
return await role_dao.get_by_id(self._role_id) return await role_dao.get_by_id(self._role_id)

View File

@@ -1,12 +1,15 @@
from cpl.auth.schema._permission.role_user import RoleUser from cpl.auth.schema._permission.role_user import RoleUser
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class RoleUserDao(DbModelDaoABC[RoleUser]): class RoleUserDao(DbModelDaoABC[RoleUser]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, RoleUser, TableManager.get("role_users")) DbModelDaoABC.__init__(self, __name__, RoleUser, TableManager.get("role_users"))
self.attribute(RoleUser.role_id, int) self.attribute(RoleUser.role_id, int)
self.attribute(RoleUser.user_id, int) self.attribute(RoleUser.user_id, int)

View File

@@ -1,18 +1,17 @@
from contextvars import ContextVar from contextvars import ContextVar
from typing import Optional from typing import Optional
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.schema._administration.auth_user import AuthUser from cpl.auth.schema._administration.auth_user import AuthUser
from cpl.dependency import get_provider
_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None) _user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None)
_logger = AuthLogger(__name__)
def set_user(user: Optional[AuthUser]):
from cpl.core.log.logger_abc import LoggerABC
logger = get_provider().get_service(LoggerABC) def set_user(user_id: Optional[AuthUser]):
logger.trace("Setting user context", user.id) _logger.trace("Setting user context", user_id)
_user_context.set(user) _user_context.set(user_id)
def get_user() -> Optional[AuthUser]: def get_user() -> Optional[AuthUser]:

View File

@@ -3,25 +3,13 @@ import traceback
from cpl.core.console import Console from cpl.core.console import Console
def dependency_error(src: str, package_name: str, e: ImportError = None) -> None: def dependency_error(package_name: str, e: ImportError) -> None:
Console.error(f"'{package_name}' is required to use feature: {src}. Please install it and try again.") Console.error(f"'{package_name}' is required to use this feature. Please install it and try again.")
tb = traceback.format_exc() tb = traceback.format_exc()
if not tb.startswith("NoneType: None"): if not tb.startswith("NoneType: None"):
Console.error("->", tb) Console.write_line("->", tb)
elif e is not None: elif e is not None:
Console.error(f"-> {str(e)}") Console.write_line("->", str(e))
exit(1)
def module_dependency_error(src: str, module: str, e: ImportError = None) -> None:
Console.error(f"'{module}' is required by '{src}'. Please initialize it with `add_module({module})`.")
tb = traceback.format_exc()
if not tb.startswith("NoneType: None"):
Console.error("->", tb)
elif e is not None:
Console.error(f"-> {str(e)}")
exit(1) exit(1)

View File

@@ -2,4 +2,3 @@ from .logger import Logger
from .logger_abc import LoggerABC from .logger_abc import LoggerABC
from .log_level import LogLevel from .log_level import LogLevel
from .log_settings import LogSettings from .log_settings import LogSettings
from .structured_logger import StructuredLogger

View File

@@ -93,13 +93,14 @@ class Logger(LoggerABC):
def _log(self, level: LogLevel, *messages: Messages): def _log(self, level: LogLevel, *messages: Messages):
try: try:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
formatted_message = self._format_message(level.value, timestamp, *messages)
self._write_log_to_file(level, self._file_format_message(level.value, timestamp, *messages)) self._write_log_to_file(level, formatted_message)
self._write_to_console(level, self._console_format_message(level.value, timestamp, *messages)) self._write_to_console(level, formatted_message)
except Exception as e: except Exception as e:
print(f"Error while logging: {e} -> {traceback.format_exc()}") print(f"Error while logging: {e} -> {traceback.format_exc()}")
def _file_format_message(self, level: str, timestamp, *messages: Messages) -> str: def _format_message(self, level: str, timestamp, *messages: Messages) -> str:
if isinstance(messages, tuple): if isinstance(messages, tuple):
messages = list(messages) messages = list(messages)
@@ -118,24 +119,6 @@ class Logger(LoggerABC):
return message return message
def _console_format_message(self, level: str, timestamp, *messages: Messages) -> str:
if isinstance(messages, tuple):
messages = list(messages)
if not isinstance(messages, list):
messages = [messages]
messages = [str(message) for message in messages if message is not None]
message = f"[{level.upper():^3}]"
message += f" [{self._file_prefix}]"
if self._source is not None:
message += f" - [{self._source}]"
message += f": {' '.join(messages)}"
return message
def header(self, string: str): def header(self, string: str):
self._log(LogLevel.info, string) self._log(LogLevel.info, string)

View File

@@ -11,10 +11,7 @@ class LoggerABC(ABC):
def set_level(self, level: LogLevel): ... def set_level(self, level: LogLevel): ...
@abstractmethod @abstractmethod
def _file_format_message(self, level: str, timestamp, *messages: Messages) -> str: ... def _format_message(self, level: str, timestamp, *messages: Messages) -> str: ...
@abstractmethod
def _console_format_message(self, level: str, timestamp, *messages: Messages) -> str: ...
@abstractmethod @abstractmethod
def header(self, string: str): def header(self, string: str):

View File

@@ -1,98 +0,0 @@
import asyncio
import importlib.util
import json
from datetime import datetime
from starlette.requests import Request
from cpl.core.log.logger import Logger
from cpl.core.typing import Source, Messages
from cpl.dependency.context import get_provider
class StructuredLogger(Logger):
def __init__(self, source: Source, file_prefix: str = None):
Logger.__init__(self, source, file_prefix)
@property
def log_file(self):
return f"logs/{self._file_prefix}_{datetime.now().strftime('%Y-%m-%d')}.jsonl"
def _file_format_message(self, level: str, timestamp: str, *messages: Messages) -> str:
structured_message = {
"timestamp": timestamp,
"level": level.upper(),
"source": self._source,
"messages": messages,
}
self._enrich_message_with_request(structured_message)
self._enrich_message_with_user(structured_message)
return json.dumps(structured_message, ensure_ascii=False)
@staticmethod
def _scope_to_json(request: Request, include_headers: bool = False) -> dict:
scope = dict(request.scope)
def convert(value):
if isinstance(value, bytes):
return value.decode("utf-8")
if isinstance(value, (list, tuple)):
return [convert(v) for v in value]
if isinstance(value, dict):
return {str(k): convert(v) for k, v in value.items()}
if not isinstance(value, (str, int, float, bool, type(None))):
return str(value)
return value
serializable_scope = {str(k): convert(v) for k, v in scope.items()}
if not include_headers and "headers" in serializable_scope:
serializable_scope["headers"] = "<omitted>"
return serializable_scope
def _enrich_message_with_request(self, message: dict):
if importlib.util.find_spec("cpl.api") is None:
return
from cpl.api.middleware.request import get_request
from starlette.requests import Request
request = get_request()
if request is None:
return
message["request"] = {
"url": str(request.url),
"method": request.method,
"scope": self._scope_to_json(request),
}
if isinstance(request, Request) and request.scope == "http":
request: Request = request # fix typing for IDEs
message["request"]["data"] = asyncio.create_task(request.body())
@staticmethod
def _enrich_message_with_user(message: dict):
if importlib.util.find_spec("cpl-auth") is None:
return
from cpl.core.ctx import get_user
user = get_user()
if user is None:
return
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin
keycloak = get_provider().get_service(KeycloakAdmin)
kc_user = keycloak.get_user(user.keycloak_id)
message["user"] = {
"id": str(user.id),
"username": kc_user.get("username"),
"email": kc_user.get("email"),
}

View File

@@ -1,105 +0,0 @@
import inspect
from typing import Type
from cpl.core.log import LoggerABC, LogLevel, StructuredLogger
from cpl.core.typing import Messages
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
class WrappedLogger(LoggerABC):
def __init__(self, file_prefix: str):
LoggerABC.__init__(self)
assert file_prefix is not None and file_prefix != "", "file_prefix must be a non-empty string"
self._source = None
self._file_prefix = file_prefix
self._set_logger()
@inject
def _set_logger(self, services: ServiceProvider):
from cpl.core.log import Logger
t_logger: Type[Logger] = services.get_service_type(LoggerABC)
if t_logger is None:
raise Exception("No LoggerABC service registered in ServiceProvider")
self._logger = t_logger(self._source, self._file_prefix)
def set_level(self, level: LogLevel):
self._logger.set_level(level)
def _file_format_message(self, level: str, timestamp, *messages: Messages) -> str:
return self._logger._file_format_message(level, timestamp, *messages)
def _console_format_message(self, level: str, timestamp, *messages: Messages) -> str:
return self._logger._console_format_message(level, timestamp, *messages)
@staticmethod
def _get_source() -> str | None:
stack = inspect.stack()
if len(stack) <= 1:
return None
from cpl.dependency import ServiceCollection
ignore_classes = [
ServiceProvider,
ServiceProvider.__subclasses__(),
ServiceCollection,
WrappedLogger,
WrappedLogger.__subclasses__(),
StructuredLogger,
]
ignore_modules = [x.__module__ for x in ignore_classes if isinstance(x, type)]
for i, frame_info in enumerate(stack[1:]):
module = inspect.getmodule(frame_info.frame)
if module is None:
continue
if module.__name__ in ignore_classes or module in ignore_classes:
continue
if module in ignore_modules or module.__name__ in ignore_modules:
continue
if module.__name__ != __name__:
return module.__name__
return None
def _set_source(self):
self._source = self._get_source()
self._set_logger()
def header(self, string: str):
self._set_source()
self._logger.header(string)
def trace(self, *messages: Messages):
self._set_source()
self._logger.trace(*messages)
def debug(self, *messages: Messages):
self._set_source()
self._logger.debug(*messages)
def info(self, *messages: Messages):
self._set_source()
self._logger.info(*messages)
def warning(self, *messages: Messages):
self._set_source()
self._logger.warning(*messages)
def error(self, messages: str, e: Exception = None):
self._set_source()
self._logger.error(messages, e)
def fatal(self, messages: str, e: Exception = None):
self._set_source()
self._logger.fatal(messages, e)

View File

@@ -14,4 +14,3 @@ UuidId = str | UUID
SerialId = int SerialId = int
Id = UuidId | SerialId Id = UuidId | SerialId
TNumber = int | float | complex

View File

@@ -1,57 +0,0 @@
import time
import tracemalloc
from typing import List, Callable
from cpl.core.console import Console
class Benchmark:
@staticmethod
def all(label: str, func: Callable, iterations: int = 5):
times: List[float] = []
mems: List[float] = []
for _ in range(iterations):
start = time.perf_counter()
func()
end = time.perf_counter()
times.append(end - start)
for _ in range(iterations):
tracemalloc.start()
func()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
mems.append(peak)
avg_time = sum(times) / len(times)
avg_mem = sum(mems) / len(mems) / (1024 * 1024)
Console.write_line(f"{label:20s} -> min {min(times):.6f}s avg {avg_time:.6f}s mem {avg_mem:.8f} MB")
@staticmethod
def time(label: str, func: Callable, iterations: int = 5):
times: List[float] = []
for _ in range(iterations):
start = time.perf_counter()
func()
end = time.perf_counter()
times.append(end - start)
avg_time = sum(times) / len(times)
Console.write_line(f"{label:20s} -> min {min(times):.6f}s avg {avg_time:.6f}s")
@staticmethod
def memory(label: str, func: Callable, iterations: int = 5):
mems: List[float] = []
for _ in range(iterations):
tracemalloc.start()
func()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
mems.append(peak)
avg_mem = sum(mems) / len(mems) / (1024 * 1024)
Console.write_line(f"{label:20s} -> mem {avg_mem:.2f} MB")

View File

@@ -1,100 +0,0 @@
import threading
import time
from typing import Generic
from cpl.core.typing import T
class Cache(Generic[T]):
def __init__(self, default_ttl: int = None, cleanup_interval: int = 60, t: type = None):
self._store = {}
self._default_ttl = default_ttl
self._lock = threading.Lock()
self._cleanup_interval = cleanup_interval
self._stop_event = threading.Event()
self._type = t
# Start background cleanup thread
self._thread = threading.Thread(target=self._auto_cleanup, daemon=True)
self._thread.start()
def set(self, key: str, value: T, ttl: int = None) -> None:
"""Store a value in the cache with optional TTL override."""
expire_at = None
ttl = ttl if ttl is not None else self._default_ttl
if ttl is not None:
expire_at = time.time() + ttl
with self._lock:
self._store[key] = (value, expire_at)
def get(self, key: str) -> T | None:
"""Retrieve a value from the cache if not expired."""
with self._lock:
item = self._store.get(key)
if not item:
return None
value, expire_at = item
if expire_at and expire_at < time.time():
# Expired -> remove and return None
del self._store[key]
return None
return value
def get_all(self) -> list[T]:
"""Retrieve all non-expired values from the cache."""
now = time.time()
with self._lock:
valid_items = []
expired_keys = []
for k, (v, exp) in self._store.items():
if exp and exp < now:
expired_keys.append(k)
else:
valid_items.append(v)
for k in expired_keys:
del self._store[k]
return valid_items
def has(self, key: str) -> bool:
"""Check if a key exists and is not expired."""
with self._lock:
item = self._store.get(key)
if not item:
return False
_, expire_at = item
if expire_at and expire_at < time.time():
# Expired -> remove and return False
del self._store[key]
return False
return True
def delete(self, key: str) -> None:
"""Remove an item from the cache."""
with self._lock:
self._store.pop(key, None)
def clear(self) -> None:
"""Clear the entire cache."""
with self._lock:
self._store.clear()
def _auto_cleanup(self):
"""Background thread to clean expired items."""
while not self._stop_event.is_set():
self.cleanup()
self._stop_event.wait(self._cleanup_interval)
def cleanup(self) -> None:
"""Remove expired items immediately."""
now = time.time()
with self._lock:
expired_keys = [k for k, (_, exp) in self._store.items() if exp and exp < now]
for k in expired_keys:
del self._store[k]
def stop(self):
"""Stop the background cleanup thread."""
self._stop_event.set()
self._thread.join()

View File

@@ -1,48 +0,0 @@
from typing import Any
class Number:
@staticmethod
def is_number(value: Any) -> bool:
"""Check if the value is a number (int or float)."""
return isinstance(value, (int, float, complex))
@staticmethod
def to_number(value: Any) -> int | float | complex:
"""
Convert a given value into int, float, or complex.
Raises ValueError if conversion is not possible.
"""
if isinstance(value, (int, float, complex)):
return value
if isinstance(value, str):
value = value.strip()
for caster in (int, float, complex):
try:
return caster(value)
except ValueError:
continue
raise ValueError(f"Cannot convert string '{value}' to number.")
if isinstance(value, bool):
return int(value)
try:
return int(value)
except Exception:
pass
try:
return float(value)
except Exception:
pass
try:
return complex(value)
except Exception:
pass
raise ValueError(f"Cannot convert type {type(value)} to number.")

View File

@@ -114,15 +114,12 @@ class String:
characters = [] characters = []
if letters: if letters:
characters.extend(string.ascii_letters) characters.append(string.ascii_letters)
if digits: if digits:
characters.extend(string.digits) characters.append(string.digits)
if special_characters: if special_characters:
characters.extend(string.punctuation) characters.append(string.punctuation)
x = "".join(random.choice(list(characters)) for _ in range(length)) if characters else "" return "".join(random.choice(characters) for _ in range(length)) if characters else ""
if len(x) != length:
raise Exception("No characters selected to generate random string")
return x

View File

@@ -1,5 +1,77 @@
import os
from typing import Type
from cpl.application.abc import ApplicationABC as _ApplicationABC
from cpl.dependency import ServiceCollection as _ServiceCollection
from . import mysql as _mysql from . import mysql as _mysql
from . import postgres as _postgres from . import postgres as _postgres
from .database_module import DatabaseModule
from .logger import DBLogger
from .table_manager import TableManager from .table_manager import TableManager
def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC:
from cpl.application.host import Host
from cpl.database.service.migration_service import MigrationService
migration_service = self._services.get_service(MigrationService)
migration_service.with_directory(os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts"))
if isinstance(paths, str):
paths = [paths]
for path in paths:
migration_service.with_directory(path)
Host.run(migration_service.migrate)
return self
def _with_seeders(self: _ApplicationABC) -> _ApplicationABC:
from cpl.database.service.seeder_service import SeederService
from cpl.application.host import Host
seeder_service: SeederService = self._services.get_service(SeederService)
Host.run(seeder_service.seed)
return self
def _add(collection: _ServiceCollection, db_context: Type, default_port: int, server_type: str):
from cpl.core.console import Console
from cpl.core.configuration import Configuration
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.model.server_type import ServerTypes, ServerType
from cpl.database.model.database_settings import DatabaseSettings
from cpl.database.service.migration_service import MigrationService
from cpl.database.service.seeder_service import SeederService
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
try:
ServerType.set_server_type(ServerTypes(server_type))
Configuration.set("DB_DEFAULT_PORT", default_port)
collection.add_singleton(DBContextABC, db_context)
collection.add_singleton(ExecutedMigrationDao)
collection.add_singleton(MigrationService)
collection.add_singleton(SeederService)
except ImportError as e:
Console.error("cpl-database is not installed", str(e))
def add_mysql(collection: _ServiceCollection):
from cpl.database.mysql.db_context import DBContext
from cpl.database.model import ServerTypes
_add(collection, DBContext, 3306, ServerTypes.MYSQL.value)
def add_postgres(collection: _ServiceCollection):
from cpl.database.mysql.db_context import DBContext
from cpl.database.model import ServerTypes
_add(collection, DBContext, 5432, ServerTypes.POSTGRES.value)
_ServiceCollection.with_module(add_mysql, _mysql.__name__)
_ServiceCollection.with_module(add_postgres, _postgres.__name__)
_ApplicationABC.extend(_ApplicationABC.with_migrations, _with_migrations)
_ApplicationABC.extend(_ApplicationABC.with_seeders, _with_seeders)

View File

@@ -1,6 +1,4 @@
from .connection_abc import ConnectionABC from .connection_abc import ConnectionABC
from .data_access_object_abc import DataAccessObjectABC
from .data_seeder_abc import DataSeederABC
from .db_context_abc import DBContextABC from .db_context_abc import DBContextABC
from .db_join_model_abc import DbJoinModelABC from .db_join_model_abc import DbJoinModelABC
from .db_model_abc import DbModelABC from .db_model_abc import DbModelABC

View File

@@ -9,20 +9,25 @@ from cpl.core.utils.get_value import get_value
from cpl.core.utils.string import String from cpl.core.utils.string import String
from cpl.database.abc.db_context_abc import DBContextABC from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.const import DATETIME_FORMAT from cpl.database.const import DATETIME_FORMAT
from cpl.database.db_logger import DBLogger
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.database.logger import DBLogger
from cpl.database.model.server_type import ServerType, ServerTypes
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
from cpl.dependency.context import get_provider
class DataAccessObjectABC(ABC, Generic[T_DBM]): class DataAccessObjectABC(ABC, Generic[T_DBM]):
@abstractmethod @abstractmethod
def __init__(self, model_type: Type[T_DBM], table_name: str): def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
self._db = get_provider().get_service(DBContextABC) from cpl.dependency.service_provider_abc import ServiceProviderABC
self._logger = get_provider().get_service(DBLogger)
self._db = ServiceProviderABC.get_global_service(DBContextABC)
self._logger = DBLogger(source)
self._model_type = model_type
self._table_name = table_name
self._logger = DBLogger(source)
self._model_type = model_type self._model_type = model_type
self._table_name = table_name self._table_name = table_name
@@ -352,13 +357,13 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
values = f"{await self._get_editor_id(obj) if not skip_editor else ''}{f', {values}' if not skip_editor and len(values) > 0 else f'{values}'}" values = f"{await self._get_editor_id(obj) if not skip_editor else ''}{f', {values}' if not skip_editor and len(values) > 0 else f'{values}'}"
return f""" return f"""
INSERT INTO {self._table_name} ( INSERT INTO {self._table_name} (
{fields} {fields}
) VALUES ( ) VALUES (
{values} {values}
) )
{"RETURNING {self.__primary_key};" if ServerType.server_type == ServerTypes.POSTGRES else ";SELECT LAST_INSERT_ID();"} RETURNING {self.__primary_key};
""" """
async def create(self, obj: T_DBM, skip_editor=False) -> int: async def create(self, obj: T_DBM, skip_editor=False) -> int:
self._logger.debug(f"create {type(obj).__name__} {obj.__dict__}") self._logger.debug(f"create {type(obj).__name__} {obj.__dict__}")

View File

@@ -2,7 +2,7 @@ from abc import abstractmethod
from datetime import datetime from datetime import datetime
from typing import Type from typing import Type
from cpl.database.table_manager import TableManager from cpl.database import TableManager
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.database.abc.db_model_abc import DbModelABC from cpl.database.abc.db_model_abc import DbModelABC
@@ -10,8 +10,8 @@ from cpl.database.abc.db_model_abc import DbModelABC
class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]): class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]):
@abstractmethod @abstractmethod
def __init__(self, model_type: Type[T_DBM], table_name: str): def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
DataAccessObjectABC.__init__(self, model_type, table_name) DataAccessObjectABC.__init__(self, source, model_type, table_name)
self.attribute(DbModelABC.id, int, ignore=True) self.attribute(DbModelABC.id, int, ignore=True)
self.attribute(DbModelABC.deleted, bool) self.attribute(DbModelABC.deleted, bool)

View File

@@ -1,33 +0,0 @@
from cpl.database.model.database_settings import DatabaseSettings
from cpl.database.mysql.mysql_module import MySQLModule
from cpl.database.postgres.postgres_module import PostgresModule
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
from cpl.database.service.migration_service import MigrationService
from cpl.database.service.seeder_service import SeederService
from cpl.dependency.module.module import Module
from cpl.dependency.service_provider import ServiceProvider
class DatabaseModule(Module):
dependencies = [(MySQLModule, PostgresModule)]
config = [DatabaseSettings]
singleton = [
ExecutedMigrationDao,
MigrationService,
SeederService,
]
@classmethod
def configure(cls, provider: ServiceProvider): ...
@staticmethod
def with_migrations(services: ServiceProvider, *paths: str | list[str]):
from cpl.database.service.migration_service import MigrationService
migration_service = services.get_service(MigrationService)
if isinstance(paths, str):
paths = [paths]
for path in paths:
migration_service.with_directory(path)

View File

@@ -0,0 +1,8 @@
from cpl.core.log import Logger
from cpl.core.typing import Source
class DBLogger(Logger):
def __init__(self, source: Source):
Logger.__init__(self, source, "db")

View File

@@ -1,7 +0,0 @@
from cpl.core.log.wrapped_logger import WrappedLogger
class DBLogger(WrappedLogger):
def __init__(self):
WrappedLogger.__init__(self, "db")

View File

@@ -21,4 +21,4 @@ class DatabaseSettings(ConfigurationModelABC):
self.option("use_unicode", bool, False) self.option("use_unicode", bool, False)
self.option("buffered", bool, False) self.option("buffered", bool, False)
self.option("auth_plugin", str, "caching_sha2_password") self.option("auth_plugin", str, "caching_sha2_password")
self.option("ssl_disabled", bool, True) self.option("ssl_disabled", bool, False)

View File

@@ -15,11 +15,6 @@ class ServerType:
assert isinstance(server_type, ServerTypes), f"Expected ServerType but got {type(server_type)}" assert isinstance(server_type, ServerTypes), f"Expected ServerType but got {type(server_type)}"
cls._server_type = server_type cls._server_type = server_type
@classmethod
@property
def has_server_type(cls) -> bool:
return cls._server_type is not None
@classmethod @classmethod
@property @property
def server_type(cls) -> ServerTypes: def server_type(cls) -> ServerTypes:

View File

@@ -1,4 +0,0 @@
from .connection import DatabaseConnection
from .db_context import DBContext
from .mysql_module import MySQLModule
from .mysql_pool import MySQLPool

View File

@@ -4,17 +4,18 @@ from typing import Any, List, Dict, Tuple, Union
from mysql.connector import Error as MySQLError, PoolError from mysql.connector import Error as MySQLError, PoolError
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.core.environment import Environment
from cpl.database.abc.db_context_abc import DBContextABC from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
from cpl.database.model.database_settings import DatabaseSettings from cpl.database.model.database_settings import DatabaseSettings
from cpl.database.mysql.mysql_pool import MySQLPool from cpl.database.mysql.mysql_pool import MySQLPool
_logger = DBLogger(__name__)
class DBContext(DBContextABC): class DBContext(DBContextABC):
def __init__(self, logger: DBLogger): def __init__(self):
DBContextABC.__init__(self) DBContextABC.__init__(self)
self._logger = logger
self._pool: MySQLPool = None self._pool: MySQLPool = None
self._fails = 0 self._fails = 0
@@ -22,62 +23,62 @@ class DBContext(DBContextABC):
def connect(self, database_settings: DatabaseSettings): def connect(self, database_settings: DatabaseSettings):
try: try:
self._logger.debug("Connecting to database") _logger.debug("Connecting to database")
self._pool = MySQLPool( self._pool = MySQLPool(
database_settings, database_settings,
) )
self._logger.info("Connected to database") _logger.info("Connected to database")
except Exception as e: except Exception as e:
self._logger.fatal("Connecting to database failed", e) _logger.fatal("Connecting to database failed", e)
async def execute(self, statement: str, args=None, multi=True) -> List[List]: async def execute(self, statement: str, args=None, multi=True) -> List[List]:
self._logger.trace(f"execute {statement} with args: {args}") _logger.trace(f"execute {statement} with args: {args}")
return await self._pool.execute(statement, args, multi) return await self._pool.execute(statement, args, multi)
async def select_map(self, statement: str, args=None) -> List[Dict]: async def select_map(self, statement: str, args=None) -> List[Dict]:
self._logger.trace(f"select {statement} with args: {args}") _logger.trace(f"select {statement} with args: {args}")
try: try:
return await self._pool.select_map(statement, args) return await self._pool.select_map(statement, args)
except (MySQLError, PoolError) as e: except (MySQLError, PoolError) as e:
if self._fails >= 3: if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4() uid = uuid.uuid4()
raise Exception( raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
) )
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1 self._fails += 1
try: try:
self._logger.debug("Retry select") _logger.debug("Retry select")
return await self.select_map(statement, args) return await self.select_map(statement, args)
except Exception as e: except Exception as e:
pass pass
return [] return []
except Exception as e: except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
raise e raise e
async def select(self, statement: str, args=None) -> Union[List[str], List[Tuple], List[Any]]: async def select(self, statement: str, args=None) -> Union[List[str], List[Tuple], List[Any]]:
self._logger.trace(f"select {statement} with args: {args}") _logger.trace(f"select {statement} with args: {args}")
try: try:
return await self._pool.select(statement, args) return await self._pool.select(statement, args)
except (MySQLError, PoolError) as e: except (MySQLError, PoolError) as e:
if self._fails >= 3: if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4() uid = uuid.uuid4()
raise Exception( raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
) )
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1 self._fails += 1
try: try:
self._logger.debug("Retry select") _logger.debug("Retry select")
return await self.select(statement, args) return await self.select(statement, args)
except Exception as e: except Exception as e:
pass pass
return [] return []
except Exception as e: except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
raise e raise e

View File

@@ -1,17 +0,0 @@
from cpl.core.configuration.configuration import Configuration
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.model.database_settings import DatabaseSettings
from cpl.database.model.server_type import ServerTypes, ServerType
from cpl.database.mysql.db_context import DBContext
from cpl.dependency.module.module import Module
from cpl.dependency.service_collection import ServiceCollection
class MySQLModule(Module):
config = [DatabaseSettings]
singleton = [(DBContextABC, DBContext)]
@staticmethod
def register(collection: ServiceCollection):
ServerType.set_server_type(ServerTypes(ServerTypes.MYSQL.value))
Configuration.set("DB_DEFAULT_PORT", 3306)

View File

@@ -4,9 +4,10 @@ import sqlparse
from mysql.connector.aio import MySQLConnectionPool from mysql.connector.aio import MySQLConnectionPool
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
from cpl.database.model import DatabaseSettings from cpl.database.model import DatabaseSettings
from cpl.dependency.context import get_provider
_logger = DBLogger(__name__)
class MySQLPool: class MySQLPool:
@@ -18,31 +19,26 @@ class MySQLPool:
"user": database_settings.user, "user": database_settings.user,
"password": database_settings.password, "password": database_settings.password,
"database": database_settings.database, "database": database_settings.database,
"charset": database_settings.charset, "ssl_disabled": True,
"use_unicode": database_settings.use_unicode,
"buffered": database_settings.buffered,
"auth_plugin": database_settings.auth_plugin,
"ssl_disabled": database_settings.ssl_disabled,
} }
self._pool: Optional[MySQLConnectionPool] = None self._pool: Optional[MySQLConnectionPool] = None
async def _get_pool(self): async def _get_pool(self):
if self._pool is None: if self._pool is None:
try: self._pool = MySQLConnectionPool(
self._pool = MySQLConnectionPool( pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig
pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig )
) await self._pool.initialize_pool()
await self._pool.initialize_pool()
con = await self._pool.get_connection() con = await self._pool.get_connection()
try:
async with await con.cursor() as cursor: async with await con.cursor() as cursor:
await cursor.execute("SELECT 1") await cursor.execute("SELECT 1")
await cursor.fetchall() await cursor.fetchall()
await con.close()
except Exception as e: except Exception as e:
logger = get_provider().get_service(DBLogger) _logger.fatal(f"Error connecting to the database: {e}")
logger.fatal(f"Error connecting to the database", e) finally:
await con.close()
return self._pool return self._pool

View File

@@ -1,4 +0,0 @@
from .db_context import DBContext
from .postgres_module import PostgresModule
from .postgres_pool import PostgresPool
from .sql_select_builder import SQLSelectBuilder

View File

@@ -7,16 +7,16 @@ from psycopg_pool import PoolTimeout
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.database.abc.db_context_abc import DBContextABC from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.logger import DBLogger from cpl.database.database_settings import DatabaseSettings
from cpl.database.model import DatabaseSettings from cpl.database.db_logger import DBLogger
from cpl.database.postgres.postgres_pool import PostgresPool from cpl.database.postgres.postgres_pool import PostgresPool
_logger = DBLogger(__name__)
class DBContext(DBContextABC): class DBContext(DBContextABC):
def __init__(self, logger: DBLogger): def __init__(self):
DBContextABC.__init__(self) DBContextABC.__init__(self)
self._logger = logger
self._pool: PostgresPool = None self._pool: PostgresPool = None
self._fails = 0 self._fails = 0
@@ -24,63 +24,63 @@ class DBContext(DBContextABC):
def connect(self, database_settings: DatabaseSettings): def connect(self, database_settings: DatabaseSettings):
try: try:
self._logger.debug("Connecting to database") _logger.debug("Connecting to database")
self._pool = PostgresPool( self._pool = PostgresPool(
database_settings, database_settings,
Environment.get("DB_POOL_SIZE", int, 1), Environment.get("DB_POOL_SIZE", int, 1),
) )
self._logger.info("Connected to database") _logger.info("Connected to database")
except Exception as e: except Exception as e:
self._logger.fatal("Connecting to database failed", e) _logger.fatal("Connecting to database failed", e)
async def execute(self, statement: str, args=None, multi=True) -> list[list]: async def execute(self, statement: str, args=None, multi=True) -> list[list]:
self._logger.trace(f"execute {statement} with args: {args}") _logger.trace(f"execute {statement} with args: {args}")
return await self._pool.execute(statement, args, multi) return await self._pool.execute(statement, args, multi)
async def select_map(self, statement: str, args=None) -> list[dict]: async def select_map(self, statement: str, args=None) -> list[dict]:
self._logger.trace(f"select {statement} with args: {args}") _logger.trace(f"select {statement} with args: {args}")
try: try:
return await self._pool.select_map(statement, args) return await self._pool.select_map(statement, args)
except (OperationalError, PoolTimeout) as e: except (OperationalError, PoolTimeout) as e:
if self._fails >= 3: if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4() uid = uuid.uuid4()
raise Exception( raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
) )
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1 self._fails += 1
try: try:
self._logger.debug("Retry select") _logger.debug("Retry select")
return await self.select_map(statement, args) return await self.select_map(statement, args)
except Exception as e: except Exception as e:
pass pass
return [] return []
except Exception as e: except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
raise e raise e
async def select(self, statement: str, args=None) -> list[str] | list[tuple] | list[Any]: async def select(self, statement: str, args=None) -> list[str] | list[tuple] | list[Any]:
self._logger.trace(f"select {statement} with args: {args}") _logger.trace(f"select {statement} with args: {args}")
try: try:
return await self._pool.select(statement, args) return await self._pool.select(statement, args)
except (OperationalError, PoolTimeout) as e: except (OperationalError, PoolTimeout) as e:
if self._fails >= 3: if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4() uid = uuid.uuid4()
raise Exception( raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
) )
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1 self._fails += 1
try: try:
self._logger.debug("Retry select") _logger.debug("Retry select")
return await self.select(statement, args) return await self.select(statement, args)
except Exception as e: except Exception as e:
pass pass
return [] return []
except Exception as e: except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
raise e raise e

View File

@@ -1,17 +0,0 @@
from cpl.core.configuration.configuration import Configuration
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.model.database_settings import DatabaseSettings
from cpl.database.model.server_type import ServerTypes, ServerType
from cpl.database.postgres.db_context import DBContext
from cpl.dependency.module.module import Module
from cpl.dependency.service_collection import ServiceCollection
class PostgresModule(Module):
config = [DatabaseSettings]
singleton = [(DBContextABC, DBContext)]
@staticmethod
def register(collection: ServiceCollection):
ServerType.set_server_type(ServerTypes(ServerTypes.POSTGRES.value))
Configuration.set("DB_DEFAULT_PORT", 5432)

View File

@@ -5,9 +5,10 @@ from psycopg import sql
from psycopg_pool import AsyncConnectionPool, PoolTimeout from psycopg_pool import AsyncConnectionPool, PoolTimeout
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
from cpl.database.model import DatabaseSettings from cpl.database.model import DatabaseSettings
from cpl.dependency.context import get_provider
_logger = DBLogger(__name__)
class PostgresPool: class PostgresPool:
@@ -31,16 +32,14 @@ class PostgresPool:
pool = AsyncConnectionPool( pool = AsyncConnectionPool(
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1) conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
) )
await pool.open()
try: try:
await pool.open()
async with pool.connection() as con: async with pool.connection() as con:
await pool.check_connection(con) await pool.check_connection(con)
self._pool = pool
except PoolTimeout as e: except PoolTimeout as e:
await pool.close() await pool.close()
logger = get_provider().get_service(DBLogger) _logger.fatal(f"Failed to connect to the database", e)
logger.fatal(f"Failed to connect to the database", e) self._pool = pool
return self._pool return self._pool

View File

@@ -1,2 +0,0 @@
from .executed_migration import ExecutedMigration
from .executed_migration_dao import ExecutedMigrationDao

View File

@@ -1,11 +1,14 @@
from cpl.database.table_manager import TableManager from cpl.database import TableManager
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.database.db_logger import DBLogger
from cpl.database.schema.executed_migration import ExecutedMigration from cpl.database.schema.executed_migration import ExecutedMigration
_logger = DBLogger(__name__)
class ExecutedMigrationDao(DataAccessObjectABC[ExecutedMigration]): class ExecutedMigrationDao(DataAccessObjectABC[ExecutedMigration]):
def __init__(self): def __init__(self):
DataAccessObjectABC.__init__(self, ExecutedMigration, TableManager.get("executed_migrations")) DataAccessObjectABC.__init__(self, __name__, ExecutedMigration, TableManager.get("executed_migrations"))
self.attribute(ExecutedMigration.migration_id, str, primary_key=True, db_name="migrationId") self.attribute(ExecutedMigration.migration_id, str, primary_key=True, db_name="migrationId")

View File

@@ -1,2 +0,0 @@
from .seeder_service import SeederService
from .migration_service import MigrationService

View File

@@ -1,21 +1,21 @@
import glob import glob
import os import os
from cpl.database.abc.db_context_abc import DBContextABC from cpl.database.abc import DBContextABC
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
from cpl.database.model.migration import Migration from cpl.database.model import Migration
from cpl.database.model.server_type import ServerType, ServerTypes from cpl.database.model.server_type import ServerType, ServerTypes
from cpl.database.schema.executed_migration import ExecutedMigration from cpl.database.schema.executed_migration import ExecutedMigration
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
from cpl.dependency.hosted import StartupTask
_logger = DBLogger(__name__)
class MigrationService(StartupTask): class MigrationService:
def __init__(self, logger: DBLogger, db: DBContextABC, executed_migration_dao: ExecutedMigrationDao): def __init__(self, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao):
self._logger = logger
self._db = db self._db = db
self._executed_migration_dao = executed_migration_dao self._executedMigrationDao = executedMigrationDao
self._script_directories: list[str] = [] self._script_directories: list[str] = []
@@ -23,27 +23,13 @@ class MigrationService(StartupTask):
self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/postgres")) self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/postgres"))
elif ServerType.server_type == ServerTypes.MYSQL: elif ServerType.server_type == ServerTypes.MYSQL:
self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/mysql")) self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/mysql"))
else:
raise Exception("Unsupported database type")
async def run(self):
await self._execute(self._load_scripts())
def with_directory(self, directory: str) -> "MigrationService": def with_directory(self, directory: str) -> "MigrationService":
cpl_rel_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../..") self._script_directories.append(directory)
cpl_abs_path = os.path.abspath(cpl_rel_path)
if directory.startswith(cpl_abs_path) or os.path.abspath(directory).startswith(cpl_abs_path):
if len(self._script_directories) > 0:
self._script_directories.insert(1, directory)
else:
self._script_directories.append(directory)
else:
self._script_directories.append(directory)
return self return self
async def _get_migration_history(self) -> list[ExecutedMigration]: async def _get_migration_history(self) -> list[ExecutedMigration]:
results = await self._db.select(f"SELECT * FROM {self._executed_migration_dao.table_name}") results = await self._db.select(f"SELECT * FROM {self._executedMigrationDao.table_name}")
applied_migrations = [] applied_migrations = []
for result in results: for result in results:
applied_migrations.append(ExecutedMigration(result[0])) applied_migrations.append(ExecutedMigration(result[0]))
@@ -106,17 +92,20 @@ class MigrationService(StartupTask):
try: try:
# check if table exists # check if table exists
if len(result) > 0: if len(result) > 0:
migration_from_db = await self._executed_migration_dao.find_by_id(migration.name) migration_from_db = await self._executedMigrationDao.find_by_id(migration.name)
if migration_from_db is not None: if migration_from_db is not None:
continue continue
self._logger.debug(f"Running upgrade migration: {migration.name}") _logger.debug(f"Running upgrade migration: {migration.name}")
await self._db.execute(migration.script, multi=True) await self._db.execute(migration.script, multi=True)
await self._executed_migration_dao.create(ExecutedMigration(migration.name), skip_editor=True) await self._executedMigrationDao.create(ExecutedMigration(migration.name), skip_editor=True)
except Exception as e: except Exception as e:
self._logger.fatal( _logger.fatal(
f"Migration failed: {migration.name}\n{active_statement}", f"Migration failed: {migration.name}\n{active_statement}",
e, e,
) )
async def migrate(self):
await self._execute(self._load_scripts())

View File

@@ -1,18 +1,18 @@
from cpl.database.abc.data_seeder_abc import DataSeederABC from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProviderABC
from cpl.dependency.hosted import StartupTask
class SeederService(StartupTask): _logger = DBLogger(__name__)
def __init__(self, provider: ServiceProvider):
StartupTask.__init__(self) class SeederService:
def __init__(self, provider: ServiceProviderABC):
self._provider = provider self._provider = provider
self._logger = provider.get_service(DBLogger)
async def run(self): async def seed(self):
seeders = self._provider.get_services(DataSeederABC) seeders = self._provider.get_services(DataSeederABC)
self._logger.debug(f"Found {len(seeders)} seeders") _logger.debug(f"Found {len(seeders)} seeders")
for seeder in seeders: for seeder in seeders:
await seeder.seed() await seeder.seed()

View File

@@ -1,7 +1,7 @@
from .context import get_provider, use_provider from .scope import Scope
from .inject import inject from .scope_abc import ScopeABC
from .service_collection import ServiceCollection from .service_collection import ServiceCollection
from .service_descriptor import ServiceDescriptor from .service_descriptor import ServiceDescriptor
from .service_lifetime import ServiceLifetimeEnum from .service_lifetime_enum import ServiceLifetimeEnum
from .service_provider import ServiceProvider
from .service_provider import ServiceProvider from .service_provider import ServiceProvider
from .service_provider_abc import ServiceProviderABC

View File

@@ -1,22 +0,0 @@
import contextvars
from contextlib import contextmanager
_current_provider = contextvars.ContextVar("current_provider", default=None)
def use_root_provider(provider: "ServiceProvider"):
_current_provider.set(provider)
@contextmanager
def use_provider(provider: "ServiceProvider"):
token = _current_provider.set(provider)
try:
yield
finally:
_current_provider.reset(token)
def get_provider() -> "ServiceProvider":
return _current_provider.get()

View File

@@ -1,2 +0,0 @@
from .hosted_service import HostedService
from .startup_task import StartupTask

View File

@@ -1,9 +0,0 @@
from abc import ABC, abstractmethod
class HostedService(ABC):
@abstractmethod
async def start(self): ...
@abstractmethod
async def stop(self): ...

View File

@@ -1,6 +0,0 @@
from abc import ABC, abstractmethod
class StartupTask(ABC):
@abstractmethod
async def run(self): ...

View File

@@ -1,42 +0,0 @@
import functools
from asyncio import iscoroutinefunction
from inspect import signature
from cpl.dependency.context import get_provider
def inject(f=None):
if f is None:
return functools.partial(inject)
if iscoroutinefunction(f):
@functools.wraps(f)
async def async_inner(*args, **kwargs):
from cpl.dependency.service_provider import ServiceProvider
provider: ServiceProvider | None = get_provider()
if provider is None:
raise ValueError(
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
)
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
return await f(*args, *injection, **kwargs)
return async_inner
@functools.wraps(f)
def inner(*args, **kwargs):
from cpl.dependency.service_provider import ServiceProvider
provider: ServiceProvider | None = get_provider()
if provider is None:
raise ValueError(
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
)
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
return f(*args, *injection, **kwargs)
return inner

View File

@@ -1,10 +0,0 @@
from cpl.dependency.module.module_abc import ModuleABC
class Module(ModuleABC):
@staticmethod
def register(collection: "ServiceCollection"): ...
@staticmethod
def configure(provider: "ServiceProvider"): ...

View File

@@ -1,60 +0,0 @@
from abc import ABC, abstractmethod
from inspect import isclass
from cpl.core.configuration import ConfigurationModelABC
class ModuleABC(ABC):
__OPTIONAL_VARS = ["dependencies", "configuration", "singleton", "scoped", "transient", "hosted"]
def __init_subclass__(cls):
super().__init_subclass__()
if f"{cls.__module__}.{cls.__name__}" == "cpl.dependency.module.module.Module":
return
for var in cls.__OPTIONAL_VARS:
if not hasattr(cls, var):
continue
value = getattr(cls, var)
if not isinstance(value, list):
raise TypeError(f"'{var}' attribute of {cls.__name__} must be a list, not {type(value).__name__}")
for dep in value:
if var == "config":
if not isclass(dep) or not issubclass(dep, ConfigurationModelABC):
raise TypeError(
f"Invalid config {dep} in {cls.__name__}: must be subclass of ConfigurationModelABC"
)
elif var == "dependencies":
if not isinstance(dep, (list, tuple)) and not isclass(dep):
raise TypeError(f"Invalid dependency {dep} in {cls.__name__}")
else:
if not isinstance(dep, tuple) and not isclass(dep):
raise TypeError(f"Invalid {var} {dep} in {cls.__name__}")
@classmethod
def get_singleton(cls):
return getattr(cls, "singleton", [])
@classmethod
def get_scoped(cls):
return getattr(cls, "scoped", [])
@classmethod
def get_transient(cls):
return getattr(cls, "transient", [])
@classmethod
def get_hosted(cls):
return getattr(cls, "hosted", [])
@staticmethod
@abstractmethod
def register(collection: "ServiceCollection"): ...
@staticmethod
@abstractmethod
def configure(provider: "ServiceProvider"): ...

Some files were not shown because too many files have changed in this diff Show More