WIP: dev into master #184
@@ -1,9 +1,10 @@
|
|||||||
from cpl.dependency import ServiceProvider, ServiceProviderABC
|
from cpl.dependency import ServiceProvider, ServiceProviderABC
|
||||||
|
from cpl.dependency.inject import inject
|
||||||
from di.test_service import TestService
|
from di.test_service import TestService
|
||||||
|
|
||||||
|
|
||||||
class StaticTest:
|
class StaticTest:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ServiceProvider.inject
|
@inject
|
||||||
def test(services: ServiceProviderABC, t1: TestService):
|
def test(services: ServiceProviderABC, t1: TestService):
|
||||||
t1.run()
|
t1.run()
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from cpl.api.settings import ApiSettings
|
|||||||
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
|
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
|
||||||
from cpl.application.abc.application_abc import ApplicationABC
|
from cpl.application.abc.application_abc import ApplicationABC
|
||||||
from cpl.core.configuration import Configuration
|
from cpl.core.configuration import Configuration
|
||||||
|
from cpl.dependency.inject import inject
|
||||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||||
|
|
||||||
|
|
||||||
@@ -44,15 +45,15 @@ class WebApp(ApplicationABC):
|
|||||||
self._policies = services.get_service(PolicyRegistry)
|
self._policies = services.get_service(PolicyRegistry)
|
||||||
self._routes = services.get_service(RouteRegistry)
|
self._routes = services.get_service(RouteRegistry)
|
||||||
|
|
||||||
self._middleware: list[Middleware] = [
|
self._middleware: list[Middleware] = []
|
||||||
Middleware(RequestMiddleware),
|
|
||||||
Middleware(LoggingMiddleware),
|
|
||||||
]
|
|
||||||
self._exception_handlers: Mapping[Any, ExceptionHandler] = {
|
self._exception_handlers: Mapping[Any, ExceptionHandler] = {
|
||||||
Exception: self._handle_exception,
|
Exception: self._handle_exception,
|
||||||
APIError: self._handle_exception,
|
APIError: self._handle_exception,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.with_middleware(RequestMiddleware)
|
||||||
|
self.with_middleware(LoggingMiddleware)
|
||||||
|
|
||||||
async def _handle_exception(self, request: Request, exc: Exception):
|
async def _handle_exception(self, request: Request, exc: Exception):
|
||||||
if isinstance(exc, APIError):
|
if isinstance(exc, APIError):
|
||||||
self._logger.error(exc)
|
self._logger.error(exc)
|
||||||
@@ -168,9 +169,9 @@ class WebApp(ApplicationABC):
|
|||||||
self._check_for_app()
|
self._check_for_app()
|
||||||
|
|
||||||
if isinstance(middleware, Middleware):
|
if isinstance(middleware, Middleware):
|
||||||
self._middleware.append(middleware)
|
self._middleware.append(inject(middleware))
|
||||||
elif callable(middleware):
|
elif callable(middleware):
|
||||||
self._middleware.append(Middleware(middleware))
|
self._middleware.append(Middleware(inject(middleware)))
|
||||||
else:
|
else:
|
||||||
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
|
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
|
||||||
|
|
||||||
@@ -220,7 +221,7 @@ class WebApp(ApplicationABC):
|
|||||||
self._validate_policies()
|
self._validate_policies()
|
||||||
|
|
||||||
if self._app is None:
|
if self._app is None:
|
||||||
routes = [route.to_starlette(self._services.inject) for route in self._routes.all()]
|
routes = [route.to_starlette(inject) for route in self._routes.all()]
|
||||||
|
|
||||||
app = Starlette(
|
app = Starlette(
|
||||||
routes=routes,
|
routes=routes,
|
||||||
|
|||||||
@@ -9,12 +9,10 @@ from cpl.api.router import Router
|
|||||||
from cpl.auth.keycloak import KeycloakClient
|
from cpl.auth.keycloak import KeycloakClient
|
||||||
from cpl.auth.schema import AuthUserDao, AuthUser
|
from cpl.auth.schema import AuthUserDao, AuthUser
|
||||||
from cpl.core.ctx import set_user
|
from cpl.core.ctx import set_user
|
||||||
from cpl.dependency import ServiceProviderABC
|
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationMiddleware(ASGIMiddleware):
|
class AuthenticationMiddleware(ASGIMiddleware):
|
||||||
|
|
||||||
@ServiceProviderABC.inject
|
|
||||||
def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
|
def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
|
||||||
ASGIMiddleware.__init__(self, app)
|
ASGIMiddleware.__init__(self, app)
|
||||||
|
|
||||||
|
|||||||
@@ -9,12 +9,10 @@ from cpl.api.registry.policy import PolicyRegistry
|
|||||||
from cpl.api.router import Router
|
from cpl.api.router import Router
|
||||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||||
from cpl.core.ctx.user_context import get_user
|
from cpl.core.ctx.user_context import get_user
|
||||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
|
||||||
|
|
||||||
|
|
||||||
class AuthorizationMiddleware(ASGIMiddleware):
|
class AuthorizationMiddleware(ASGIMiddleware):
|
||||||
|
|
||||||
@ServiceProviderABC.inject
|
|
||||||
def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao):
|
def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao):
|
||||||
ASGIMiddleware.__init__(self, app)
|
ASGIMiddleware.__init__(self, app)
|
||||||
|
|
||||||
|
|||||||
@@ -6,12 +6,10 @@ from starlette.types import Receive, Scope, Send
|
|||||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
from cpl.api.logger import APILogger
|
from cpl.api.logger import APILogger
|
||||||
from cpl.api.middleware.request import get_request
|
from cpl.api.middleware.request import get_request
|
||||||
from cpl.dependency import ServiceProviderABC
|
|
||||||
|
|
||||||
|
|
||||||
class LoggingMiddleware(ASGIMiddleware):
|
class LoggingMiddleware(ASGIMiddleware):
|
||||||
|
|
||||||
@ServiceProviderABC.inject
|
|
||||||
def __init__(self, app, logger: APILogger):
|
def __init__(self, app, logger: APILogger):
|
||||||
ASGIMiddleware.__init__(self, app)
|
ASGIMiddleware.__init__(self, app)
|
||||||
|
|
||||||
|
|||||||
@@ -9,14 +9,12 @@ from starlette.types import Scope, Receive, Send
|
|||||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
from cpl.api.logger import APILogger
|
from cpl.api.logger import APILogger
|
||||||
from cpl.api.typing import TRequest
|
from cpl.api.typing import TRequest
|
||||||
from cpl.dependency import ServiceProviderABC
|
|
||||||
|
|
||||||
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
|
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
|
||||||
|
|
||||||
|
|
||||||
class RequestMiddleware(ASGIMiddleware):
|
class RequestMiddleware(ASGIMiddleware):
|
||||||
|
|
||||||
@ServiceProviderABC.inject
|
|
||||||
def __init__(self, app, logger: APILogger):
|
def __init__(self, app, logger: APILogger):
|
||||||
ASGIMiddleware.__init__(self, app)
|
ASGIMiddleware.__init__(self, app)
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from cpl.application.abc.startup_abc import StartupABC
|
|||||||
from cpl.application.abc.startup_extension_abc import StartupExtensionABC
|
from cpl.application.abc.startup_extension_abc import StartupExtensionABC
|
||||||
from cpl.application.host import Host
|
from cpl.application.host import Host
|
||||||
from cpl.core.errors import dependency_error
|
from cpl.core.errors import dependency_error
|
||||||
|
from cpl.dependency.context import get_current_provider, use_root_provider
|
||||||
from cpl.dependency.service_collection import ServiceCollection
|
from cpl.dependency.service_collection import ServiceCollection
|
||||||
|
|
||||||
TApp = TypeVar("TApp", bound=ApplicationABC)
|
TApp = TypeVar("TApp", bound=ApplicationABC)
|
||||||
@@ -21,6 +22,7 @@ class ApplicationBuilder(Generic[TApp]):
|
|||||||
self._app = app if app is not None else ApplicationABC
|
self._app = app if app is not None else ApplicationABC
|
||||||
|
|
||||||
self._services = ServiceCollection()
|
self._services = ServiceCollection()
|
||||||
|
use_root_provider(self._services.build())
|
||||||
|
|
||||||
self._startup: Optional[StartupABC] = None
|
self._startup: Optional[StartupABC] = None
|
||||||
self._app_extensions: list[Type[ApplicationExtensionABC]] = []
|
self._app_extensions: list[Type[ApplicationExtensionABC]] = []
|
||||||
@@ -34,7 +36,12 @@ class ApplicationBuilder(Generic[TApp]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def service_provider(self):
|
def service_provider(self):
|
||||||
return self._services.build()
|
provider = get_current_provider()
|
||||||
|
if provider is None:
|
||||||
|
provider = self._services.build()
|
||||||
|
use_root_provider(provider)
|
||||||
|
|
||||||
|
return provider
|
||||||
|
|
||||||
def validate_app_required_modules(self, app: ApplicationABC):
|
def validate_app_required_modules(self, app: ApplicationABC):
|
||||||
for module in app.required_modules:
|
for module in app.required_modules:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import Type
|
|||||||
|
|
||||||
from cpl.core.log import LoggerABC, LogLevel
|
from cpl.core.log import LoggerABC, LogLevel
|
||||||
from cpl.core.typing import Messages
|
from cpl.core.typing import Messages
|
||||||
|
from cpl.dependency.inject import inject
|
||||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||||
|
|
||||||
|
|
||||||
@@ -17,7 +18,7 @@ class WrappedLogger(LoggerABC):
|
|||||||
|
|
||||||
self._set_logger()
|
self._set_logger()
|
||||||
|
|
||||||
@ServiceProviderABC.inject
|
@inject
|
||||||
def _set_logger(self, services: ServiceProviderABC):
|
def _set_logger(self, services: ServiceProviderABC):
|
||||||
from cpl.core.log import Logger
|
from cpl.core.log import Logger
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from .context import get_current_provider, use_provider
|
||||||
|
from .inject import inject
|
||||||
from .scope import Scope
|
from .scope import Scope
|
||||||
from .scope_abc import ScopeABC
|
from .scope_abc import ScopeABC
|
||||||
from .service_collection import ServiceCollection
|
from .service_collection import ServiceCollection
|
||||||
|
|||||||
21
src/cpl-dependency/cpl/dependency/context.py
Normal file
21
src/cpl-dependency/cpl/dependency/context.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import contextvars
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
_current_provider = contextvars.ContextVar("current_provider", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
def use_root_provider(provider):
|
||||||
|
_current_provider.set(provider)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def use_provider(provider):
|
||||||
|
token = _current_provider.set(provider)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_current_provider.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_provider():
|
||||||
|
return _current_provider.get()
|
||||||
42
src/cpl-dependency/cpl/dependency/inject.py
Normal file
42
src/cpl-dependency/cpl/dependency/inject.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import functools
|
||||||
|
from asyncio import iscoroutinefunction
|
||||||
|
from inspect import signature
|
||||||
|
|
||||||
|
from cpl.dependency.context import get_current_provider
|
||||||
|
|
||||||
|
|
||||||
|
def inject(f=None):
|
||||||
|
if f is None:
|
||||||
|
return functools.partial(inject)
|
||||||
|
|
||||||
|
if iscoroutinefunction(f):
|
||||||
|
|
||||||
|
@functools.wraps(f)
|
||||||
|
async def async_inner(*args, **kwargs):
|
||||||
|
from cpl.dependency.service_provider import ServiceProvider
|
||||||
|
|
||||||
|
provider: ServiceProvider | None = get_current_provider()
|
||||||
|
if provider is None:
|
||||||
|
raise ValueError(
|
||||||
|
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
|
||||||
|
)
|
||||||
|
|
||||||
|
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
|
||||||
|
return await f(*args, *injection, **kwargs)
|
||||||
|
|
||||||
|
return async_inner
|
||||||
|
|
||||||
|
@functools.wraps(f)
|
||||||
|
def inner(*args, **kwargs):
|
||||||
|
from cpl.dependency.service_provider import ServiceProvider
|
||||||
|
|
||||||
|
provider: ServiceProvider | None = get_current_provider()
|
||||||
|
if provider is None:
|
||||||
|
raise ValueError(
|
||||||
|
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
|
||||||
|
)
|
||||||
|
|
||||||
|
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
|
||||||
|
return f(*args, *injection, **kwargs)
|
||||||
|
|
||||||
|
return inner
|
||||||
@@ -1,9 +1,8 @@
|
|||||||
import functools
|
|
||||||
from abc import abstractmethod, ABC
|
from abc import abstractmethod, ABC
|
||||||
from inspect import Signature, signature, iscoroutinefunction
|
from inspect import Signature
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
from cpl.core.typing import T, R
|
from cpl.core.typing import T
|
||||||
from cpl.dependency.scope_abc import ScopeABC
|
from cpl.dependency.scope_abc import ScopeABC
|
||||||
|
|
||||||
|
|
||||||
@@ -126,40 +125,3 @@ class ServiceProviderABC(ABC):
|
|||||||
-------
|
-------
|
||||||
Object of type list[:class:`type`]
|
Object of type list[:class:`type`]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def inject(cls, f=None):
|
|
||||||
r"""Decorator to allow injection into static and class methods
|
|
||||||
|
|
||||||
Parameter
|
|
||||||
---------
|
|
||||||
f: Callable
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
function
|
|
||||||
"""
|
|
||||||
if f is None:
|
|
||||||
return functools.partial(cls.inject)
|
|
||||||
|
|
||||||
if iscoroutinefunction(f):
|
|
||||||
|
|
||||||
@functools.wraps(f)
|
|
||||||
async def async_inner(*args, **kwargs):
|
|
||||||
if cls._provider is None:
|
|
||||||
raise Exception(f"{cls.__name__} not build!")
|
|
||||||
|
|
||||||
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
|
|
||||||
return await f(*args, *injection, **kwargs)
|
|
||||||
|
|
||||||
return async_inner
|
|
||||||
|
|
||||||
@functools.wraps(f)
|
|
||||||
def inner(*args, **kwargs):
|
|
||||||
if cls._provider is None:
|
|
||||||
raise Exception(f"{cls.__name__} not build!")
|
|
||||||
|
|
||||||
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
|
|
||||||
return f(*args, *injection, **kwargs)
|
|
||||||
|
|
||||||
return inner
|
|
||||||
|
|||||||
Reference in New Issue
Block a user