diff --git a/example/custom/api/src/main.py b/example/custom/api/src/main.py index 3a5ee2f9..5f5d2428 100644 --- a/example/custom/api/src/main.py +++ b/example/custom/api/src/main.py @@ -6,8 +6,10 @@ from cpl.application import ApplicationBuilder from cpl.auth.permission.permissions import Permissions from cpl.auth.schema import AuthUser, Role from cpl.core.configuration import Configuration +from cpl.core.console import Console from cpl.core.environment import Environment from cpl.core.utils.cache import Cache +from custom.api.src.scoped_service import ScopedService from service import PingService @@ -23,6 +25,8 @@ def main(): builder.services.add_transient(PingService) builder.services.add_module(api) + builder.services.add_scoped(ScopedService) + builder.services.add_cache(AuthUser) builder.services.add_cache(Role) @@ -40,6 +44,32 @@ def main(): user_cache = provider.get_service(Cache[AuthUser]) role_cache = provider.get_service(Cache[Role]) + if role_cache == user_cache: + raise Exception("Cache service is not working") + + s1 = provider.get_service(ScopedService) + s2 = provider.get_service(ScopedService) + + if s1.name == s2.name: + raise Exception("Scoped service is not working") + + with provider.create_scope() as scope: + s3 = scope.get_service(ScopedService) + s4 = scope.get_service(ScopedService) + + if s3.name != s4.name: + raise Exception("Scoped service is not working") + + if s1.name == s3.name: + raise Exception("Scoped service is not working") + + Console.write_line( + s1.name, + s2.name, + s3.name, + s4.name, + ) + app.run() diff --git a/example/custom/api/src/routes/ping.py b/example/custom/api/src/routes/ping.py index 7fad7145..7bb40ad2 100644 --- a/example/custom/api/src/routes/ping.py +++ b/example/custom/api/src/routes/ping.py @@ -5,12 +5,17 @@ from starlette.responses import JSONResponse from cpl.api import APILogger from cpl.api.router import Router +from cpl.core.console import Console +from cpl.dependency import ServiceProvider +from custom.api.src.scoped_service import ScopedService @Router.authenticate() # @Router.authorize(permissions=[Permissions.administrator]) # @Router.authorize(policies=["test"]) @Router.get(f"/ping") -async def ping(r: Request, ping: PingService, logger: APILogger): +async def ping(r: Request, ping: PingService, logger: APILogger, provider: ServiceProvider, scoped: ScopedService): logger.info(f"Ping: {ping}") + + Console.write_line(scoped.name) return JSONResponse(ping.ping(r)) diff --git a/example/custom/api/src/scoped_service.py b/example/custom/api/src/scoped_service.py new file mode 100644 index 00000000..f8c9b15a --- /dev/null +++ b/example/custom/api/src/scoped_service.py @@ -0,0 +1,14 @@ +from cpl.core.console.console import Console +from cpl.core.utils.string import String + + +class ScopedService: + def __init__(self): + self._name = String.random(8) + + @property + def name(self) -> str: + return self._name + + def run(self): + Console.write_line(f"Im {self._name}") diff --git a/example/custom/general/src/general/application.py b/example/custom/general/src/general/application.py index b22ce377..5486060b 100644 --- a/example/custom/general/src/general/application.py +++ b/example/custom/general/src/general/application.py @@ -62,7 +62,7 @@ class Application(ApplicationABC): root_scoped_service2 = self._services.get_service(ScopedService) Console.write_line(root_scoped_service2) - if root_scoped_service != root_scoped_service2: + if root_scoped_service == root_scoped_service2: raise Exception("Root scoped service should be equal to root scoped service 2") test_settings = Configuration.get(TestSettings) diff --git a/src/cpl-api/cpl/api/application/web_app.py b/src/cpl-api/cpl/api/application/web_app.py index c6a28b48..cc9d8dce 100644 --- a/src/cpl-api/cpl/api/application/web_app.py +++ b/src/cpl-api/cpl/api/application/web_app.py @@ -30,7 +30,6 @@ from cpl.core.configuration import Configuration from cpl.dependency.inject import inject from cpl.dependency.service_provider import ServiceProvider - PolicyInput = Union[dict[str, PolicyResolver], Policy] diff --git a/src/cpl-api/cpl/api/middleware/_scope_middleware.py b/src/cpl-api/cpl/api/middleware/_scope_middleware.py deleted file mode 100644 index cc49631f..00000000 --- a/src/cpl-api/cpl/api/middleware/_scope_middleware.py +++ /dev/null @@ -1,13 +0,0 @@ -from cpl.api.abc import ASGIMiddleware -from cpl.dependency.service_provider import ServiceProvider - - -class ScopeMiddleware(ASGIMiddleware): - def __init__(self, app, provider: ServiceProvider): - ASGIMiddleware.__init__(self, app) - self._app = app - self._provider = provider - - async def __call__(self, scope, receive, send): - with self._provider.create_scope(): - await self._app(scope, receive, send) \ No newline at end of file diff --git a/src/cpl-api/cpl/api/middleware/request.py b/src/cpl-api/cpl/api/middleware/request.py index e0b88b89..0cedc88b 100644 --- a/src/cpl-api/cpl/api/middleware/request.py +++ b/src/cpl-api/cpl/api/middleware/request.py @@ -9,15 +9,18 @@ from starlette.types import Scope, Receive, Send from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.logger import APILogger from cpl.api.typing import TRequest +from cpl.dependency.inject import inject +from cpl.dependency.service_provider import ServiceProvider _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None) class RequestMiddleware(ASGIMiddleware): - def __init__(self, app, logger: APILogger): + def __init__(self, app, provider: ServiceProvider, logger: APILogger): ASGIMiddleware.__init__(self, app) + self._provider = provider self._logger = logger self._ctx_token = None @@ -27,7 +30,8 @@ class RequestMiddleware(ASGIMiddleware): await self.set_request_data(request) try: - await self._app(scope, receive, send) + with self._provider.create_scope(): + inject(await self._app(scope, receive, send)) finally: await self.clean_request_data() diff --git a/src/cpl-dependency/cpl/dependency/inject.py b/src/cpl-dependency/cpl/dependency/inject.py index 3e6b915f..f49579af 100644 --- a/src/cpl-dependency/cpl/dependency/inject.py +++ b/src/cpl-dependency/cpl/dependency/inject.py @@ -10,7 +10,6 @@ def inject(f=None): return functools.partial(inject) if iscoroutinefunction(f): - @functools.wraps(f) async def async_inner(*args, **kwargs): from cpl.dependency.service_provider import ServiceProvider diff --git a/src/cpl-dependency/cpl/dependency/service_provider.py b/src/cpl-dependency/cpl/dependency/service_provider.py index 09fc3852..0be72c42 100644 --- a/src/cpl-dependency/cpl/dependency/service_provider.py +++ b/src/cpl-dependency/cpl/dependency/service_provider.py @@ -14,23 +14,9 @@ from cpl.dependency.service_lifetime import ServiceLifetimeEnum class ServiceProvider: - r"""Provider for the services - - Parameter - --------- - service_descriptors: list[:class:`cpl.dependency.service_descriptor.ServiceDescriptor`] - Descriptor of the service - config: :class:`cpl.core.configuration.configuration_abc.ConfigurationABC` - CPL Configuration - db_context: Optional[:class:`cpl.database.context.database_context_abc.DatabaseContextABC`] - Database representation - """ - - def __init__( - self, - service_descriptors: list[ServiceDescriptor], - ): + def __init__(self, service_descriptors: list[ServiceDescriptor], is_scope: bool = False): self._service_descriptors: list[ServiceDescriptor] = service_descriptors + self._is_scope = is_scope def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]: origin_type = typing.get_origin(service_type) or service_type @@ -57,13 +43,13 @@ class ServiceProvider: def _get_service(self, parameter: Parameter, origin_service_type: type = None) -> Optional[object]: for descriptor in self._service_descriptors: if descriptor.service_type == parameter.annotation or issubclass( - descriptor.service_type, parameter.annotation + descriptor.service_type, parameter.annotation ): if descriptor.implementation is not None: return descriptor.implementation implementation = self._build_service(descriptor.service_type, origin_service_type=origin_service_type) - if descriptor.lifetime == ServiceLifetimeEnum.singleton: + if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped): descriptor.implementation = implementation return implementation @@ -81,7 +67,7 @@ class ServiceProvider: implementation = self._build_service( descriptor.service_type, origin_service_type=service_type, **kwargs ) - if descriptor.lifetime == ServiceLifetimeEnum.singleton: + if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped): descriptor.implementation = implementation implementations.append(implementation) @@ -127,12 +113,10 @@ class ServiceProvider: service_type = type(descriptor.implementation) else: service_type = descriptor.service_type - break sig = signature(service_type.__init__) params = self._build_by_signature(sig, origin_service_type) - return service_type(*params, *args, **kwargs) @contextmanager @@ -144,13 +128,12 @@ class ServiceProvider: else: scoped_descriptors.append(copy.deepcopy(d)) - scoped_provider = ServiceProvider(scoped_descriptors) + scoped_provider = ServiceProvider(scoped_descriptors, is_scope=True) with use_provider(scoped_provider): yield scoped_provider def get_service(self, service_type: T, *args, **kwargs) -> Optional[R]: result = self._find_service(service_type) - if result is None: return None @@ -158,9 +141,10 @@ class ServiceProvider: return result.implementation implementation = self._build_service(service_type, *args, **kwargs) - if ( - result.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped) - ): + + if result.lifetime == ServiceLifetimeEnum.singleton: + result.implementation = implementation + elif result.lifetime == ServiceLifetimeEnum.scoped and self._is_scope: result.implementation = implementation return implementation @@ -173,12 +157,9 @@ class ServiceProvider: def get_services(self, service_type: T, *args, **kwargs) -> list[Optional[R]]: implementations = [] - if typing.get_origin(service_type) == list: raise Exception(f"Invalid type {service_type}! Expected single type not list of type") - implementations.extend(self._get_services(service_type)) - return implementations def get_service_types(self, service_type: Type[T]) -> list[Type[T]]: