WIP: dev into master #184

Draft
edraft wants to merge 121 commits from dev into master
9 changed files with 67 additions and 48 deletions
Showing only changes of commit b49f663ae0 - Show all commits

View File

@@ -6,8 +6,10 @@ from cpl.application import ApplicationBuilder
from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema import AuthUser, Role
from cpl.core.configuration import Configuration
from cpl.core.console import Console
from cpl.core.environment import Environment
from cpl.core.utils.cache import Cache
from custom.api.src.scoped_service import ScopedService
from service import PingService
@@ -23,6 +25,8 @@ def main():
builder.services.add_transient(PingService)
builder.services.add_module(api)
builder.services.add_scoped(ScopedService)
builder.services.add_cache(AuthUser)
builder.services.add_cache(Role)
@@ -40,6 +44,32 @@ def main():
user_cache = provider.get_service(Cache[AuthUser])
role_cache = provider.get_service(Cache[Role])
if role_cache == user_cache:
raise Exception("Cache service is not working")
s1 = provider.get_service(ScopedService)
s2 = provider.get_service(ScopedService)
if s1.name == s2.name:
raise Exception("Scoped service is not working")
with provider.create_scope() as scope:
s3 = scope.get_service(ScopedService)
s4 = scope.get_service(ScopedService)
if s3.name != s4.name:
raise Exception("Scoped service is not working")
if s1.name == s3.name:
raise Exception("Scoped service is not working")
Console.write_line(
s1.name,
s2.name,
s3.name,
s4.name,
)
app.run()

View File

@@ -5,12 +5,17 @@ from starlette.responses import JSONResponse
from cpl.api import APILogger
from cpl.api.router import Router
from cpl.core.console import Console
from cpl.dependency import ServiceProvider
from custom.api.src.scoped_service import ScopedService
@Router.authenticate()
# @Router.authorize(permissions=[Permissions.administrator])
# @Router.authorize(policies=["test"])
@Router.get(f"/ping")
async def ping(r: Request, ping: PingService, logger: APILogger):
async def ping(r: Request, ping: PingService, logger: APILogger, provider: ServiceProvider, scoped: ScopedService):
logger.info(f"Ping: {ping}")
Console.write_line(scoped.name)
return JSONResponse(ping.ping(r))

View File

@@ -0,0 +1,14 @@
from cpl.core.console.console import Console
from cpl.core.utils.string import String
class ScopedService:
def __init__(self):
self._name = String.random(8)
@property
def name(self) -> str:
return self._name
def run(self):
Console.write_line(f"Im {self._name}")

View File

@@ -62,7 +62,7 @@ class Application(ApplicationABC):
root_scoped_service2 = self._services.get_service(ScopedService)
Console.write_line(root_scoped_service2)
if root_scoped_service != root_scoped_service2:
if root_scoped_service == root_scoped_service2:
raise Exception("Root scoped service should be equal to root scoped service 2")
test_settings = Configuration.get(TestSettings)

View File

@@ -30,7 +30,6 @@ from cpl.core.configuration import Configuration
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
PolicyInput = Union[dict[str, PolicyResolver], Policy]

View File

@@ -1,13 +0,0 @@
from cpl.api.abc import ASGIMiddleware
from cpl.dependency.service_provider import ServiceProvider
class ScopeMiddleware(ASGIMiddleware):
def __init__(self, app, provider: ServiceProvider):
ASGIMiddleware.__init__(self, app)
self._app = app
self._provider = provider
async def __call__(self, scope, receive, send):
with self._provider.create_scope():
await self._app(scope, receive, send)

View File

@@ -9,15 +9,18 @@ from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger
from cpl.api.typing import TRequest
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
class RequestMiddleware(ASGIMiddleware):
def __init__(self, app, logger: APILogger):
def __init__(self, app, provider: ServiceProvider, logger: APILogger):
ASGIMiddleware.__init__(self, app)
self._provider = provider
self._logger = logger
self._ctx_token = None
@@ -27,7 +30,8 @@ class RequestMiddleware(ASGIMiddleware):
await self.set_request_data(request)
try:
await self._app(scope, receive, send)
with self._provider.create_scope():
inject(await self._app(scope, receive, send))
finally:
await self.clean_request_data()

View File

@@ -10,7 +10,6 @@ def inject(f=None):
return functools.partial(inject)
if iscoroutinefunction(f):
@functools.wraps(f)
async def async_inner(*args, **kwargs):
from cpl.dependency.service_provider import ServiceProvider

View File

@@ -14,23 +14,9 @@ from cpl.dependency.service_lifetime import ServiceLifetimeEnum
class ServiceProvider:
r"""Provider for the services
Parameter
---------
service_descriptors: list[:class:`cpl.dependency.service_descriptor.ServiceDescriptor`]
Descriptor of the service
config: :class:`cpl.core.configuration.configuration_abc.ConfigurationABC`
CPL Configuration
db_context: Optional[:class:`cpl.database.context.database_context_abc.DatabaseContextABC`]
Database representation
"""
def __init__(
self,
service_descriptors: list[ServiceDescriptor],
):
def __init__(self, service_descriptors: list[ServiceDescriptor], is_scope: bool = False):
self._service_descriptors: list[ServiceDescriptor] = service_descriptors
self._is_scope = is_scope
def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]:
origin_type = typing.get_origin(service_type) or service_type
@@ -63,7 +49,7 @@ class ServiceProvider:
return descriptor.implementation
implementation = self._build_service(descriptor.service_type, origin_service_type=origin_service_type)
if descriptor.lifetime == ServiceLifetimeEnum.singleton:
if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped):
descriptor.implementation = implementation
return implementation
@@ -81,7 +67,7 @@ class ServiceProvider:
implementation = self._build_service(
descriptor.service_type, origin_service_type=service_type, **kwargs
)
if descriptor.lifetime == ServiceLifetimeEnum.singleton:
if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped):
descriptor.implementation = implementation
implementations.append(implementation)
@@ -127,12 +113,10 @@ class ServiceProvider:
service_type = type(descriptor.implementation)
else:
service_type = descriptor.service_type
break
sig = signature(service_type.__init__)
params = self._build_by_signature(sig, origin_service_type)
return service_type(*params, *args, **kwargs)
@contextmanager
@@ -144,13 +128,12 @@ class ServiceProvider:
else:
scoped_descriptors.append(copy.deepcopy(d))
scoped_provider = ServiceProvider(scoped_descriptors)
scoped_provider = ServiceProvider(scoped_descriptors, is_scope=True)
with use_provider(scoped_provider):
yield scoped_provider
def get_service(self, service_type: T, *args, **kwargs) -> Optional[R]:
result = self._find_service(service_type)
if result is None:
return None
@@ -158,9 +141,10 @@ class ServiceProvider:
return result.implementation
implementation = self._build_service(service_type, *args, **kwargs)
if (
result.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped)
):
if result.lifetime == ServiceLifetimeEnum.singleton:
result.implementation = implementation
elif result.lifetime == ServiceLifetimeEnum.scoped and self._is_scope:
result.implementation = implementation
return implementation
@@ -173,12 +157,9 @@ class ServiceProvider:
def get_services(self, service_type: T, *args, **kwargs) -> list[Optional[R]]:
implementations = []
if typing.get_origin(service_type) == list:
raise Exception(f"Invalid type {service_type}! Expected single type not list of type")
implementations.extend(self._get_services(service_type))
return implementations
def get_service_types(self, service_type: Type[T]) -> list[Type[T]]: