diff --git a/example/custom/di/src/di/static_test.py b/example/custom/di/src/di/static_test.py index d3c10f60..edce2f40 100644 --- a/example/custom/di/src/di/static_test.py +++ b/example/custom/di/src/di/static_test.py @@ -1,9 +1,10 @@ from cpl.dependency import ServiceProvider, ServiceProviderABC +from cpl.dependency.inject import inject from di.test_service import TestService class StaticTest: @staticmethod - @ServiceProvider.inject + @inject def test(services: ServiceProviderABC, t1: TestService): t1.run() diff --git a/src/cpl-api/cpl/api/application/web_app.py b/src/cpl-api/cpl/api/application/web_app.py index 63631054..dbff7557 100644 --- a/src/cpl-api/cpl/api/application/web_app.py +++ b/src/cpl-api/cpl/api/application/web_app.py @@ -27,6 +27,7 @@ from cpl.api.settings import ApiSettings from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver from cpl.application.abc.application_abc import ApplicationABC from cpl.core.configuration import Configuration +from cpl.dependency.inject import inject from cpl.dependency.service_provider_abc import ServiceProviderABC @@ -44,15 +45,15 @@ class WebApp(ApplicationABC): self._policies = services.get_service(PolicyRegistry) self._routes = services.get_service(RouteRegistry) - self._middleware: list[Middleware] = [ - Middleware(RequestMiddleware), - Middleware(LoggingMiddleware), - ] + self._middleware: list[Middleware] = [] self._exception_handlers: Mapping[Any, ExceptionHandler] = { Exception: self._handle_exception, APIError: self._handle_exception, } + self.with_middleware(RequestMiddleware) + self.with_middleware(LoggingMiddleware) + async def _handle_exception(self, request: Request, exc: Exception): if isinstance(exc, APIError): self._logger.error(exc) @@ -168,9 +169,9 @@ class WebApp(ApplicationABC): self._check_for_app() if isinstance(middleware, Middleware): - self._middleware.append(middleware) + self._middleware.append(inject(middleware)) elif callable(middleware): - self._middleware.append(Middleware(middleware)) + self._middleware.append(Middleware(inject(middleware))) else: raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable") @@ -220,7 +221,7 @@ class WebApp(ApplicationABC): self._validate_policies() if self._app is None: - routes = [route.to_starlette(self._services.inject) for route in self._routes.all()] + routes = [route.to_starlette(inject) for route in self._routes.all()] app = Starlette( routes=routes, diff --git a/src/cpl-api/cpl/api/middleware/authentication.py b/src/cpl-api/cpl/api/middleware/authentication.py index dd6b0bd6..c0dc95f1 100644 --- a/src/cpl-api/cpl/api/middleware/authentication.py +++ b/src/cpl-api/cpl/api/middleware/authentication.py @@ -9,12 +9,10 @@ from cpl.api.router import Router from cpl.auth.keycloak import KeycloakClient from cpl.auth.schema import AuthUserDao, AuthUser from cpl.core.ctx import set_user -from cpl.dependency import ServiceProviderABC class AuthenticationMiddleware(ASGIMiddleware): - @ServiceProviderABC.inject def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao): ASGIMiddleware.__init__(self, app) diff --git a/src/cpl-api/cpl/api/middleware/authorization.py b/src/cpl-api/cpl/api/middleware/authorization.py index 4125e65d..b0b0d18c 100644 --- a/src/cpl-api/cpl/api/middleware/authorization.py +++ b/src/cpl-api/cpl/api/middleware/authorization.py @@ -9,12 +9,10 @@ from cpl.api.registry.policy import PolicyRegistry from cpl.api.router import Router from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.core.ctx.user_context import get_user -from cpl.dependency.service_provider_abc import ServiceProviderABC class AuthorizationMiddleware(ASGIMiddleware): - @ServiceProviderABC.inject def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao): ASGIMiddleware.__init__(self, app) diff --git a/src/cpl-api/cpl/api/middleware/logging.py b/src/cpl-api/cpl/api/middleware/logging.py index 6a28b22c..53655757 100644 --- a/src/cpl-api/cpl/api/middleware/logging.py +++ b/src/cpl-api/cpl/api/middleware/logging.py @@ -6,12 +6,10 @@ from starlette.types import Receive, Scope, Send from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.logger import APILogger from cpl.api.middleware.request import get_request -from cpl.dependency import ServiceProviderABC class LoggingMiddleware(ASGIMiddleware): - @ServiceProviderABC.inject def __init__(self, app, logger: APILogger): ASGIMiddleware.__init__(self, app) diff --git a/src/cpl-api/cpl/api/middleware/request.py b/src/cpl-api/cpl/api/middleware/request.py index 215ff683..e0b88b89 100644 --- a/src/cpl-api/cpl/api/middleware/request.py +++ b/src/cpl-api/cpl/api/middleware/request.py @@ -9,14 +9,12 @@ from starlette.types import Scope, Receive, Send from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.logger import APILogger from cpl.api.typing import TRequest -from cpl.dependency import ServiceProviderABC _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None) class RequestMiddleware(ASGIMiddleware): - @ServiceProviderABC.inject def __init__(self, app, logger: APILogger): ASGIMiddleware.__init__(self, app) diff --git a/src/cpl-application/cpl/application/application_builder.py b/src/cpl-application/cpl/application/application_builder.py index a1303aa1..cc36da88 100644 --- a/src/cpl-application/cpl/application/application_builder.py +++ b/src/cpl-application/cpl/application/application_builder.py @@ -7,6 +7,7 @@ from cpl.application.abc.startup_abc import StartupABC from cpl.application.abc.startup_extension_abc import StartupExtensionABC from cpl.application.host import Host from cpl.core.errors import dependency_error +from cpl.dependency.context import get_current_provider, use_root_provider from cpl.dependency.service_collection import ServiceCollection TApp = TypeVar("TApp", bound=ApplicationABC) @@ -21,6 +22,7 @@ class ApplicationBuilder(Generic[TApp]): self._app = app if app is not None else ApplicationABC self._services = ServiceCollection() + use_root_provider(self._services.build()) self._startup: Optional[StartupABC] = None self._app_extensions: list[Type[ApplicationExtensionABC]] = [] @@ -34,7 +36,12 @@ class ApplicationBuilder(Generic[TApp]): @property def service_provider(self): - return self._services.build() + provider = get_current_provider() + if provider is None: + provider = self._services.build() + use_root_provider(provider) + + return provider def validate_app_required_modules(self, app: ApplicationABC): for module in app.required_modules: diff --git a/src/cpl-core/cpl/core/log/wrapped_logger.py b/src/cpl-core/cpl/core/log/wrapped_logger.py index 38411ed5..b56b4df8 100644 --- a/src/cpl-core/cpl/core/log/wrapped_logger.py +++ b/src/cpl-core/cpl/core/log/wrapped_logger.py @@ -3,6 +3,7 @@ from typing import Type from cpl.core.log import LoggerABC, LogLevel from cpl.core.typing import Messages +from cpl.dependency.inject import inject from cpl.dependency.service_provider_abc import ServiceProviderABC @@ -17,7 +18,7 @@ class WrappedLogger(LoggerABC): self._set_logger() - @ServiceProviderABC.inject + @inject def _set_logger(self, services: ServiceProviderABC): from cpl.core.log import Logger diff --git a/src/cpl-dependency/cpl/dependency/__init__.py b/src/cpl-dependency/cpl/dependency/__init__.py index 8aa7165e..fb8e0c72 100644 --- a/src/cpl-dependency/cpl/dependency/__init__.py +++ b/src/cpl-dependency/cpl/dependency/__init__.py @@ -1,3 +1,5 @@ +from .context import get_current_provider, use_provider +from .inject import inject from .scope import Scope from .scope_abc import ScopeABC from .service_collection import ServiceCollection diff --git a/src/cpl-dependency/cpl/dependency/context.py b/src/cpl-dependency/cpl/dependency/context.py new file mode 100644 index 00000000..f07aab4b --- /dev/null +++ b/src/cpl-dependency/cpl/dependency/context.py @@ -0,0 +1,21 @@ +import contextvars +from contextlib import contextmanager + +_current_provider = contextvars.ContextVar("current_provider", default=None) + + +def use_root_provider(provider): + _current_provider.set(provider) + + +@contextmanager +def use_provider(provider): + token = _current_provider.set(provider) + try: + yield + finally: + _current_provider.reset(token) + + +def get_current_provider(): + return _current_provider.get() diff --git a/src/cpl-dependency/cpl/dependency/inject.py b/src/cpl-dependency/cpl/dependency/inject.py new file mode 100644 index 00000000..6b92368c --- /dev/null +++ b/src/cpl-dependency/cpl/dependency/inject.py @@ -0,0 +1,42 @@ +import functools +from asyncio import iscoroutinefunction +from inspect import signature + +from cpl.dependency.context import get_current_provider + + +def inject(f=None): + if f is None: + return functools.partial(inject) + + if iscoroutinefunction(f): + + @functools.wraps(f) + async def async_inner(*args, **kwargs): + from cpl.dependency.service_provider import ServiceProvider + + provider: ServiceProvider | None = get_current_provider() + if provider is None: + raise ValueError( + "No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context." + ) + + injection = [x for x in provider._build_by_signature(signature(f)) if x is not None] + return await f(*args, *injection, **kwargs) + + return async_inner + + @functools.wraps(f) + def inner(*args, **kwargs): + from cpl.dependency.service_provider import ServiceProvider + + provider: ServiceProvider | None = get_current_provider() + if provider is None: + raise ValueError( + "No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context." + ) + + injection = [x for x in provider._build_by_signature(signature(f)) if x is not None] + return f(*args, *injection, **kwargs) + + return inner diff --git a/src/cpl-dependency/cpl/dependency/service_provider_abc.py b/src/cpl-dependency/cpl/dependency/service_provider_abc.py index 53e87ba2..7b0d39dd 100644 --- a/src/cpl-dependency/cpl/dependency/service_provider_abc.py +++ b/src/cpl-dependency/cpl/dependency/service_provider_abc.py @@ -1,9 +1,8 @@ -import functools from abc import abstractmethod, ABC -from inspect import Signature, signature, iscoroutinefunction +from inspect import Signature from typing import Optional, Type -from cpl.core.typing import T, R +from cpl.core.typing import T from cpl.dependency.scope_abc import ScopeABC @@ -126,40 +125,3 @@ class ServiceProviderABC(ABC): ------- Object of type list[:class:`type`] """ - - @classmethod - def inject(cls, f=None): - r"""Decorator to allow injection into static and class methods - - Parameter - --------- - f: Callable - - Returns - ------- - function - """ - if f is None: - return functools.partial(cls.inject) - - if iscoroutinefunction(f): - - @functools.wraps(f) - async def async_inner(*args, **kwargs): - if cls._provider is None: - raise Exception(f"{cls.__name__} not build!") - - injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None] - return await f(*args, *injection, **kwargs) - - return async_inner - - @functools.wraps(f) - def inner(*args, **kwargs): - if cls._provider is None: - raise Exception(f"{cls.__name__} not build!") - - injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None] - return f(*args, *injection, **kwargs) - - return inner