Compare commits

..

1 Commits

Author SHA1 Message Date
69bbbc8cee Authorization via with_route
Some checks failed
Build on push / prepare (push) Successful in 9s
Build on push / query (push) Successful in 20s
Build on push / core (push) Successful in 20s
Build on push / dependency (push) Successful in 17s
Build on push / mail (push) Successful in 15s
Build on push / application (push) Successful in 18s
Build on push / translation (push) Successful in 18s
Build on push / database (push) Successful in 25s
Build on push / auth (push) Successful in 15s
Build on push / api (push) Successful in 14s
Test before pr merge / test-lint (pull_request) Failing after 6s
2025-09-22 22:03:42 +02:00
219 changed files with 1981 additions and 2195 deletions

View File

@@ -25,11 +25,7 @@ jobs:
git tag
DATE=$(date +'%Y.%m.%d')
TAG_COUNT=$(git tag -l "${DATE}.*" | wc -l)
if [ "$TAG_COUNT" -eq 0 ]; then
BUILD_NUMBER=0
else
BUILD_NUMBER=$(($TAG_COUNT + 1))
fi
BUILD_NUMBER=$(($TAG_COUNT + 1))
VERSION_SUFFIX=${{ inputs.version_suffix }}
if [ -n "$VERSION_SUFFIX" ] && [ "$VERSION_SUFFIX" = "dev" ]; then

3
.gitignore vendored
View File

@@ -139,6 +139,3 @@ PythonImportHelper-v2-Completion.json
# cpl unittest stuff
unittests/test_*_playground
# cpl logs
**/logs/*.jsonl

View File

@@ -1,79 +0,0 @@
from starlette.responses import JSONResponse
from cpl.api.api_module import ApiModule
from cpl.api.application.web_app import WebApp
from cpl.application.application_builder import ApplicationBuilder
from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema import AuthUser, Role
from cpl.core.configuration import Configuration
from cpl.core.console import Console
from cpl.core.environment import Environment
from cpl.core.utils.cache import Cache
from cpl.database.mysql.mysql_module import MySQLModule
from scoped_service import ScopedService
from service import PingService
def main():
builder = ApplicationBuilder[WebApp](WebApp)
Configuration.add_json_file(f"appsettings.json")
Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json")
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
# builder.services.add_logging()
builder.services.add_structured_logging()
builder.services.add_transient(PingService)
builder.services.add_module(MySQLModule)
builder.services.add_module(ApiModule)
builder.services.add_scoped(ScopedService)
builder.services.add_cache(AuthUser)
builder.services.add_cache(Role)
app = builder.build()
app.with_logging()
app.with_database()
app.with_authentication()
app.with_authorization()
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET", authentication=True, permissions=[Permissions.administrator])
app.with_routes_directory("routes")
provider = builder.service_provider
user_cache = provider.get_service(Cache[AuthUser])
role_cache = provider.get_service(Cache[Role])
if role_cache == user_cache:
raise Exception("Cache service is not working")
s1 = provider.get_service(ScopedService)
s2 = provider.get_service(ScopedService)
if s1.name == s2.name:
raise Exception("Scoped service is not working")
with provider.create_scope() as scope:
s3 = scope.get_service(ScopedService)
s4 = scope.get_service(ScopedService)
if s3.name != s4.name:
raise Exception("Scoped service is not working")
if s1.name == s3.name:
raise Exception("Scoped service is not working")
Console.write_line(
s1.name,
s2.name,
s3.name,
s4.name,
)
app.run()
if __name__ == "__main__":
main()

View File

@@ -1,21 +0,0 @@
from urllib.request import Request
from service import PingService
from starlette.responses import JSONResponse
from cpl.api import APILogger
from cpl.api.router import Router
from cpl.core.console import Console
from cpl.dependency import ServiceProvider
from scoped_service import ScopedService
@Router.authenticate()
# @Router.authorize(permissions=[Permissions.administrator])
# @Router.authorize(policies=["test"])
@Router.get(f"/ping")
async def ping(r: Request, ping: PingService, logger: APILogger, provider: ServiceProvider, scoped: ScopedService):
logger.info(f"Ping: {ping}")
Console.write_line(scoped.name)
return JSONResponse(ping.ping(r))

View File

@@ -1,14 +0,0 @@
from cpl.core.console.console import Console
from cpl.core.utils.string import String
class ScopedService:
def __init__(self):
self._name = String.random(8)
@property
def name(self) -> str:
return self._name
def run(self):
Console.write_line(f"Im {self._name}")

View File

@@ -1,45 +0,0 @@
from cpl.application.abc import ApplicationABC
from cpl.core.console.console import Console
from cpl.dependency import ServiceProvider
from test_abc import TestABC
from test_service import TestService
from di_tester_service import DITesterService
from tester import Tester
class Application(ApplicationABC):
def __init__(self, services: ServiceProvider):
ApplicationABC.__init__(self, services)
def _part_of_scoped(self):
ts: TestService = self._services.get_service(TestService)
ts.run()
def main(self):
with self._services.create_scope() as scope:
Console.write_line("Scope1")
ts: TestService = scope.get_service(TestService)
ts.run()
dit: DITesterService = scope.get_service(DITesterService)
dit.run()
if ts.name != dit.name:
raise Exception("DI is broken!")
with self._services.create_scope() as scope:
Console.write_line("Scope2")
ts: TestService = scope.get_service(TestService)
ts.run()
dit: DITesterService = scope.get_service(DITesterService)
dit.run()
if ts.name != dit.name:
raise Exception("DI is broken!")
Console.write_line("Global")
self._part_of_scoped()
#from static_test import StaticTest
#StaticTest.test()
self._services.get_service(Tester)
Console.write_line(self._services.get_services(TestABC))

View File

@@ -1,27 +0,0 @@
from cpl.application.abc import StartupABC
from cpl.dependency import ServiceProvider, ServiceCollection
from di_tester_service import DITesterService
from test1_service import Test1Service
from test2_service import Test2Service
from test_abc import TestABC
from test_service import TestService
from tester import Tester
class Startup(StartupABC):
def __init__(self):
StartupABC.__init__(self)
@staticmethod
def configure_configuration(): ...
@staticmethod
def configure_services(services: ServiceCollection) -> ServiceProvider:
services.add_scoped(TestService)
services.add_scoped(DITesterService)
services.add_singleton(TestABC, Test1Service)
services.add_singleton(TestABC, Test2Service)
services.add_singleton(Tester)
return services.build()

View File

@@ -1,10 +0,0 @@
from cpl.dependency import ServiceProvider, ServiceProvider
from cpl.dependency.inject import inject
from test_service import TestService
class StaticTest:
@staticmethod
@inject
def test(services: ServiceProvider, t1: TestService):
t1.run()

View File

@@ -1,7 +0,0 @@
from cpl.core.console.console import Console
from test_abc import TestABC
class Tester:
def __init__(self, t1: TestABC, t2: TestABC, t3: TestABC, t: list[TestABC]):
Console.write_line("Tester:", t, t1, t2, t3)

View File

@@ -1,20 +0,0 @@
import asyncio
import time
from cpl.core.console import Console
from cpl.dependency.hosted.hosted_service import HostedService
class Hosted(HostedService):
def __init__(self):
self._stopped = False
async def start(self):
Console.write_line("Hosted Service Started")
while not self._stopped:
Console.write_line("Hosted Service Running")
await asyncio.sleep(5)
async def stop(self):
Console.write_line("Hosted Service Stopped")
self._stopped = True

View File

@@ -1,10 +0,0 @@
from cpl.core.console import Console
class ScopedService:
def __init__(self):
self.value = "I am a scoped service"
Console.write_line(self.value, self)
def get_value(self):
return self.value

View File

@@ -1,60 +0,0 @@
from cpl.core.console import Console
from cpl.core.utils.benchmark import Benchmark
from cpl.query.enumerable import Enumerable
from cpl.query.immutable_list import ImmutableList
from cpl.query.list import List
from cpl.query.set import Set
def _default():
Console.write_line(Enumerable.empty().to_list())
Console.write_line(Enumerable.range(0, 100).length)
Console.write_line(Enumerable.range(0, 100).to_list())
Console.write_line(Enumerable.range(0, 100).where(lambda x: x % 2 == 0).length)
Console.write_line(
Enumerable.range(0, 100).where(lambda x: x % 2 == 0).to_list().select(lambda x: str(x)).to_list()
)
Console.write_line(List)
s =Enumerable.range(0, 10).to_set()
Console.write_line(s)
s.add(1)
Console.write_line(s)
data = Enumerable(
[
{"name": "Alice", "age": 30},
{"name": "Dave", "age": 35},
{"name": "Charlie", "age": 25},
{"name": "Bob", "age": 25},
]
)
Console.write_line(data.order_by(lambda x: x["age"]).to_list())
Console.write_line(data.order_by(lambda x: x["age"]).then_by(lambda x: x["name"]).to_list())
Console.write_line(data.order_by(lambda x: x["name"]).then_by(lambda x: x["age"]).to_list())
def t_benchmark(data: list):
Benchmark.all("Enumerable", lambda: Enumerable(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
Benchmark.all("Set", lambda: Set(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
Benchmark.all("List", lambda: List(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
Benchmark.all(
"ImmutableList", lambda: ImmutableList(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list()
)
Benchmark.all("List comprehension", lambda: [x * 2 for x in data if x % 2 == 0])
def main():
N = 1_000_000
data = list(range(N))
t_benchmark(data)
Console.write_line()
_default()
if __name__ == "__main__":
main()

View File

@@ -1,4 +1,32 @@
from .error import APIError, AlreadyExists, EndpointNotImplemented, Forbidden, NotFound, Unauthorized
from .logger import APILogger
from .settings import ApiSettings
from .api_module import ApiModule
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
def add_api(collection: _ServiceCollection):
try:
from cpl.database import mysql
collection.add_module(mysql)
except ImportError as e:
from cpl.core.errors import dependency_error
dependency_error("cpl-database", e)
try:
from cpl import auth
from cpl.auth import permission
collection.add_module(auth)
collection.add_module(permission)
except ImportError as e:
from cpl.core.errors import dependency_error
dependency_error("cpl-auth", e)
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.registry.route import RouteRegistry
collection.add_singleton(PolicyRegistry)
collection.add_singleton(RouteRegistry)
_ServiceCollection.with_module(add_api, __name__)

View File

@@ -1 +0,0 @@
from .asgi_middleware_abc import ASGIMiddleware

View File

@@ -1,24 +0,0 @@
from cpl.auth.auth_module import AuthModule
from cpl.auth.permission.permission_module import PermissionsModule
from cpl.database.database_module import DatabaseModule
from cpl.dependency.module import Module, TModule
class ApiModule(Module):
@staticmethod
def dependencies() -> list[TModule]:
return [AuthModule, DatabaseModule, PermissionsModule]
@staticmethod
def register(collection: "ServiceCollection"):
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.registry.route import RouteRegistry
collection.add_module(DatabaseModule)
collection.add_module(AuthModule)
collection.add_module(PermissionsModule)
collection.add_singleton(PolicyRegistry)
collection.add_singleton(RouteRegistry)

View File

@@ -1 +0,0 @@
from .web_app import WebApp

View File

@@ -25,46 +25,42 @@ from cpl.api.registry.route import RouteRegistry
from cpl.api.router import Router
from cpl.api.settings import ApiSettings
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
from cpl.api.api_module import ApiModule
from cpl.application.abc.application_abc import ApplicationABC
from cpl.auth.auth_module import AuthModule
from cpl.auth.permission.permission_module import PermissionsModule
from cpl.core.configuration import Configuration
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.service_provider_abc import ServiceProviderABC
_logger = APILogger("API")
PolicyInput = Union[dict[str, PolicyResolver], Policy]
class WebApp(ApplicationABC):
def __init__(self, services: ServiceProvider):
super().__init__(services, [AuthModule, PermissionsModule, ApiModule])
def __init__(self, services: ServiceProviderABC):
super().__init__(services, [auth, api])
self._app: Starlette | None = None
self._logger = services.get_service(APILogger)
self._api_settings = Configuration.get(ApiSettings)
self._policies = services.get_service(PolicyRegistry)
self._routes = services.get_service(RouteRegistry)
self._middleware: list[Middleware] = []
self._middleware: list[Middleware] = [
Middleware(RequestMiddleware),
Middleware(LoggingMiddleware),
]
self._exception_handlers: Mapping[Any, ExceptionHandler] = {
Exception: self._handle_exception,
APIError: self._handle_exception,
}
self.with_middleware(RequestMiddleware)
self.with_middleware(LoggingMiddleware)
async def _handle_exception(self, request: Request, exc: Exception):
@staticmethod
async def _handle_exception(request: Request, exc: Exception):
if isinstance(exc, APIError):
self._logger.error(exc)
_logger.error(exc)
return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
if hasattr(request.state, "request_id"):
self._logger.error(f"Request {request.state.request_id}", exc)
_logger.error(f"Request {request.state.request_id}", exc)
else:
self._logger.error("Request unknown", exc)
_logger.error("Request unknown", exc)
return JSONResponse({"error": str(exc)}, status_code=500)
@@ -72,10 +68,10 @@ class WebApp(ApplicationABC):
origins = self._api_settings.allowed_origins
if origins is None or origins == "":
self._logger.warning("No allowed origins specified, allowing all origins")
_logger.warning("No allowed origins specified, allowing all origins")
return ["*"]
self._logger.debug(f"Allowed origins: {origins}")
_logger.debug(f"Allowed origins: {origins}")
return origins.split(",")
def with_database(self) -> Self:
@@ -171,9 +167,9 @@ class WebApp(ApplicationABC):
self._check_for_app()
if isinstance(middleware, Middleware):
self._middleware.append(inject(middleware))
self._middleware.append(middleware)
elif callable(middleware):
self._middleware.append(Middleware(inject(middleware)))
self._middleware.append(Middleware(middleware))
else:
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
@@ -194,11 +190,11 @@ class WebApp(ApplicationABC):
if isinstance(policy, dict):
for name, resolver in policy.items():
if not isinstance(name, str):
self._logger.warning(f"Skipping policy at index {i}, name must be a string")
_logger.warning(f"Skipping policy at index {i}, name must be a string")
continue
if not callable(resolver):
self._logger.warning(f"Skipping policy {name}, resolver must be callable")
_logger.warning(f"Skipping policy {name}, resolver must be callable")
continue
_policies.append(Policy(name, resolver))
@@ -206,7 +202,7 @@ class WebApp(ApplicationABC):
_policies.append(policy)
self._policies.extend(_policies)
self._policies.extend_policies(_policies)
self.with_middleware(AuthorizationMiddleware)
return self
@@ -216,14 +212,14 @@ class WebApp(ApplicationABC):
for policy_name in rule["policies"]:
policy = self._policies.get(policy_name)
if not policy:
self._logger.fatal(f"Authorization policy '{policy_name}' not found")
_logger.fatal(f"Authorization policy '{policy_name}' not found")
async def main(self):
self._logger.debug(f"Preparing API")
_logger.debug(f"Preparing API")
self._validate_policies()
if self._app is None:
routes = [route.to_starlette(inject) for route in self._routes.all()]
routes = [route.to_starlette(self._services.inject) for route in self._routes.all()]
app = Starlette(
routes=routes,
@@ -241,7 +237,7 @@ class WebApp(ApplicationABC):
else:
app = self._app
self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
_logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
config = uvicorn.Config(
app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio"
@@ -249,4 +245,4 @@ class WebApp(ApplicationABC):
server = uvicorn.Server(config)
await server.serve()
self._logger.info("Shutdown API")
_logger.info("Shutdown API")

View File

@@ -1,7 +1,7 @@
from cpl.core.log.wrapped_logger import WrappedLogger
from cpl.core.log.logger import Logger
class APILogger(WrappedLogger):
class APILogger(Logger):
def __init__(self):
WrappedLogger.__init__(self, "api")
def __init__(self, source: str):
Logger.__init__(self, source, "api")

View File

@@ -1,4 +0,0 @@
from .authentication import AuthenticationMiddleware
from .authorization import AuthorizationMiddleware
from .logging import LoggingMiddleware
from .request import RequestMiddleware

View File

@@ -2,22 +2,24 @@ from keycloak import KeycloakAuthenticationError
from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.error import Unauthorized
from cpl.api.logger import APILogger
from cpl.api.error import Unauthorized
from cpl.api.middleware.request import get_request
from cpl.api.router import Router
from cpl.auth.keycloak import KeycloakClient
from cpl.auth.schema import AuthUserDao, AuthUser
from cpl.core.ctx import set_user
from cpl.dependency import ServiceProviderABC
_logger = APILogger(__name__)
class AuthenticationMiddleware(ASGIMiddleware):
def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
@ServiceProviderABC.inject
def __init__(self, app, keycloak: KeycloakClient, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app)
self._logger = logger
self._keycloak = keycloak
self._user_dao = user_dao
@@ -26,11 +28,11 @@ class AuthenticationMiddleware(ASGIMiddleware):
url = request.url.path
if url not in Router.get_auth_required_routes():
self._logger.trace(f"No authentication required for {url}")
_logger.trace(f"No authentication required for {url}")
return await self._app(scope, receive, send)
if not request.headers.get("Authorization"):
self._logger.debug(f"Unauthorized access to {url}, missing Authorization header")
_logger.debug(f"Unauthorized access to {url}, missing Authorization header")
return await Unauthorized(f"Missing header Authorization").asgi_response(scope, receive, send)
auth_header = request.headers.get("Authorization", None)
@@ -39,7 +41,7 @@ class AuthenticationMiddleware(ASGIMiddleware):
token = auth_header.split("Bearer ")[1]
if not await self._verify_login(token):
self._logger.debug(f"Unauthorized access to {url}, invalid token")
_logger.debug(f"Unauthorized access to {url}, invalid token")
return await Unauthorized("Invalid token").asgi_response(scope, receive, send)
# check user exists in db, if not create
@@ -49,7 +51,7 @@ class AuthenticationMiddleware(ASGIMiddleware):
user = await self._get_or_crate_user(keycloak_id)
if user.deleted:
self._logger.debug(f"Unauthorized access to {url}, user is deleted")
_logger.debug(f"Unauthorized access to {url}, user is deleted")
return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
request.state.user = user
@@ -71,8 +73,8 @@ class AuthenticationMiddleware(ASGIMiddleware):
token_info = self._keycloak.introspect(token)
return token_info.get("active", False)
except KeycloakAuthenticationError as e:
self._logger.debug(f"Keycloak authentication error: {e}")
_logger.debug(f"Keycloak authentication error: {e}")
return False
except Exception as e:
self._logger.error(f"Unexpected error during token verification: {e}")
_logger.error(f"Unexpected error during token verification: {e}")
return False

View File

@@ -9,15 +9,17 @@ from cpl.api.registry.policy import PolicyRegistry
from cpl.api.router import Router
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
from cpl.core.ctx.user_context import get_user
from cpl.dependency.service_provider_abc import ServiceProviderABC
_logger = APILogger(__name__)
class AuthorizationMiddleware(ASGIMiddleware):
def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao):
@ServiceProviderABC.inject
def __init__(self, app, policies: PolicyRegistry, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app)
self._logger = logger
self._policies = policies
self._user_dao = user_dao
@@ -26,7 +28,7 @@ class AuthorizationMiddleware(ASGIMiddleware):
url = request.url.path
if url not in Router.get_authorization_rules_paths():
self._logger.trace(f"No authorization required for {url}")
_logger.trace(f"No authorization required for {url}")
return await self._app(scope, receive, send)
user = get_user()
@@ -51,21 +53,17 @@ class AuthorizationMiddleware(ASGIMiddleware):
if rule["permissions"]:
if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
scope, receive, send
)
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
scope, receive, send
)
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
for policy_name in rule["policies"]:
policy = self._policies.get(policy_name)
if not policy:
self._logger.warning(f"Authorization policy '{policy_name}' not found")
_logger.warning(f"Authorization policy '{policy_name}' not found")
continue
if not await policy.resolve(user):
return await Forbidden(f"policy {policy.name} failed").asgi_response(scope, receive, send)
return await self._call_next(scope, receive, send)
return await self._call_next(scope, receive, send)

View File

@@ -7,14 +7,14 @@ from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger
from cpl.api.middleware.request import get_request
_logger = APILogger(__name__)
class LoggingMiddleware(ASGIMiddleware):
def __init__(self, app, logger: APILogger):
def __init__(self, app):
ASGIMiddleware.__init__(self, app)
self._logger = logger
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
await self._call_next(scope, receive, send)
@@ -53,8 +53,9 @@ class LoggingMiddleware(ASGIMiddleware):
}
return {key: value for key, value in headers.items() if key in relevant_keys}
async def _log_request(self, request: Request):
self._logger.debug(
@classmethod
async def _log_request(cls, request: Request):
_logger.debug(
f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}"
)
@@ -63,7 +64,7 @@ class LoggingMiddleware(ASGIMiddleware):
user = get_user()
request_info = {
"headers": self._filter_relevant_headers(dict(request.headers)),
"headers": cls._filter_relevant_headers(dict(request.headers)),
"args": dict(request.query_params),
"form-data": (
await request.form()
@@ -77,9 +78,10 @@ class LoggingMiddleware(ASGIMiddleware):
),
}
self._logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
_logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
async def _log_after_request(self, request: Request, status_code: int, duration: float):
self._logger.info(
@staticmethod
async def _log_after_request(request: Request, status_code: int, duration: float):
_logger.info(
f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
)

View File

@@ -9,20 +9,16 @@ from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger
from cpl.api.typing import TRequest
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
_logger = APILogger(__name__)
class RequestMiddleware(ASGIMiddleware):
def __init__(self, app, provider: ServiceProvider, logger: APILogger):
def __init__(self, app):
ASGIMiddleware.__init__(self, app)
self._provider = provider
self._logger = logger
self._ctx_token = None
async def __call__(self, scope: Scope, receive: Receive, send: Send):
@@ -30,15 +26,14 @@ class RequestMiddleware(ASGIMiddleware):
await self.set_request_data(request)
try:
with self._provider.create_scope():
inject(await self._app(scope, receive, send))
await self._app(scope, receive, send)
finally:
await self.clean_request_data()
async def set_request_data(self, request: TRequest):
request.state.request_id = uuid4()
request.state.start_time = time.time()
self._logger.trace(f"Set new current request: {request.state.request_id}")
_logger.trace(f"Set new current request: {request.state.request_id}")
self._ctx_token = _request_context.set(request)
@@ -50,7 +45,7 @@ class RequestMiddleware(ASGIMiddleware):
if self._ctx_token is None:
return
self._logger.trace(f"Clearing current request: {request.state.request_id}")
_logger.trace(f"Clearing current request: {request.state.request_id}")
_request_context.reset(self._ctx_token)

View File

@@ -1,3 +0,0 @@
from .api_route import ApiRoute
from .policy import Policy
from .validation_match import ValidationMatch

View File

@@ -7,7 +7,13 @@ from cpl.api.typing import HTTPMethods
class ApiRoute:
def __init__(self, path: str, fn: Callable, method: HTTPMethods, **kwargs):
def __init__(
self,
path: str,
fn: Callable,
method: HTTPMethods,
**kwargs
):
self._path = path
self._fn = fn
self._method = method

View File

@@ -1,5 +1,5 @@
from asyncio import iscoroutinefunction
from typing import Optional
from typing import Optional, Any, Coroutine, Awaitable
from cpl.api.typing import PolicyResolver
from cpl.core.ctx import get_user

View File

@@ -1,2 +0,0 @@
from .policy import PolicyRegistry
from .route import RouteRegistry

View File

@@ -1,5 +1,6 @@
from typing import Optional
from cpl.api.model.policy import Policy
from cpl.api.model.api_route import ApiRoute
from cpl.core.abc.registry_abc import RegistryABC

View File

@@ -3,7 +3,6 @@ from enum import Enum
from cpl.api.model.validation_match import ValidationMatch
from cpl.api.registry.route import RouteRegistry
from cpl.api.typing import HTTPMethods
from cpl.dependency import get_provider
class Router:
@@ -42,13 +41,7 @@ class Router:
return inner
@classmethod
def authorize(
cls,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
):
def authorize(cls, roles: list[str | Enum]=None, permissions: list[str | Enum]=None, policies: list[str]=None, match: ValidationMatch=None):
"""
Decorator to mark a route as requiring authorization.
Usage:
@@ -92,14 +85,15 @@ class Router:
return inner
@classmethod
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs):
from cpl.api.model.api_route import ApiRoute
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry=None, **kwargs):
if not registry:
routes = get_provider().get_service(RouteRegistry)
from cpl.api.model.api_route import ApiRoute
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = ServiceProviderABC.get_global_service(RouteRegistry)
else:
routes = registry
def inner(fn):
routes.add(ApiRoute(path, fn, method, **kwargs))
setattr(fn, "_route_path", path)
@@ -143,9 +137,8 @@ class Router:
"""
from cpl.api.model.api_route import ApiRoute
routes = get_provider().get_service(RouteRegistry)
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = ServiceProviderABC.get_global_service(RouteRegistry)
def inner(fn):
path = getattr(fn, "_route_path", None)
if path is None:
@@ -154,7 +147,7 @@ class Router:
route = routes.get(path)
if route is None:
raise ValueError(f"Cannot override a route that does not exist: {path}")
routes.add(ApiRoute(path, fn, route.method, **route.kwargs))
setattr(fn, "_route_path", path)
return fn

View File

@@ -16,4 +16,4 @@ PartialMiddleware = Union[
Middleware,
Callable[[ASGIApp], ASGIApp],
]
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]

View File

@@ -1,2 +1 @@
from .application_builder import ApplicationBuilder
from .host import Host

View File

@@ -2,10 +2,11 @@ from abc import ABC, abstractmethod
from typing import Callable, Self
from cpl.application.host import Host
from cpl.core.console.console import Console
from cpl.core.log import LogSettings
from cpl.core.log.log_level import LogLevel
from cpl.core.log.log_settings import LogSettings
from cpl.core.log.logger_abc import LoggerABC
from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.service_provider_abc import ServiceProviderABC
def __not_implemented__(package: str, func: Callable):
@@ -16,12 +17,12 @@ class ApplicationABC(ABC):
r"""ABC for the Application class
Parameters:
services: :class:`cpl.dependency.service_provider.ServiceProvider`
services: :class:`cpl.dependency.service_provider_abc.ServiceProviderABC`
Contains instances of prepared objects
"""
@abstractmethod
def __init__(self, services: ServiceProvider, required_modules: list[str | object] = None):
def __init__(self, services: ServiceProviderABC, required_modules: list[str | object] = None):
self._services = services
self._required_modules = (
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []
@@ -84,7 +85,7 @@ class ApplicationABC(ABC):
Called by custom Application.main
"""
try:
Host.run_app(self.main)
Host.run(self.main)
except KeyboardInterrupt:
pass

View File

@@ -1,10 +1,10 @@
from abc import ABC, abstractmethod
from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency import ServiceProviderABC
class ApplicationExtensionABC(ABC):
@staticmethod
@abstractmethod
def run(services: ServiceProvider): ...
def run(services: ServiceProviderABC): ...

View File

@@ -7,7 +7,6 @@ from cpl.application.abc.startup_abc import StartupABC
from cpl.application.abc.startup_extension_abc import StartupExtensionABC
from cpl.application.host import Host
from cpl.core.errors import dependency_error
from cpl.dependency.context import get_provider, use_root_provider
from cpl.dependency.service_collection import ServiceCollection
TApp = TypeVar("TApp", bound=ApplicationABC)
@@ -22,7 +21,6 @@ class ApplicationBuilder(Generic[TApp]):
self._app = app if app is not None else ApplicationABC
self._services = ServiceCollection()
use_root_provider(self._services.build())
self._startup: Optional[StartupABC] = None
self._app_extensions: list[Type[ApplicationExtensionABC]] = []
@@ -36,12 +34,7 @@ class ApplicationBuilder(Generic[TApp]):
@property
def service_provider(self):
provider = get_provider()
if provider is None:
provider = self._services.build()
use_root_provider(provider)
return provider
return self._services.build()
def validate_app_required_modules(self, app: ApplicationABC):
for module in app.required_modules:
@@ -49,7 +42,6 @@ class ApplicationBuilder(Generic[TApp]):
continue
dependency_error(
type(app).__name__,
module,
ImportError(
f"Required module '{module}' for application '{app.__class__.__name__}' is not loaded. Load using 'add_module({module})' method."
@@ -83,7 +75,6 @@ class ApplicationBuilder(Generic[TApp]):
for extension in self._app_extensions:
Host.run(extension.run, self.service_provider)
use_root_provider(self._services.build())
app = self._app(self.service_provider)
self.validate_app_required_modules(app)
return app

View File

@@ -1,80 +1,17 @@
import asyncio
from typing import Callable
from cpl.dependency import get_provider
from cpl.dependency.hosted.startup_task import StartupTask
class Host:
_loop: asyncio.AbstractEventLoop | None = None
_tasks: dict = {}
_loop = asyncio.get_event_loop()
@classmethod
def get_loop(cls) -> asyncio.AbstractEventLoop:
if cls._loop is None:
cls._loop = asyncio.new_event_loop()
asyncio.set_event_loop(cls._loop)
def get_loop(cls):
return cls._loop
@classmethod
def run_start_tasks(cls):
provider = get_provider()
tasks = provider.get_services(StartupTask)
loop = cls.get_loop()
for task in tasks:
if asyncio.iscoroutinefunction(task.run):
loop.run_until_complete(task.run())
else:
task.run()
@classmethod
def run_hosted_services(cls):
provider = get_provider()
services = provider.get_hosted_services()
loop = cls.get_loop()
for service in services:
if asyncio.iscoroutinefunction(service.start):
cls._tasks[service] = loop.create_task(service.start())
@classmethod
async def _stop_all(cls):
for service in cls._tasks.keys():
if asyncio.iscoroutinefunction(service.stop):
await service.stop()
for task in cls._tasks.values():
task.cancel()
cls._tasks.clear()
@classmethod
def run_app(cls, func: Callable, *args, **kwargs):
cls.run_start_tasks()
cls.run_hosted_services()
async def runner():
try:
if asyncio.iscoroutinefunction(func):
app_task = asyncio.create_task(func(*args, **kwargs))
else:
app_task = cls.get_loop().run_in_executor(None, func, *args, **kwargs)
await asyncio.wait(
[app_task, *cls._tasks.values()],
return_when=asyncio.FIRST_COMPLETED,
)
except (KeyboardInterrupt, asyncio.CancelledError):
pass
finally:
await cls._stop_all()
cls.get_loop().run_until_complete(runner())
@classmethod
def run(cls, func: Callable, *args, **kwargs):
if asyncio.iscoroutinefunction(func):
return cls.get_loop().run_until_complete(func(*args, **kwargs))
return cls._loop.run_until_complete(func(*args, **kwargs))
return func(*args, **kwargs)

View File

@@ -5,9 +5,10 @@ from cpl.application.abc import ApplicationABC as _ApplicationABC
from cpl.auth import permission as _permission
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin as _KeycloakAdmin
from cpl.auth.keycloak.keycloak_client import KeycloakClient as _KeycloakClient
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
from .auth_logger import AuthLogger
from .keycloak_settings import KeycloakSettings
from .logger import AuthLogger
from .auth_module import AuthModule
from .permission_seeder import PermissionSeeder
def _with_permissions(self: _ApplicationABC, *permissions: Type[Enum]) -> _ApplicationABC:
@@ -18,4 +19,66 @@ def _with_permissions(self: _ApplicationABC, *permissions: Type[Enum]) -> _Appli
return self
def _add_daos(collection: _ServiceCollection):
from .schema._administration.auth_user_dao import AuthUserDao
from .schema._administration.api_key_dao import ApiKeyDao
from .schema._permission.api_key_permission_dao import ApiKeyPermissionDao
from .schema._permission.permission_dao import PermissionDao
from .schema._permission.role_dao import RoleDao
from .schema._permission.role_permission_dao import RolePermissionDao
from .schema._permission.role_user_dao import RoleUserDao
collection.add_singleton(AuthUserDao)
collection.add_singleton(ApiKeyDao)
collection.add_singleton(ApiKeyPermissionDao)
collection.add_singleton(PermissionDao)
collection.add_singleton(RoleDao)
collection.add_singleton(RolePermissionDao)
collection.add_singleton(RoleUserDao)
def add_auth(collection: _ServiceCollection):
import os
try:
from cpl.database.service.migration_service import MigrationService
from cpl.database.model.server_type import ServerType, ServerTypes
collection.add_singleton(_KeycloakClient)
collection.add_singleton(_KeycloakAdmin)
_add_daos(collection)
provider = collection.build()
migration_service: MigrationService = provider.get_service(MigrationService)
if ServerType.server_type == ServerTypes.POSTGRES:
migration_service.with_directory(
os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/postgres")
)
elif ServerType.server_type == ServerTypes.MYSQL:
migration_service.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/mysql"))
except ImportError as e:
from cpl.core.console import Console
Console.error("cpl-database is not installed", str(e))
def add_permission(collection: _ServiceCollection):
from .permission_seeder import PermissionSeeder
from .permission.permissions_registry import PermissionsRegistry
from .permission.permissions import Permissions
try:
from cpl.database.abc.data_seeder_abc import DataSeederABC
collection.add_singleton(DataSeederABC, PermissionSeeder)
PermissionsRegistry.with_enum(Permissions)
except ImportError as e:
from cpl.core.console import Console
Console.error("cpl-database is not installed", str(e))
_ServiceCollection.with_module(add_auth, __name__)
_ServiceCollection.with_module(add_permission, _permission.__name__)
_ApplicationABC.extend(_ApplicationABC.with_permissions, _with_permissions)

View File

@@ -0,0 +1,8 @@
from cpl.core.log import Logger
from cpl.core.typing import Source
class AuthLogger(Logger):
def __init__(self, source: Source):
Logger.__init__(self, source, "auth")

View File

@@ -1,45 +0,0 @@
import os
from cpl.database.database_module import DatabaseModule
from cpl.database.model.server_type import ServerType, ServerTypes
from cpl.database.service.migration_service import MigrationService
from cpl.dependency.module import Module, TModule
from cpl.dependency.service_collection import ServiceCollection
class AuthModule(Module):
@staticmethod
def dependencies() -> list[TModule]:
return [DatabaseModule]
@staticmethod
def register(collection: ServiceCollection):
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin
from cpl.auth.keycloak.keycloak_client import KeycloakClient
from .schema._administration.api_key_dao import ApiKeyDao
from .schema._administration.auth_user_dao import AuthUserDao
from .schema._permission.api_key_permission_dao import ApiKeyPermissionDao
from .schema._permission.permission_dao import PermissionDao
from .schema._permission.role_dao import RoleDao
from .schema._permission.role_permission_dao import RolePermissionDao
from .schema._permission.role_user_dao import RoleUserDao
collection.add_singleton(KeycloakClient)
collection.add_singleton(KeycloakAdmin)
collection.add_singleton(AuthUserDao)
collection.add_singleton(ApiKeyDao)
collection.add_singleton(ApiKeyPermissionDao)
collection.add_singleton(PermissionDao)
collection.add_singleton(RoleDao)
collection.add_singleton(RolePermissionDao)
collection.add_singleton(RoleUserDao)
provider = collection.build()
migration_service: MigrationService = provider.get_service(MigrationService)
if ServerType.server_type == ServerTypes.POSTGRES:
migration_service.with_directory(
os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/postgres")
)
elif ServerType.server_type == ServerTypes.MYSQL:
migration_service.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/mysql"))

View File

@@ -1,13 +1,15 @@
from keycloak import KeycloakAdmin as _KeycloakAdmin, KeycloakOpenIDConnection
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.keycloak_settings import KeycloakSettings
from cpl.auth.logger import AuthLogger
_logger = AuthLogger("keycloak")
class KeycloakAdmin(_KeycloakAdmin):
def __init__(self, logger: AuthLogger, settings: KeycloakSettings):
# logger.info("Initializing Keycloak admin")
def __init__(self, settings: KeycloakSettings):
_logger.info("Initializing Keycloak admin")
_connection = KeycloakOpenIDConnection(
server_url=settings.url,
client_id=settings.client_id,

View File

@@ -2,13 +2,15 @@ from typing import Optional
from keycloak import KeycloakOpenID
from cpl.auth.logger import AuthLogger
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.keycloak_settings import KeycloakSettings
_logger = AuthLogger("keycloak")
class KeycloakClient(KeycloakOpenID):
def __init__(self, logger: AuthLogger, settings: KeycloakSettings):
def __init__(self, settings: KeycloakSettings):
KeycloakOpenID.__init__(
self,
server_url=settings.url,
@@ -16,7 +18,7 @@ class KeycloakClient(KeycloakOpenID):
realm_name=settings.realm,
client_secret_key=settings.client_secret,
)
logger.info("Initializing Keycloak client")
_logger.info("Initializing Keycloak client")
def get_user_id(self, token: str) -> Optional[str]:
info = self.introspect(token)

View File

@@ -1,5 +1,5 @@
from cpl.core.utils.get_value import get_value
from cpl.dependency import ServiceProvider
from cpl.dependency import ServiceProviderABC
class KeycloakUser:
@@ -32,5 +32,5 @@ class KeycloakUser:
def id(self) -> str:
from cpl.auth import KeycloakAdmin
keycloak_admin: KeycloakAdmin = get_provider().get_service(KeycloakAdmin)
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
return keycloak_admin.get_user_id(self._username)

View File

@@ -1,7 +0,0 @@
from cpl.core.log.wrapped_logger import WrappedLogger
class AuthLogger(WrappedLogger):
def __init__(self):
WrappedLogger.__init__(self, "auth")

View File

@@ -1,4 +0,0 @@
from .permission_module import PermissionsModule
from .permission_seeder import PermissionSeeder
from .permissions import Permissions
from .permissions_registry import PermissionsRegistry

View File

@@ -1,20 +0,0 @@
from cpl.auth.auth_module import AuthModule
from cpl.auth.permission.permission_seeder import PermissionSeeder
from cpl.auth.permission.permissions import Permissions
from cpl.auth.permission.permissions_registry import PermissionsRegistry
from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.dependency.module import Module, TModule
from cpl.dependency.service_collection import ServiceCollection
class PermissionsModule(Module):
@staticmethod
def dependencies() -> list[TModule]:
from cpl.database.database_module import DatabaseModule
return [DatabaseModule, AuthModule]
@staticmethod
def register(collection: ServiceCollection):
collection.add_singleton(DataSeederABC, PermissionSeeder)
PermissionsRegistry.with_enum(Permissions)

View File

@@ -14,13 +14,14 @@ from cpl.auth.schema import (
)
from cpl.core.utils.get_value import get_value
from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.logger import DBLogger
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class PermissionSeeder(DataSeederABC):
def __init__(
self,
logger: DBLogger,
permission_dao: PermissionDao,
role_dao: RoleDao,
role_permission_dao: RolePermissionDao,
@@ -28,7 +29,6 @@ class PermissionSeeder(DataSeederABC):
api_key_permission_dao: ApiKeyPermissionDao,
):
DataSeederABC.__init__(self)
self._logger = logger
self._permission_dao = permission_dao
self._role_dao = role_dao
self._role_permission_dao = role_permission_dao
@@ -40,7 +40,7 @@ class PermissionSeeder(DataSeederABC):
possible_permissions = [permission for permission in PermissionsRegistry.get()]
if len(permissions) == len(possible_permissions):
self._logger.info("Permissions already existing")
_logger.info("Permissions already existing")
await self._update_missing_descriptions()
return
@@ -53,7 +53,7 @@ class PermissionSeeder(DataSeederABC):
await self._permission_dao.delete_many(to_delete, hard_delete=True)
self._logger.warning("Permissions incomplete")
_logger.warning("Permissions incomplete")
permission_names = [permission.name for permission in permissions]
await self._permission_dao.create_many(
[

View File

@@ -10,8 +10,7 @@ from cpl.core.log.logger import Logger
from cpl.core.typing import Id, SerialId
from cpl.core.utils.credential_manager import CredentialManager
from cpl.database.abc.db_model_abc import DbModelABC
from cpl.dependency import get_provider
from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.service_provider_abc import ServiceProviderABC
_logger = Logger(__name__)
@@ -48,7 +47,7 @@ class ApiKey(DbModelABC):
async def permissions(self):
from cpl.auth.schema._permission.api_key_permission_dao import ApiKeyPermissionDao
apiKeyPermissionDao = get_provider().get_service(ApiKeyPermissionDao)
apiKeyPermissionDao = ServiceProviderABC.get_global_provider().get_service(ApiKeyPermissionDao)
return [await x.permission for x in await apiKeyPermissionDao.find_by_api_key_id(self.id)]

View File

@@ -3,12 +3,15 @@ from typing import Optional
from cpl.auth.schema._administration.api_key import ApiKey
from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class ApiKeyDao(DbModelDaoABC[ApiKey]):
def __init__(self):
DbModelDaoABC.__init__(self, ApiKey, TableManager.get("api_keys"))
DbModelDaoABC.__init__(self, __name__, ApiKey, TableManager.get("api_keys"))
self.attribute(ApiKey.identifier, str)
self.attribute(ApiKey.key, str, "keystring")

View File

@@ -6,11 +6,13 @@ from async_property import async_property
from keycloak import KeycloakGetError
from cpl.auth.keycloak import KeycloakAdmin
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.permission.permissions import Permissions
from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC
from cpl.database.logger import DBLogger
from cpl.dependency import ServiceProvider
from cpl.dependency import ServiceProviderABC
_logger = AuthLogger(__name__)
class AuthUser(DbModelABC):
@@ -36,13 +38,12 @@ class AuthUser(DbModelABC):
return "ANONYMOUS"
try:
keycloak = get_provider().get_service(KeycloakAdmin)
return keycloak.get_user(self._keycloak_id).get("username")
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
return keycloak_admin.get_user(self._keycloak_id).get("username")
except KeycloakGetError as e:
return "UNKNOWN"
except Exception as e:
logger = get_provider().get_service(DBLogger)
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
_logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
return "UNKNOWN"
@property
@@ -51,39 +52,38 @@ class AuthUser(DbModelABC):
return "ANONYMOUS"
try:
keycloak = get_provider().get_service(KeycloakAdmin)
return keycloak.get_user(self._keycloak_id).get("email")
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
return keycloak_admin.get_user(self._keycloak_id).get("email")
except KeycloakGetError as e:
return "UNKNOWN"
except Exception as e:
logger = get_provider().get_service(DBLogger)
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
_logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
return "UNKNOWN"
@async_property
async def roles(self):
from cpl.auth.schema._permission.role_user_dao import RoleUserDao
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao)
role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
return [await x.role for x in await role_user_dao.get_by_user_id(self.id)]
@async_property
async def permissions(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
return await auth_user_dao.get_permissions(self.id)
async def has_permission(self, permission: Permissions) -> bool:
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
return await auth_user_dao.has_permission(self.id, permission)
async def anonymize(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
self._keycloak_id = str(uuid.UUID(int=0))
await auth_user_dao.update(self)

View File

@@ -4,14 +4,17 @@ from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema._administration.auth_user import AuthUser
from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.dependency import ServiceProvider
from cpl.dependency import ServiceProviderABC
_logger = DBLogger(__name__)
class AuthUserDao(DbModelDaoABC[AuthUser]):
def __init__(self):
DbModelDaoABC.__init__(self, AuthUser, TableManager.get("auth_users"))
DbModelDaoABC.__init__(self, __name__, AuthUser, TableManager.get("auth_users"))
self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId")
@@ -36,7 +39,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool:
from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
result = await self._db.select_map(
f"""

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId
from cpl.database.abc import DbJoinModelABC
from cpl.dependency import ServiceProvider
from cpl.dependency import ServiceProviderABC
class ApiKeyPermission(DbJoinModelABC):
@@ -31,7 +31,7 @@ class ApiKeyPermission(DbJoinModelABC):
async def api_key(self):
from cpl.auth.schema._administration.api_key_dao import ApiKeyDao
api_key_dao: ApiKeyDao = get_provider().get_service(ApiKeyDao)
api_key_dao: ApiKeyDao = ServiceProviderABC.get_global_service(ApiKeyDao)
return await api_key_dao.get_by_id(self._api_key_id)
@property
@@ -42,5 +42,5 @@ class ApiKeyPermission(DbJoinModelABC):
async def permission(self):
from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
return await permission_dao.get_by_id(self._permission_id)

View File

@@ -1,12 +1,15 @@
from cpl.auth.schema._permission.api_key_permission import ApiKeyPermission
from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class ApiKeyPermissionDao(DbModelDaoABC[ApiKeyPermission]):
def __init__(self):
DbModelDaoABC.__init__(self, ApiKeyPermission, TableManager.get("api_key_permissions"))
DbModelDaoABC.__init__(self, __name__, ApiKeyPermission, TableManager.get("api_key_permissions"))
self.attribute(ApiKeyPermission.api_key_id, int)
self.attribute(ApiKeyPermission.permission_id, int)

View File

@@ -3,12 +3,15 @@ from typing import Optional
from cpl.auth.schema._permission.permission import Permission
from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class PermissionDao(DbModelDaoABC[Permission]):
def __init__(self):
DbModelDaoABC.__init__(self, Permission, TableManager.get("permissions"))
DbModelDaoABC.__init__(self, __name__, Permission, TableManager.get("permissions"))
self.attribute(Permission.name, str)
self.attribute(Permission.description, Optional[str])

View File

@@ -6,7 +6,7 @@ from async_property import async_property
from cpl.auth.permission.permissions import Permissions
from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProvider
from cpl.dependency import ServiceProviderABC
class Role(DbModelABC):
@@ -44,22 +44,22 @@ class Role(DbModelABC):
async def permissions(self):
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao)
role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
return [await x.permission for x in await role_permission_dao.get_by_role_id(self.id)]
@async_property
async def users(self):
from cpl.auth.schema._permission.role_user_dao import RoleUserDao
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao)
role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
return [await x.user for x in await role_user_dao.get_by_role_id(self.id)]
async def has_permission(self, permission: Permissions) -> bool:
from cpl.auth.schema._permission.permission_dao import PermissionDao
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao)
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
p = await permission_dao.get_by_name(permission.value)

View File

@@ -1,11 +1,14 @@
from cpl.auth.schema._permission.role import Role
from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class RoleDao(DbModelDaoABC[Role]):
def __init__(self):
DbModelDaoABC.__init__(self, Role, TableManager.get("roles"))
DbModelDaoABC.__init__(self, __name__, Role, TableManager.get("roles"))
self.attribute(Role.name, str)
self.attribute(Role.description, str)

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProvider
from cpl.dependency import ServiceProviderABC
class RolePermission(DbModelABC):
@@ -31,7 +31,7 @@ class RolePermission(DbModelABC):
async def role(self):
from cpl.auth.schema._permission.role_dao import RoleDao
role_dao: RoleDao = get_provider().get_service(RoleDao)
role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
return await role_dao.get_by_id(self._role_id)
@property
@@ -42,5 +42,5 @@ class RolePermission(DbModelABC):
async def permission(self):
from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
return await permission_dao.get_by_id(self._permission_id)

View File

@@ -1,12 +1,15 @@
from cpl.auth.schema._permission.role_permission import RolePermission
from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class RolePermissionDao(DbModelDaoABC[RolePermission]):
def __init__(self):
DbModelDaoABC.__init__(self, RolePermission, TableManager.get("role_permissions"))
DbModelDaoABC.__init__(self, __name__, RolePermission, TableManager.get("role_permissions"))
self.attribute(RolePermission.role_id, int)
self.attribute(RolePermission.permission_id, int)

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId
from cpl.database.abc import DbJoinModelABC
from cpl.dependency import ServiceProvider
from cpl.dependency import ServiceProviderABC
class RoleUser(DbJoinModelABC):
@@ -31,7 +31,7 @@ class RoleUser(DbJoinModelABC):
async def user(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
return await auth_user_dao.get_by_id(self._user_id)
@property
@@ -42,5 +42,5 @@ class RoleUser(DbJoinModelABC):
async def role(self):
from cpl.auth.schema._permission.role_dao import RoleDao
role_dao: RoleDao = get_provider().get_service(RoleDao)
role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
return await role_dao.get_by_id(self._role_id)

View File

@@ -1,12 +1,15 @@
from cpl.auth.schema._permission.role_user import RoleUser
from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class RoleUserDao(DbModelDaoABC[RoleUser]):
def __init__(self):
DbModelDaoABC.__init__(self, RoleUser, TableManager.get("role_users"))
DbModelDaoABC.__init__(self, __name__, RoleUser, TableManager.get("role_users"))
self.attribute(RoleUser.role_id, int)
self.attribute(RoleUser.user_id, int)

View File

@@ -1,18 +1,17 @@
from contextvars import ContextVar
from typing import Optional
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.schema._administration.auth_user import AuthUser
from cpl.dependency import get_provider
_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None)
_logger = AuthLogger(__name__)
def set_user(user: Optional[AuthUser]):
from cpl.core.log.logger_abc import LoggerABC
logger = get_provider().get_service(LoggerABC)
logger.trace("Setting user context", user.id)
_user_context.set(user)
def set_user(user_id: Optional[AuthUser]):
_logger.trace("Setting user context", user_id)
_user_context.set(user_id)
def get_user() -> Optional[AuthUser]:

View File

@@ -3,20 +3,8 @@ import traceback
from cpl.core.console import Console
def dependency_error(src: str, package_name: str, e: ImportError = None) -> None:
Console.error(f"'{package_name}' is required to use feature: {src}. Please install it and try again.")
tb = traceback.format_exc()
if not tb.startswith("NoneType: None"):
Console.write_line("->", tb)
elif e is not None:
Console.write_line("->", str(e))
exit(1)
def module_dependency_error(src: str, module: str, e: ImportError = None) -> None:
Console.error(f"'{module}' is required to use feature: {src}. Please initialize it with `add_module({module})`.")
def dependency_error(package_name: str, e: ImportError) -> None:
Console.error(f"'{package_name}' is required to use this feature. Please install it and try again.")
tb = traceback.format_exc()
if not tb.startswith("NoneType: None"):
Console.write_line("->", tb)

View File

@@ -2,4 +2,3 @@ from .logger import Logger
from .logger_abc import LoggerABC
from .log_level import LogLevel
from .log_settings import LogSettings
from .structured_logger import StructuredLogger

View File

@@ -1,111 +0,0 @@
import asyncio
import importlib.util
import json
import traceback
from datetime import datetime
from starlette.requests import Request
from cpl.core.log.log_level import LogLevel
from cpl.core.log.logger import Logger
from cpl.core.typing import Source, Messages
from cpl.dependency import get_provider
class StructuredLogger(Logger):
def __init__(self, source: Source, file_prefix: str = None):
Logger.__init__(self, source, file_prefix)
@property
def log_file(self):
return f"logs/{self._file_prefix}_{datetime.now().strftime('%Y-%m-%d')}.jsonl"
def _log(self, level: LogLevel, *messages: Messages):
try:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
formatted_message = self._format_message(level.value, timestamp, *messages)
structured_message = self._get_structured_message(level.value, timestamp, formatted_message)
self._write_log_to_file(level, structured_message)
self._write_to_console(level, formatted_message)
except Exception as e:
print(f"Error while logging: {e} -> {traceback.format_exc()}")
def _get_structured_message(self, level: str, timestamp: str, messages: str) -> str:
structured_message = {
"timestamp": timestamp,
"level": level.upper(),
"source": self._source,
"messages": messages,
}
self._enrich_message_with_request(structured_message)
self._enrich_message_with_user(structured_message)
return json.dumps(structured_message, ensure_ascii=False)
@staticmethod
def _scope_to_json(request: Request, include_headers: bool = False) -> dict:
scope = dict(request.scope)
def convert(value):
if isinstance(value, bytes):
return value.decode("utf-8")
if isinstance(value, (list, tuple)):
return [convert(v) for v in value]
if isinstance(value, dict):
return {str(k): convert(v) for k, v in value.items()}
if not isinstance(value, (str, int, float, bool, type(None))):
return str(value)
return value
serializable_scope = {str(k): convert(v) for k, v in scope.items()}
if not include_headers and "headers" in serializable_scope:
serializable_scope["headers"] = "<omitted>"
return serializable_scope
def _enrich_message_with_request(self, message: dict):
if importlib.util.find_spec("cpl.api") is None:
return
from cpl.api.middleware.request import get_request
from starlette.requests import Request
request = get_request()
if request is None:
return
message["request"] = {
"url": str(request.url),
"method": request.method,
"scope": self._scope_to_json(request),
}
if isinstance(request, Request) and request.scope == "http":
request: Request = request # fix typing for IDEs
message["request"]["data"] = asyncio.create_task(request.body())
@staticmethod
def _enrich_message_with_user(message: dict):
if importlib.util.find_spec("cpl-auth") is None:
return
from cpl.core.ctx import get_user
user = get_user()
if user is None:
return
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin
keycloak = get_provider().get_service(KeycloakAdmin)
kc_user = keycloak.get_user(user.keycloak_id)
message["user"] = {
"id": str(user.id),
"username": kc_user.get("username"),
"email": kc_user.get("email"),
}

View File

@@ -1,101 +0,0 @@
import inspect
from typing import Type
from cpl.core.log import LoggerABC, LogLevel
from cpl.core.typing import Messages
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
class WrappedLogger(LoggerABC):
def __init__(self, file_prefix: str):
LoggerABC.__init__(self)
assert file_prefix is not None and file_prefix != "", "file_prefix must be a non-empty string"
self._source = None
self._file_prefix = file_prefix
self._set_logger()
@inject
def _set_logger(self, services: ServiceProvider):
from cpl.core.log import Logger
t_logger: Type[Logger] = services.get_service_type(LoggerABC)
if t_logger is None:
raise Exception("No LoggerABC service registered in ServiceProvider")
self._logger = t_logger(self._source, self._file_prefix)
def set_level(self, level: LogLevel):
self._logger.set_level(level)
def _format_message(self, level: str, timestamp, *messages: Messages) -> str:
return self._logger._format_message(level, timestamp, *messages)
@staticmethod
def _get_source() -> str | None:
stack = inspect.stack()
if len(stack) <= 1:
return None
from cpl.dependency import ServiceCollection
ignore_classes = [
ServiceProvider,
ServiceProvider.__subclasses__(),
ServiceCollection,
WrappedLogger,
WrappedLogger.__subclasses__(),
]
ignore_modules = [x.__module__ for x in ignore_classes if isinstance(x, type)]
for i, frame_info in enumerate(stack[1:]):
module = inspect.getmodule(frame_info.frame)
if module is None:
continue
if module.__name__ in ignore_classes or module in ignore_classes:
continue
if module in ignore_modules or module.__name__ in ignore_modules:
continue
if module.__name__ != __name__:
return module.__name__
return None
def _set_source(self):
self._source = self._get_source()
self._set_logger()
def header(self, string: str):
self._set_source()
self._logger.header(string)
def trace(self, *messages: Messages):
self._set_source()
self._logger.trace(*messages)
def debug(self, *messages: Messages):
self._set_source()
self._logger.debug(*messages)
def info(self, *messages: Messages):
self._set_source()
self._logger.info(*messages)
def warning(self, *messages: Messages):
self._set_source()
self._logger.warning(*messages)
def error(self, messages: str, e: Exception = None):
self._set_source()
self._logger.error(messages, e)
def fatal(self, messages: str, e: Exception = None):
self._set_source()
self._logger.fatal(messages, e)

View File

@@ -14,4 +14,3 @@ UuidId = str | UUID
SerialId = int
Id = UuidId | SerialId
TNumber = int | float | complex

View File

@@ -1,57 +0,0 @@
import time
import tracemalloc
from typing import List, Callable
from cpl.core.console import Console
class Benchmark:
@staticmethod
def all(label: str, func: Callable, iterations: int = 5):
times: List[float] = []
mems: List[float] = []
for _ in range(iterations):
start = time.perf_counter()
func()
end = time.perf_counter()
times.append(end - start)
for _ in range(iterations):
tracemalloc.start()
func()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
mems.append(peak)
avg_time = sum(times) / len(times)
avg_mem = sum(mems) / len(mems) / (1024 * 1024)
Console.write_line(f"{label:20s} -> min {min(times):.6f}s avg {avg_time:.6f}s mem {avg_mem:.8f} MB")
@staticmethod
def time(label: str, func: Callable, iterations: int = 5):
times: List[float] = []
for _ in range(iterations):
start = time.perf_counter()
func()
end = time.perf_counter()
times.append(end - start)
avg_time = sum(times) / len(times)
Console.write_line(f"{label:20s} -> min {min(times):.6f}s avg {avg_time:.6f}s")
@staticmethod
def memory(label: str, func: Callable, iterations: int = 5):
mems: List[float] = []
for _ in range(iterations):
tracemalloc.start()
func()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
mems.append(peak)
avg_mem = sum(mems) / len(mems) / (1024 * 1024)
Console.write_line(f"{label:20s} -> mem {avg_mem:.2f} MB")

View File

@@ -1,100 +0,0 @@
import threading
import time
from typing import Generic
from cpl.core.typing import T
class Cache(Generic[T]):
def __init__(self, default_ttl: int = None, cleanup_interval: int = 60, t: type = None):
self._store = {}
self._default_ttl = default_ttl
self._lock = threading.Lock()
self._cleanup_interval = cleanup_interval
self._stop_event = threading.Event()
self._type = t
# Start background cleanup thread
self._thread = threading.Thread(target=self._auto_cleanup, daemon=True)
self._thread.start()
def set(self, key: str, value: T, ttl: int = None) -> None:
"""Store a value in the cache with optional TTL override."""
expire_at = None
ttl = ttl if ttl is not None else self._default_ttl
if ttl is not None:
expire_at = time.time() + ttl
with self._lock:
self._store[key] = (value, expire_at)
def get(self, key: str) -> T | None:
"""Retrieve a value from the cache if not expired."""
with self._lock:
item = self._store.get(key)
if not item:
return None
value, expire_at = item
if expire_at and expire_at < time.time():
# Expired -> remove and return None
del self._store[key]
return None
return value
def get_all(self) -> list[T]:
"""Retrieve all non-expired values from the cache."""
now = time.time()
with self._lock:
valid_items = []
expired_keys = []
for k, (v, exp) in self._store.items():
if exp and exp < now:
expired_keys.append(k)
else:
valid_items.append(v)
for k in expired_keys:
del self._store[k]
return valid_items
def has(self, key: str) -> bool:
"""Check if a key exists and is not expired."""
with self._lock:
item = self._store.get(key)
if not item:
return False
_, expire_at = item
if expire_at and expire_at < time.time():
# Expired -> remove and return False
del self._store[key]
return False
return True
def delete(self, key: str) -> None:
"""Remove an item from the cache."""
with self._lock:
self._store.pop(key, None)
def clear(self) -> None:
"""Clear the entire cache."""
with self._lock:
self._store.clear()
def _auto_cleanup(self):
"""Background thread to clean expired items."""
while not self._stop_event.is_set():
self.cleanup()
self._stop_event.wait(self._cleanup_interval)
def cleanup(self) -> None:
"""Remove expired items immediately."""
now = time.time()
with self._lock:
expired_keys = [k for k, (_, exp) in self._store.items() if exp and exp < now]
for k in expired_keys:
del self._store[k]
def stop(self):
"""Stop the background cleanup thread."""
self._stop_event.set()
self._thread.join()

View File

@@ -1,48 +0,0 @@
from typing import Any
class Number:
@staticmethod
def is_number(value: Any) -> bool:
"""Check if the value is a number (int or float)."""
return isinstance(value, (int, float, complex))
@staticmethod
def to_number(value: Any) -> int | float | complex:
"""
Convert a given value into int, float, or complex.
Raises ValueError if conversion is not possible.
"""
if isinstance(value, (int, float, complex)):
return value
if isinstance(value, str):
value = value.strip()
for caster in (int, float, complex):
try:
return caster(value)
except ValueError:
continue
raise ValueError(f"Cannot convert string '{value}' to number.")
if isinstance(value, bool):
return int(value)
try:
return int(value)
except Exception:
pass
try:
return float(value)
except Exception:
pass
try:
return complex(value)
except Exception:
pass
raise ValueError(f"Cannot convert type {type(value)} to number.")

View File

@@ -114,15 +114,12 @@ class String:
characters = []
if letters:
characters.extend(string.ascii_letters)
characters.append(string.ascii_letters)
if digits:
characters.extend(string.digits)
characters.append(string.digits)
if special_characters:
characters.extend(string.punctuation)
characters.append(string.punctuation)
x = "".join(random.choice(list(characters)) for _ in range(length)) if characters else ""
if len(x) != length:
raise Exception("No characters selected to generate random string")
return x
return "".join(random.choice(characters) for _ in range(length)) if characters else ""

View File

@@ -1,14 +1,15 @@
import os
from typing import Type
from cpl.application.abc import ApplicationABC as _ApplicationABC
from cpl.dependency import ServiceCollection as _ServiceCollection
from . import mysql as _mysql
from . import postgres as _postgres
from .database_module import DatabaseModule
from .table_manager import TableManager
from .logger import DBLogger
def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC:
from cpl.application.host import Host
from cpl.database.service.migration_service import MigrationService
migration_service = self._services.get_service(MigrationService)
@@ -20,6 +21,8 @@ def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _Applica
for path in paths:
migration_service.with_directory(path)
Host.run(migration_service.migrate)
return self
@@ -32,5 +35,43 @@ def _with_seeders(self: _ApplicationABC) -> _ApplicationABC:
return self
def _add(collection: _ServiceCollection, db_context: Type, default_port: int, server_type: str):
from cpl.core.console import Console
from cpl.core.configuration import Configuration
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.model.server_type import ServerTypes, ServerType
from cpl.database.model.database_settings import DatabaseSettings
from cpl.database.service.migration_service import MigrationService
from cpl.database.service.seeder_service import SeederService
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
try:
ServerType.set_server_type(ServerTypes(server_type))
Configuration.set("DB_DEFAULT_PORT", default_port)
collection.add_singleton(DBContextABC, db_context)
collection.add_singleton(ExecutedMigrationDao)
collection.add_singleton(MigrationService)
collection.add_singleton(SeederService)
except ImportError as e:
Console.error("cpl-database is not installed", str(e))
def add_mysql(collection: _ServiceCollection):
from cpl.database.mysql.db_context import DBContext
from cpl.database.model import ServerTypes
_add(collection, DBContext, 3306, ServerTypes.MYSQL.value)
def add_postgres(collection: _ServiceCollection):
from cpl.database.mysql.db_context import DBContext
from cpl.database.model import ServerTypes
_add(collection, DBContext, 5432, ServerTypes.POSTGRES.value)
_ServiceCollection.with_module(add_mysql, _mysql.__name__)
_ServiceCollection.with_module(add_postgres, _postgres.__name__)
_ApplicationABC.extend(_ApplicationABC.with_migrations, _with_migrations)
_ApplicationABC.extend(_ApplicationABC.with_seeders, _with_seeders)

View File

@@ -1,6 +1,4 @@
from .connection_abc import ConnectionABC
from .data_access_object_abc import DataAccessObjectABC
from .data_seeder_abc import DataSeederABC
from .db_context_abc import DBContextABC
from .db_join_model_abc import DbJoinModelABC
from .db_model_abc import DbModelABC

View File

@@ -9,20 +9,25 @@ from cpl.core.utils.get_value import get_value
from cpl.core.utils.string import String
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.const import DATETIME_FORMAT
from cpl.database.db_logger import DBLogger
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.database.logger import DBLogger
from cpl.database.model.server_type import ServerType, ServerTypes
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
from cpl.dependency.context import get_provider
class DataAccessObjectABC(ABC, Generic[T_DBM]):
@abstractmethod
def __init__(self, model_type: Type[T_DBM], table_name: str):
self._db = get_provider().get_service(DBContextABC)
self._logger = get_provider().get_service(DBLogger)
def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
from cpl.dependency.service_provider_abc import ServiceProviderABC
self._db = ServiceProviderABC.get_global_service(DBContextABC)
self._logger = DBLogger(source)
self._model_type = model_type
self._table_name = table_name
self._logger = DBLogger(source)
self._model_type = model_type
self._table_name = table_name
@@ -352,13 +357,13 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
values = f"{await self._get_editor_id(obj) if not skip_editor else ''}{f', {values}' if not skip_editor and len(values) > 0 else f'{values}'}"
return f"""
INSERT INTO {self._table_name} (
{fields}
) VALUES (
{values}
)
{"RETURNING {self.__primary_key};" if ServerType.server_type == ServerTypes.POSTGRES else ";SELECT LAST_INSERT_ID();"}
"""
INSERT INTO {self._table_name} (
{fields}
) VALUES (
{values}
)
RETURNING {self.__primary_key};
"""
async def create(self, obj: T_DBM, skip_editor=False) -> int:
self._logger.debug(f"create {type(obj).__name__} {obj.__dict__}")

View File

@@ -2,7 +2,7 @@ from abc import abstractmethod
from datetime import datetime
from typing import Type
from cpl.database.table_manager import TableManager
from cpl.database import TableManager
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.database.abc.db_model_abc import DbModelABC
@@ -10,8 +10,8 @@ from cpl.database.abc.db_model_abc import DbModelABC
class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]):
@abstractmethod
def __init__(self, model_type: Type[T_DBM], table_name: str):
DataAccessObjectABC.__init__(self, model_type, table_name)
def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
DataAccessObjectABC.__init__(self, source, model_type, table_name)
self.attribute(DbModelABC.id, int, ignore=True)
self.attribute(DbModelABC.deleted, bool)

View File

@@ -1,23 +0,0 @@
from cpl.core.errors import module_dependency_error
from cpl.database.model.server_type import ServerType
from cpl.dependency.module import Module, TModule
from cpl.dependency.service_collection import ServiceCollection
class DatabaseModule(Module):
@staticmethod
def dependencies() -> list[TModule]:
if not ServerType.has_server_type:
module_dependency_error(__name__, "MySQLModule or PostgresModule")
return []
@staticmethod
def register(collection: ServiceCollection):
from cpl.database.schema import ExecutedMigrationDao
from cpl.database.service.migration_service import MigrationService
from cpl.database.service.seeder_service import SeederService
collection.add_singleton(ExecutedMigrationDao)
collection.add_singleton(MigrationService)
collection.add_singleton(SeederService)

View File

@@ -0,0 +1,8 @@
from cpl.core.log import Logger
from cpl.core.typing import Source
class DBLogger(Logger):
def __init__(self, source: Source):
Logger.__init__(self, source, "db")

View File

@@ -1,7 +0,0 @@
from cpl.core.log.wrapped_logger import WrappedLogger
class DBLogger(WrappedLogger):
def __init__(self):
WrappedLogger.__init__(self, "db")

View File

@@ -15,11 +15,6 @@ class ServerType:
assert isinstance(server_type, ServerTypes), f"Expected ServerType but got {type(server_type)}"
cls._server_type = server_type
@classmethod
@property
def has_server_type(cls) -> bool:
return cls._server_type is not None
@classmethod
@property
def server_type(cls) -> ServerTypes:

View File

@@ -1,4 +0,0 @@
from .connection import DatabaseConnection
from .db_context import DBContext
from .mysql_module import MySQLModule
from .mysql_pool import MySQLPool

View File

@@ -4,17 +4,18 @@ from typing import Any, List, Dict, Tuple, Union
from mysql.connector import Error as MySQLError, PoolError
from cpl.core.configuration import Configuration
from cpl.core.environment import Environment
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.logger import DBLogger
from cpl.database.db_logger import DBLogger
from cpl.database.model.database_settings import DatabaseSettings
from cpl.database.mysql.mysql_pool import MySQLPool
_logger = DBLogger(__name__)
class DBContext(DBContextABC):
def __init__(self, logger: DBLogger):
def __init__(self):
DBContextABC.__init__(self)
self._logger = logger
self._pool: MySQLPool = None
self._fails = 0
@@ -22,62 +23,62 @@ class DBContext(DBContextABC):
def connect(self, database_settings: DatabaseSettings):
try:
self._logger.debug("Connecting to database")
_logger.debug("Connecting to database")
self._pool = MySQLPool(
database_settings,
)
self._logger.info("Connected to database")
_logger.info("Connected to database")
except Exception as e:
self._logger.fatal("Connecting to database failed", e)
_logger.fatal("Connecting to database failed", e)
async def execute(self, statement: str, args=None, multi=True) -> List[List]:
self._logger.trace(f"execute {statement} with args: {args}")
_logger.trace(f"execute {statement} with args: {args}")
return await self._pool.execute(statement, args, multi)
async def select_map(self, statement: str, args=None) -> List[Dict]:
self._logger.trace(f"select {statement} with args: {args}")
_logger.trace(f"select {statement} with args: {args}")
try:
return await self._pool.select_map(statement, args)
except (MySQLError, PoolError) as e:
if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4()
raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
)
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1
try:
self._logger.debug("Retry select")
_logger.debug("Retry select")
return await self.select_map(statement, args)
except Exception as e:
pass
return []
except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
raise e
async def select(self, statement: str, args=None) -> Union[List[str], List[Tuple], List[Any]]:
self._logger.trace(f"select {statement} with args: {args}")
_logger.trace(f"select {statement} with args: {args}")
try:
return await self._pool.select(statement, args)
except (MySQLError, PoolError) as e:
if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4()
raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
)
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1
try:
self._logger.debug("Retry select")
_logger.debug("Retry select")
return await self.select(statement, args)
except Exception as e:
pass
return []
except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
raise e

View File

@@ -1,20 +0,0 @@
from cpl.core.configuration.configuration import Configuration
from cpl.database.model.server_type import ServerTypes, ServerType
from cpl.dependency.module import Module, TModule
from cpl.dependency.service_collection import ServiceCollection
class MySQLModule(Module):
@staticmethod
def dependencies() -> list[TModule]:
return []
@staticmethod
def register(collection: ServiceCollection):
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.mysql.db_context import DBContext
ServerType.set_server_type(ServerTypes(ServerTypes.MYSQL.value))
Configuration.set("DB_DEFAULT_PORT", 3306)
collection.add_singleton(DBContextABC, DBContext)

View File

@@ -4,9 +4,10 @@ import sqlparse
from mysql.connector.aio import MySQLConnectionPool
from cpl.core.environment import Environment
from cpl.database.logger import DBLogger
from cpl.database.db_logger import DBLogger
from cpl.database.model import DatabaseSettings
from cpl.dependency.context import get_provider
_logger = DBLogger(__name__)
class MySQLPool:
@@ -18,11 +19,7 @@ class MySQLPool:
"user": database_settings.user,
"password": database_settings.password,
"database": database_settings.database,
"charset": database_settings.charset,
"use_unicode": database_settings.use_unicode,
"buffered": database_settings.buffered,
"auth_plugin": database_settings.auth_plugin,
"ssl_disabled": False,
"ssl_disabled": True,
}
self._pool: Optional[MySQLConnectionPool] = None
@@ -39,8 +36,7 @@ class MySQLPool:
await cursor.execute("SELECT 1")
await cursor.fetchall()
except Exception as e:
logger = get_provider().get_service(DBLogger)
logger.fatal(f"Error connecting to the database: {e}")
_logger.fatal(f"Error connecting to the database: {e}")
finally:
await con.close()

View File

@@ -1,4 +0,0 @@
from .db_context import DBContext
from .postgres_module import PostgresModule
from .postgres_pool import PostgresPool
from .sql_select_builder import SQLSelectBuilder

View File

@@ -7,16 +7,16 @@ from psycopg_pool import PoolTimeout
from cpl.core.configuration import Configuration
from cpl.core.environment import Environment
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.logger import DBLogger
from cpl.database.model import DatabaseSettings
from cpl.database.database_settings import DatabaseSettings
from cpl.database.db_logger import DBLogger
from cpl.database.postgres.postgres_pool import PostgresPool
_logger = DBLogger(__name__)
class DBContext(DBContextABC):
def __init__(self, logger: DBLogger):
def __init__(self):
DBContextABC.__init__(self)
self._logger = logger
self._pool: PostgresPool = None
self._fails = 0
@@ -24,63 +24,63 @@ class DBContext(DBContextABC):
def connect(self, database_settings: DatabaseSettings):
try:
self._logger.debug("Connecting to database")
_logger.debug("Connecting to database")
self._pool = PostgresPool(
database_settings,
Environment.get("DB_POOL_SIZE", int, 1),
)
self._logger.info("Connected to database")
_logger.info("Connected to database")
except Exception as e:
self._logger.fatal("Connecting to database failed", e)
_logger.fatal("Connecting to database failed", e)
async def execute(self, statement: str, args=None, multi=True) -> list[list]:
self._logger.trace(f"execute {statement} with args: {args}")
_logger.trace(f"execute {statement} with args: {args}")
return await self._pool.execute(statement, args, multi)
async def select_map(self, statement: str, args=None) -> list[dict]:
self._logger.trace(f"select {statement} with args: {args}")
_logger.trace(f"select {statement} with args: {args}")
try:
return await self._pool.select_map(statement, args)
except (OperationalError, PoolTimeout) as e:
if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4()
raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
)
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1
try:
self._logger.debug("Retry select")
_logger.debug("Retry select")
return await self.select_map(statement, args)
except Exception as e:
pass
return []
except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
raise e
async def select(self, statement: str, args=None) -> list[str] | list[tuple] | list[Any]:
self._logger.trace(f"select {statement} with args: {args}")
_logger.trace(f"select {statement} with args: {args}")
try:
return await self._pool.select(statement, args)
except (OperationalError, PoolTimeout) as e:
if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4()
raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
)
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1
try:
self._logger.debug("Retry select")
_logger.debug("Retry select")
return await self.select(statement, args)
except Exception as e:
pass
return []
except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
raise e

View File

@@ -1,21 +0,0 @@
from cpl.core.configuration.configuration import Configuration
from cpl.database.database_module import DatabaseModule
from cpl.database.model.server_type import ServerTypes, ServerType
from cpl.dependency.module import Module, TModule
from cpl.dependency.service_collection import ServiceCollection
class PostgresModule(Module):
@staticmethod
def dependencies() -> list[TModule]:
return [DatabaseModule]
@staticmethod
def register(collection: ServiceCollection):
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.postgres.db_context import DBContext
ServerType.set_server_type(ServerTypes(ServerTypes.POSTGRES.value))
Configuration.set("DB_DEFAULT_PORT", 5432)
collection.add_singleton(DBContextABC, DBContext)

View File

@@ -5,9 +5,10 @@ from psycopg import sql
from psycopg_pool import AsyncConnectionPool, PoolTimeout
from cpl.core.environment import Environment
from cpl.database.logger import DBLogger
from cpl.database.db_logger import DBLogger
from cpl.database.model import DatabaseSettings
from cpl.dependency import ServiceProvider
_logger = DBLogger(__name__)
class PostgresPool:
@@ -37,8 +38,7 @@ class PostgresPool:
await pool.check_connection(con)
except PoolTimeout as e:
await pool.close()
logger = get_provider().get_service(DBLogger)
logger.fatal(f"Failed to connect to the database", e)
_logger.fatal(f"Failed to connect to the database", e)
self._pool = pool
return self._pool

View File

@@ -1,2 +0,0 @@
from .executed_migration import ExecutedMigration
from .executed_migration_dao import ExecutedMigrationDao

View File

@@ -1,11 +1,14 @@
from cpl.database.table_manager import TableManager
from cpl.database import TableManager
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.database.db_logger import DBLogger
from cpl.database.schema.executed_migration import ExecutedMigration
_logger = DBLogger(__name__)
class ExecutedMigrationDao(DataAccessObjectABC[ExecutedMigration]):
def __init__(self):
DataAccessObjectABC.__init__(self, ExecutedMigration, TableManager.get("executed_migrations"))
DataAccessObjectABC.__init__(self, __name__, ExecutedMigration, TableManager.get("executed_migrations"))
self.attribute(ExecutedMigration.migration_id, str, primary_key=True, db_name="migrationId")

View File

@@ -1,22 +1,21 @@
import glob
import os
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.logger import DBLogger
from cpl.database.model.migration import Migration
from cpl.database.abc import DBContextABC
from cpl.database.db_logger import DBLogger
from cpl.database.model import Migration
from cpl.database.model.server_type import ServerType, ServerTypes
from cpl.database.schema.executed_migration import ExecutedMigration
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
from cpl.dependency.hosted.startup_task import StartupTask
_logger = DBLogger(__name__)
class MigrationService(StartupTask):
class MigrationService:
def __init__(self, logger: DBLogger, db: DBContextABC, executed_migration_dao: ExecutedMigrationDao):
StartupTask.__init__(self)
self._logger = logger
def __init__(self, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao):
self._db = db
self._executed_migration_dao = executed_migration_dao
self._executedMigrationDao = executedMigrationDao
self._script_directories: list[str] = []
@@ -25,15 +24,12 @@ class MigrationService(StartupTask):
elif ServerType.server_type == ServerTypes.MYSQL:
self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/mysql"))
async def run(self):
await self._execute(self._load_scripts())
def with_directory(self, directory: str) -> "MigrationService":
self._script_directories.append(directory)
return self
async def _get_migration_history(self) -> list[ExecutedMigration]:
results = await self._db.select(f"SELECT * FROM {self._executed_migration_dao.table_name}")
results = await self._db.select(f"SELECT * FROM {self._executedMigrationDao.table_name}")
applied_migrations = []
for result in results:
applied_migrations.append(ExecutedMigration(result[0]))
@@ -96,17 +92,20 @@ class MigrationService(StartupTask):
try:
# check if table exists
if len(result) > 0:
migration_from_db = await self._executed_migration_dao.find_by_id(migration.name)
migration_from_db = await self._executedMigrationDao.find_by_id(migration.name)
if migration_from_db is not None:
continue
self._logger.debug(f"Running upgrade migration: {migration.name}")
_logger.debug(f"Running upgrade migration: {migration.name}")
await self._db.execute(migration.script, multi=True)
await self._executed_migration_dao.create(ExecutedMigration(migration.name), skip_editor=True)
await self._executedMigrationDao.create(ExecutedMigration(migration.name), skip_editor=True)
except Exception as e:
self._logger.fatal(
_logger.fatal(
f"Migration failed: {migration.name}\n{active_statement}",
e,
)
async def migrate(self):
await self._execute(self._load_scripts())

View File

@@ -1,16 +1,18 @@
from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.logger import DBLogger
from cpl.dependency import ServiceProvider
from cpl.database.db_logger import DBLogger
from cpl.dependency import ServiceProviderABC
_logger = DBLogger(__name__)
class SeederService:
def __init__(self, provider: ServiceProvider):
def __init__(self, provider: ServiceProviderABC):
self._provider = provider
self._logger = provider.get_service(DBLogger)
async def seed(self):
seeders = self._provider.get_services(DataSeederABC)
self._logger.debug(f"Found {len(seeders)} seeders")
_logger.debug(f"Found {len(seeders)} seeders")
for seeder in seeders:
await seeder.seed()

View File

@@ -1,7 +1,7 @@
from .context import get_provider, use_provider
from .inject import inject
from .scope import Scope
from .scope_abc import ScopeABC
from .service_collection import ServiceCollection
from .service_descriptor import ServiceDescriptor
from .service_lifetime import ServiceLifetimeEnum
from .service_provider import ServiceProvider
from .service_lifetime_enum import ServiceLifetimeEnum
from .service_provider import ServiceProvider
from .service_provider_abc import ServiceProviderABC

View File

@@ -1,22 +0,0 @@
import contextvars
from contextlib import contextmanager
_current_provider = contextvars.ContextVar("current_provider", default=None)
def use_root_provider(provider: "ServiceProvider"):
_current_provider.set(provider)
@contextmanager
def use_provider(provider: "ServiceProvider"):
token = _current_provider.set(provider)
try:
yield
finally:
_current_provider.reset(token)
def get_provider() -> "ServiceProvider":
return _current_provider.get()

View File

@@ -1,2 +0,0 @@
from .hosted_service import HostedService
from .startup_task import StartupTask

View File

@@ -1,9 +0,0 @@
from abc import ABC, abstractmethod
class HostedService(ABC):
@abstractmethod
async def start(self): ...
@abstractmethod
async def stop(self): ...

View File

@@ -1,6 +0,0 @@
from abc import ABC, abstractmethod
class StartupTask(ABC):
@abstractmethod
async def run(self): ...

View File

@@ -1,42 +0,0 @@
import functools
from asyncio import iscoroutinefunction
from inspect import signature
from cpl.dependency.context import get_provider
def inject(f=None):
if f is None:
return functools.partial(inject)
if iscoroutinefunction(f):
@functools.wraps(f)
async def async_inner(*args, **kwargs):
from cpl.dependency.service_provider import ServiceProvider
provider: ServiceProvider | None = get_provider()
if provider is None:
raise ValueError(
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
)
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
return await f(*args, *injection, **kwargs)
return async_inner
@functools.wraps(f)
def inner(*args, **kwargs):
from cpl.dependency.service_provider import ServiceProvider
provider: ServiceProvider | None = get_provider()
if provider is None:
raise ValueError(
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
)
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
return f(*args, *injection, **kwargs)
return inner

View File

@@ -1,15 +0,0 @@
from abc import abstractmethod, ABC
from typing import Type
TModule = Type["Module"]
class Module(ABC):
@staticmethod
@abstractmethod
def dependencies() -> list[TModule]: ...
@staticmethod
@abstractmethod
def register(collection: "ServiceCollection"): ...

View File

@@ -0,0 +1,22 @@
from cpl.dependency.scope_abc import ScopeABC
from cpl.dependency.service_provider_abc import ServiceProviderABC
class Scope(ScopeABC):
def __init__(self, service_provider: ServiceProviderABC):
self._service_provider = service_provider
self._service_provider.set_scope(self)
ScopeABC.__init__(self)
def __enter__(self):
return self
def __exit__(self, *args):
self.dispose()
@property
def service_provider(self) -> ServiceProviderABC:
return self._service_provider
def dispose(self):
self._service_provider = None

View File

@@ -0,0 +1,20 @@
from abc import ABC, abstractmethod
class ScopeABC(ABC):
r"""ABC for the class :class:`cpl.dependency.scope.Scope`"""
def __init__(self): ...
@property
@abstractmethod
def service_provider(self):
r"""Returns to service provider of scope
Returns:
Object of type :class:`cpl.dependency.service_provider_abc.ServiceProviderABC`
"""
@abstractmethod
def dispose(self):
r"""Sets service_provider to None"""

View File

@@ -0,0 +1,18 @@
from cpl.dependency.scope import Scope
from cpl.dependency.scope_abc import ScopeABC
from cpl.dependency.service_provider_abc import ServiceProviderABC
class ScopeBuilder:
r"""Class to build :class:`cpl.dependency.scope.Scope`"""
def __init__(self, service_provider: ServiceProviderABC) -> None:
self._service_provider = service_provider
def build(self) -> ScopeABC:
r"""Returns scope
Returns:
Object of type :class:`cpl.dependency.scope.Scope`
"""
return Scope(self._service_provider)

View File

@@ -1,14 +1,12 @@
from inspect import isclass
from typing import Union, Type, Callable, Self
from cpl.core.log.logger import Logger
from cpl.core.log.logger_abc import LoggerABC
from cpl.core.typing import T, Service
from cpl.core.utils.cache import Cache
from cpl.dependency.hosted.startup_task import StartupTask
from cpl.dependency.module import Module
from cpl.dependency.service_descriptor import ServiceDescriptor
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.service_provider_abc import ServiceProviderABC
class ServiceCollection:
@@ -18,7 +16,7 @@ class ServiceCollection:
@classmethod
def with_module(cls, func: Callable, name: str = None) -> type[Self]:
# cls._modules[func.__name__ if name is None else name] = func
cls._modules[func.__name__ if name is None else name] = func
return cls
def __init__(self):
@@ -64,56 +62,23 @@ class ServiceCollection:
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service)
return self
def add_startup_task(self, task: Type[StartupTask]) -> Self:
self.add_singleton(StartupTask, task)
return self
def add_hosted_service(self, service_type: T, service: Service = None) -> Self:
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.hosted, service)
return self
def build(self) -> ServiceProvider:
def build(self) -> ServiceProviderABC:
sp = ServiceProvider(self._service_descriptors)
ServiceProviderABC.set_global_provider(sp)
return sp
def add_module(self, module: Type[Module]) -> Self:
assert isclass(module), "Module must be a Module"
assert issubclass(module, Module), f"Module must be subclass of {Module.__name__}"
def add_module(self, module: str | object) -> Self:
if not isinstance(module, str):
module = module.__name__
name = module.__name__
if module in self._modules:
if module not in self._modules:
raise ValueError(f"Module {module} not found")
for dependency in module.dependencies():
if dependency.__name__ not in self._loaded_modules:
self.add_module(dependency)
module().register(self)
if name not in self._loaded_modules:
self._loaded_modules.add(name)
self._modules[module](self)
if module not in self._loaded_modules:
self._loaded_modules.add(module)
return self
def add_logging(self) -> Self:
from cpl.core.log.logger import Logger
from cpl.core.log.wrapped_logger import WrappedLogger
self.add_transient(LoggerABC, Logger)
for wrapper in WrappedLogger.__subclasses__():
self.add_transient(wrapper)
return self
def add_structured_logging(self) -> Self:
from cpl.core.log.structured_logger import StructuredLogger
from cpl.core.log.wrapped_logger import WrappedLogger
self.add_transient(LoggerABC, StructuredLogger)
for wrapper in WrappedLogger.__subclasses__():
self.add_transient(wrapper)
return self
def add_cache(self, t: Type[T]):
self._service_descriptors.append(ServiceDescriptor(Cache(t=t), ServiceLifetimeEnum.singleton, Cache[t]))
return self

View File

@@ -1,6 +1,6 @@
from typing import Union, Optional
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
class ServiceDescriptor:

Some files were not shown because too many files have changed in this diff Show More