DI Provider ctx #186

This commit is contained in:
2025-09-24 20:46:43 +02:00
parent cf8edafd39
commit cdb4a0fb34
12 changed files with 87 additions and 58 deletions

View File

@@ -1,9 +1,10 @@
from cpl.dependency import ServiceProvider, ServiceProviderABC from cpl.dependency import ServiceProvider, ServiceProviderABC
from cpl.dependency.inject import inject
from di.test_service import TestService from di.test_service import TestService
class StaticTest: class StaticTest:
@staticmethod @staticmethod
@ServiceProvider.inject @inject
def test(services: ServiceProviderABC, t1: TestService): def test(services: ServiceProviderABC, t1: TestService):
t1.run() t1.run()

View File

@@ -27,6 +27,7 @@ from cpl.api.settings import ApiSettings
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
from cpl.application.abc.application_abc import ApplicationABC from cpl.application.abc.application_abc import ApplicationABC
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.dependency.inject import inject
from cpl.dependency.service_provider_abc import ServiceProviderABC from cpl.dependency.service_provider_abc import ServiceProviderABC
@@ -44,15 +45,15 @@ class WebApp(ApplicationABC):
self._policies = services.get_service(PolicyRegistry) self._policies = services.get_service(PolicyRegistry)
self._routes = services.get_service(RouteRegistry) self._routes = services.get_service(RouteRegistry)
self._middleware: list[Middleware] = [ self._middleware: list[Middleware] = []
Middleware(RequestMiddleware),
Middleware(LoggingMiddleware),
]
self._exception_handlers: Mapping[Any, ExceptionHandler] = { self._exception_handlers: Mapping[Any, ExceptionHandler] = {
Exception: self._handle_exception, Exception: self._handle_exception,
APIError: self._handle_exception, APIError: self._handle_exception,
} }
self.with_middleware(RequestMiddleware)
self.with_middleware(LoggingMiddleware)
async def _handle_exception(self, request: Request, exc: Exception): async def _handle_exception(self, request: Request, exc: Exception):
if isinstance(exc, APIError): if isinstance(exc, APIError):
self._logger.error(exc) self._logger.error(exc)
@@ -168,9 +169,9 @@ class WebApp(ApplicationABC):
self._check_for_app() self._check_for_app()
if isinstance(middleware, Middleware): if isinstance(middleware, Middleware):
self._middleware.append(middleware) self._middleware.append(inject(middleware))
elif callable(middleware): elif callable(middleware):
self._middleware.append(Middleware(middleware)) self._middleware.append(Middleware(inject(middleware)))
else: else:
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable") raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
@@ -220,7 +221,7 @@ class WebApp(ApplicationABC):
self._validate_policies() self._validate_policies()
if self._app is None: if self._app is None:
routes = [route.to_starlette(self._services.inject) for route in self._routes.all()] routes = [route.to_starlette(inject) for route in self._routes.all()]
app = Starlette( app = Starlette(
routes=routes, routes=routes,

View File

@@ -9,12 +9,10 @@ from cpl.api.router import Router
from cpl.auth.keycloak import KeycloakClient from cpl.auth.keycloak import KeycloakClient
from cpl.auth.schema import AuthUserDao, AuthUser from cpl.auth.schema import AuthUserDao, AuthUser
from cpl.core.ctx import set_user from cpl.core.ctx import set_user
from cpl.dependency import ServiceProviderABC
class AuthenticationMiddleware(ASGIMiddleware): class AuthenticationMiddleware(ASGIMiddleware):
@ServiceProviderABC.inject
def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao): def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)

View File

@@ -9,12 +9,10 @@ from cpl.api.registry.policy import PolicyRegistry
from cpl.api.router import Router from cpl.api.router import Router
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
from cpl.core.ctx.user_context import get_user from cpl.core.ctx.user_context import get_user
from cpl.dependency.service_provider_abc import ServiceProviderABC
class AuthorizationMiddleware(ASGIMiddleware): class AuthorizationMiddleware(ASGIMiddleware):
@ServiceProviderABC.inject
def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao): def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)

View File

@@ -6,12 +6,10 @@ from starlette.types import Receive, Scope, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.api.middleware.request import get_request from cpl.api.middleware.request import get_request
from cpl.dependency import ServiceProviderABC
class LoggingMiddleware(ASGIMiddleware): class LoggingMiddleware(ASGIMiddleware):
@ServiceProviderABC.inject
def __init__(self, app, logger: APILogger): def __init__(self, app, logger: APILogger):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)

View File

@@ -9,14 +9,12 @@ from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.api.typing import TRequest from cpl.api.typing import TRequest
from cpl.dependency import ServiceProviderABC
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None) _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
class RequestMiddleware(ASGIMiddleware): class RequestMiddleware(ASGIMiddleware):
@ServiceProviderABC.inject
def __init__(self, app, logger: APILogger): def __init__(self, app, logger: APILogger):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)

View File

@@ -7,6 +7,7 @@ from cpl.application.abc.startup_abc import StartupABC
from cpl.application.abc.startup_extension_abc import StartupExtensionABC from cpl.application.abc.startup_extension_abc import StartupExtensionABC
from cpl.application.host import Host from cpl.application.host import Host
from cpl.core.errors import dependency_error from cpl.core.errors import dependency_error
from cpl.dependency.context import get_current_provider, use_root_provider
from cpl.dependency.service_collection import ServiceCollection from cpl.dependency.service_collection import ServiceCollection
TApp = TypeVar("TApp", bound=ApplicationABC) TApp = TypeVar("TApp", bound=ApplicationABC)
@@ -21,6 +22,7 @@ class ApplicationBuilder(Generic[TApp]):
self._app = app if app is not None else ApplicationABC self._app = app if app is not None else ApplicationABC
self._services = ServiceCollection() self._services = ServiceCollection()
use_root_provider(self._services.build())
self._startup: Optional[StartupABC] = None self._startup: Optional[StartupABC] = None
self._app_extensions: list[Type[ApplicationExtensionABC]] = [] self._app_extensions: list[Type[ApplicationExtensionABC]] = []
@@ -34,7 +36,12 @@ class ApplicationBuilder(Generic[TApp]):
@property @property
def service_provider(self): def service_provider(self):
return self._services.build() provider = get_current_provider()
if provider is None:
provider = self._services.build()
use_root_provider(provider)
return provider
def validate_app_required_modules(self, app: ApplicationABC): def validate_app_required_modules(self, app: ApplicationABC):
for module in app.required_modules: for module in app.required_modules:

View File

@@ -3,6 +3,7 @@ from typing import Type
from cpl.core.log import LoggerABC, LogLevel from cpl.core.log import LoggerABC, LogLevel
from cpl.core.typing import Messages from cpl.core.typing import Messages
from cpl.dependency.inject import inject
from cpl.dependency.service_provider_abc import ServiceProviderABC from cpl.dependency.service_provider_abc import ServiceProviderABC
@@ -17,7 +18,7 @@ class WrappedLogger(LoggerABC):
self._set_logger() self._set_logger()
@ServiceProviderABC.inject @inject
def _set_logger(self, services: ServiceProviderABC): def _set_logger(self, services: ServiceProviderABC):
from cpl.core.log import Logger from cpl.core.log import Logger

View File

@@ -1,3 +1,5 @@
from .context import get_current_provider, use_provider
from .inject import inject
from .scope import Scope from .scope import Scope
from .scope_abc import ScopeABC from .scope_abc import ScopeABC
from .service_collection import ServiceCollection from .service_collection import ServiceCollection

View File

@@ -0,0 +1,21 @@
import contextvars
from contextlib import contextmanager
_current_provider = contextvars.ContextVar("current_provider", default=None)
def use_root_provider(provider):
_current_provider.set(provider)
@contextmanager
def use_provider(provider):
token = _current_provider.set(provider)
try:
yield
finally:
_current_provider.reset(token)
def get_current_provider():
return _current_provider.get()

View File

@@ -0,0 +1,42 @@
import functools
from asyncio import iscoroutinefunction
from inspect import signature
from cpl.dependency.context import get_current_provider
def inject(f=None):
if f is None:
return functools.partial(inject)
if iscoroutinefunction(f):
@functools.wraps(f)
async def async_inner(*args, **kwargs):
from cpl.dependency.service_provider import ServiceProvider
provider: ServiceProvider | None = get_current_provider()
if provider is None:
raise ValueError(
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
)
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
return await f(*args, *injection, **kwargs)
return async_inner
@functools.wraps(f)
def inner(*args, **kwargs):
from cpl.dependency.service_provider import ServiceProvider
provider: ServiceProvider | None = get_current_provider()
if provider is None:
raise ValueError(
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
)
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
return f(*args, *injection, **kwargs)
return inner

View File

@@ -1,9 +1,8 @@
import functools
from abc import abstractmethod, ABC from abc import abstractmethod, ABC
from inspect import Signature, signature, iscoroutinefunction from inspect import Signature
from typing import Optional, Type from typing import Optional, Type
from cpl.core.typing import T, R from cpl.core.typing import T
from cpl.dependency.scope_abc import ScopeABC from cpl.dependency.scope_abc import ScopeABC
@@ -126,40 +125,3 @@ class ServiceProviderABC(ABC):
------- -------
Object of type list[:class:`type`] Object of type list[:class:`type`]
""" """
@classmethod
def inject(cls, f=None):
r"""Decorator to allow injection into static and class methods
Parameter
---------
f: Callable
Returns
-------
function
"""
if f is None:
return functools.partial(cls.inject)
if iscoroutinefunction(f):
@functools.wraps(f)
async def async_inner(*args, **kwargs):
if cls._provider is None:
raise Exception(f"{cls.__name__} not build!")
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
return await f(*args, *injection, **kwargs)
return async_inner
@functools.wraps(f)
def inner(*args, **kwargs):
if cls._provider is None:
raise Exception(f"{cls.__name__} not build!")
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
return f(*args, *injection, **kwargs)
return inner