Compare commits

..

5 Commits

Author SHA1 Message Date
6a3fdb3ebd Fixed formatting #186
All checks were successful
Test before pr merge / test-lint (pull_request) Successful in 7s
Build on push / prepare (push) Successful in 9s
Build on push / core (push) Successful in 17s
Build on push / query (push) Successful in 17s
Build on push / dependency (push) Successful in 17s
Build on push / application (push) Successful in 16s
Build on push / database (push) Successful in 17s
Build on push / mail (push) Successful in 18s
Build on push / translation (push) Successful in 18s
Build on push / auth (push) Successful in 14s
Build on push / api (push) Successful in 14s
2025-09-24 21:48:57 +02:00
b49f663ae0 API scoped requests #186
Some checks failed
Test before pr merge / test-lint (pull_request) Failing after 6s
Build on push / prepare (push) Successful in 9s
Build on push / query (push) Successful in 18s
Build on push / core (push) Successful in 21s
Build on push / dependency (push) Successful in 14s
Build on push / api (push) Has been cancelled
Build on push / auth (push) Has been cancelled
Build on push / application (push) Has been cancelled
Build on push / database (push) Has been cancelled
Build on push / mail (push) Has been cancelled
Build on push / translation (push) Has been cancelled
2025-09-24 21:47:52 +02:00
287f5e3149 New implementation of scopes #186 2025-09-24 21:27:28 +02:00
4c8cd988cc Removed ServiceProviderABC #186 2025-09-24 20:53:01 +02:00
cdb4a0fb34 DI Provider ctx #186 2025-09-24 20:46:43 +02:00
58 changed files with 323 additions and 425 deletions

View File

@@ -6,8 +6,10 @@ from cpl.application import ApplicationBuilder
from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema import AuthUser, Role
from cpl.core.configuration import Configuration
from cpl.core.console import Console
from cpl.core.environment import Environment
from cpl.core.utils.cache import Cache
from custom.api.src.scoped_service import ScopedService
from service import PingService
@@ -23,6 +25,8 @@ def main():
builder.services.add_transient(PingService)
builder.services.add_module(api)
builder.services.add_scoped(ScopedService)
builder.services.add_cache(AuthUser)
builder.services.add_cache(Role)
@@ -40,6 +44,32 @@ def main():
user_cache = provider.get_service(Cache[AuthUser])
role_cache = provider.get_service(Cache[Role])
if role_cache == user_cache:
raise Exception("Cache service is not working")
s1 = provider.get_service(ScopedService)
s2 = provider.get_service(ScopedService)
if s1.name == s2.name:
raise Exception("Scoped service is not working")
with provider.create_scope() as scope:
s3 = scope.get_service(ScopedService)
s4 = scope.get_service(ScopedService)
if s3.name != s4.name:
raise Exception("Scoped service is not working")
if s1.name == s3.name:
raise Exception("Scoped service is not working")
Console.write_line(
s1.name,
s2.name,
s3.name,
s4.name,
)
app.run()

View File

@@ -5,12 +5,17 @@ from starlette.responses import JSONResponse
from cpl.api import APILogger
from cpl.api.router import Router
from cpl.core.console import Console
from cpl.dependency import ServiceProvider
from custom.api.src.scoped_service import ScopedService
@Router.authenticate()
# @Router.authorize(permissions=[Permissions.administrator])
# @Router.authorize(policies=["test"])
@Router.get(f"/ping")
async def ping(r: Request, ping: PingService, logger: APILogger):
async def ping(r: Request, ping: PingService, logger: APILogger, provider: ServiceProvider, scoped: ScopedService):
logger.info(f"Ping: {ping}")
Console.write_line(scoped.name)
return JSONResponse(ping.ping(r))

View File

@@ -0,0 +1,14 @@
from cpl.core.console.console import Console
from cpl.core.utils.string import String
class ScopedService:
def __init__(self):
self._name = String.random(8)
@property
def name(self) -> str:
return self._name
def run(self):
Console.write_line(f"Im {self._name}")

View File

@@ -3,7 +3,7 @@ from cpl.auth.keycloak import KeycloakAdmin
from cpl.core.console import Console
from cpl.core.environment import Environment
from cpl.core.log import LoggerABC
from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
from model.city import City
from model.city_dao import CityDao
from model.user import User
@@ -11,7 +11,7 @@ from model.user_dao import UserDao
class Application(ApplicationABC):
def __init__(self, services: ServiceProviderABC):
def __init__(self, services: ServiceProvider):
ApplicationABC.__init__(self, services)
self._logger = services.get_service(LoggerABC)

View File

@@ -1,7 +1,6 @@
from cpl.application.abc import ApplicationABC
from cpl.core.console.console import Console
from cpl.dependency import ServiceProviderABC
from cpl.dependency.scope import Scope
from cpl.dependency import ServiceProvider
from di.static_test import StaticTest
from di.test_abc import TestABC
from di.test_service import TestService
@@ -10,33 +9,37 @@ from di.tester import Tester
class Application(ApplicationABC):
def __init__(self, services: ServiceProviderABC):
def __init__(self, services: ServiceProvider):
ApplicationABC.__init__(self, services)
def _part_of_scoped(self):
ts: TestService = self._services.get_service(TestService)
ts.run()
def configure(self): ...
def main(self):
with self._services.create_scope() as scope:
Console.write_line("Scope1")
ts: TestService = scope.service_provider.get_service(TestService)
ts: TestService = scope.get_service(TestService)
ts.run()
dit: DITesterService = scope.service_provider.get_service(DITesterService)
dit: DITesterService = scope.get_service(DITesterService)
dit.run()
if ts.name != dit.name:
raise Exception("DI is broken!")
with self._services.create_scope() as scope:
Console.write_line("Scope2")
ts: TestService = scope.service_provider.get_service(TestService)
ts: TestService = scope.get_service(TestService)
ts.run()
dit: DITesterService = scope.service_provider.get_service(DITesterService)
dit: DITesterService = scope.get_service(DITesterService)
dit.run()
if ts.name != dit.name:
raise Exception("DI is broken!")
Console.write_line("Global")
self._part_of_scoped()
StaticTest.test()
self._services.get_service(Tester)
Console.write_line(self._services.get_services(list[TestABC]))
Console.write_line(self._services.get_services(TestABC))

View File

@@ -6,6 +6,10 @@ class DITesterService:
def __init__(self, ts: TestService):
self._ts = ts
@property
def name(self) -> str:
return self._ts.name
def run(self):
Console.write_line("DIT: ")
self._ts.run()

View File

@@ -1,5 +1,5 @@
from cpl.application.abc import StartupABC
from cpl.dependency import ServiceProviderABC, ServiceCollection
from cpl.dependency import ServiceProvider, ServiceCollection
from di.di_tester_service import DITesterService
from di.test1_service import Test1Service
from di.test2_service import Test2Service
@@ -12,9 +12,11 @@ class Startup(StartupABC):
def __init__(self):
StartupABC.__init__(self)
def configure_configuration(self): ...
@staticmethod
def configure_configuration(): ...
def configure_services(self, services: ServiceCollection) -> ServiceProviderABC:
@staticmethod
def configure_services(services: ServiceCollection) -> ServiceProvider:
services.add_scoped(TestService)
services.add_scoped(DITesterService)

View File

@@ -1,9 +1,10 @@
from cpl.dependency import ServiceProvider, ServiceProviderABC
from cpl.dependency import ServiceProvider, ServiceProvider
from cpl.dependency.inject import inject
from di.test_service import TestService
class StaticTest:
@staticmethod
@ServiceProvider.inject
def test(services: ServiceProviderABC, t1: TestService):
@inject
def test(services: ServiceProvider, t1: TestService):
t1.run()

View File

@@ -6,7 +6,7 @@ from di.test_abc import TestABC
class Test1Service(TestABC):
def __init__(self):
TestABC.__init__(self, String.random_string(string.ascii_lowercase, 8))
TestABC.__init__(self, String.random(8))
def run(self):
Console.write_line(f"Im {self._name}")

View File

@@ -6,7 +6,7 @@ from di.test_abc import TestABC
class Test2Service(TestABC):
def __init__(self):
TestABC.__init__(self, String.random_string(string.ascii_lowercase, 8))
TestABC.__init__(self, String.random(8))
def run(self):
Console.write_line(f"Im {self._name}")

View File

@@ -1,5 +1,3 @@
import string
from cpl.core.console.console import Console
from cpl.core.utils.string import String
@@ -8,5 +6,9 @@ class TestService:
def __init__(self):
self._name = String.random(8)
@property
def name(self) -> str:
return self._name
def run(self):
Console.write_line(f"Im {self._name}")

View File

@@ -4,19 +4,20 @@ from typing import Optional
from cpl.application.abc import ApplicationABC
from cpl.core.configuration import Configuration
from cpl.core.console import Console
from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
from cpl.core.environment import Environment
from cpl.core.log import LoggerABC
from cpl.core.pipes import IPAddressPipe
from cpl.mail import EMail, EMailClientABC
from cpl.query.extension.list import List
from cpl.query import List
from general.scoped_service import ScopedService
from test_service import TestService
from test_settings import TestSettings
class Application(ApplicationABC):
def __init__(self, services: ServiceProviderABC):
def __init__(self, services: ServiceProvider):
ApplicationABC.__init__(self, services)
self._logger = self._services.get_service(LoggerABC)
self._mailer = self._services.get_service(EMailClientABC)
@@ -38,7 +39,7 @@ class Application(ApplicationABC):
def main(self):
self._logger.debug(f"Host: {Environment.get_host_name()}")
self._logger.debug(f"Environment: {Environment.get_environment()}")
Console.write_line(List(int, range(0, 10)).select(lambda x: f"x={x}").to_list())
Console.write_line(List(range(0, 10)).select(lambda x: f"x={x}").to_list())
Console.spinner("Test", self._wait, 2, spinner_foreground_color="red")
test: TestService = self._services.get_service(TestService)
ip_pipe: IPAddressPipe = self._services.get_service(IPAddressPipe)
@@ -48,10 +49,21 @@ class Application(ApplicationABC):
Console.write_line(f"DI working: {test == test2 and ip_pipe != ip_pipe2}")
Console.write_line(self._services.get_service(LoggerABC))
scope = self._services.create_scope()
Console.write_line("scope", scope)
with self._services.create_scope() as s:
Console.write_line("with scope", s)
root_scoped_service = self._services.get_service(ScopedService)
with self._services.create_scope() as scope:
s_srvc1 = scope.get_service(ScopedService)
s_srvc2 = scope.get_service(ScopedService)
Console.write_line(root_scoped_service)
Console.write_line(s_srvc1)
Console.write_line(s_srvc2)
if root_scoped_service == s_srvc1 or s_srvc1 != s_srvc2:
raise Exception("Root scoped service should not be equal to scoped service")
root_scoped_service2 = self._services.get_service(ScopedService)
Console.write_line(root_scoped_service2)
if root_scoped_service == root_scoped_service2:
raise Exception("Root scoped service should be equal to root scoped service 2")
test_settings = Configuration.get(TestSettings)
Console.write_line(test_settings.value)

View File

@@ -0,0 +1,10 @@
from cpl.core.console import Console
class ScopedService:
def __init__(self):
self.value = "I am a scoped service"
Console.write_line(self.value, self)
def get_value(self):
return self.value

View File

@@ -4,6 +4,7 @@ from cpl.core.configuration import Configuration
from cpl.core.environment import Environment
from cpl.core.pipes import IPAddressPipe
from cpl.dependency import ServiceCollection
from general.scoped_service import ScopedService
from test_service import TestService
@@ -21,3 +22,4 @@ class Startup(StartupABC):
services.add_module(mail)
services.add_transient(IPAddressPipe)
services.add_singleton(TestService)
services.add_scoped(ScopedService)

View File

@@ -1,10 +1,10 @@
from cpl.application.abc import ApplicationExtensionABC
from cpl.core.console import Console
from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
class TestExtension(ApplicationExtensionABC):
@staticmethod
def run(services: ServiceProviderABC):
def run(services: ServiceProvider):
Console.write_line("Hello World from App Extension")

View File

@@ -1,10 +1,10 @@
from cpl.core.console.console import Console
from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
from cpl.core.pipes.ip_address_pipe import IPAddressPipe
class TestService:
def __init__(self, provider: ServiceProviderABC):
def __init__(self, provider: ServiceProvider):
self._provider = provider
def run(self):

View File

@@ -1,14 +1,14 @@
from cpl.application import ApplicationABC
from cpl.core.configuration import ConfigurationABC
from cpl.core.console import Console
from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
from cpl.translation.translate_pipe import TranslatePipe
from cpl.translation.translation_service_abc import TranslationServiceABC
from cpl.translation.translation_settings import TranslationSettings
class Application(ApplicationABC):
def __init__(self, config: ConfigurationABC, services: ServiceProviderABC):
def __init__(self, config: ConfigurationABC, services: ServiceProvider):
ApplicationABC.__init__(self, config, services)
self._translate: TranslatePipe = services.get_service(TranslatePipe)

View File

@@ -1,6 +1,6 @@
from cpl.application import StartupABC
from cpl.core.configuration import ConfigurationABC
from cpl.dependency import ServiceProviderABC, ServiceCollection
from cpl.dependency import ServiceProvider, ServiceCollection
from cpl.core.environment import Environment
@@ -12,6 +12,6 @@ class Startup(StartupABC):
configuration.add_json_file("appsettings.json")
return configuration
def configure_services(self, services: ServiceCollection, environment: Environment) -> ServiceProviderABC:
def configure_services(self, services: ServiceCollection, environment: Environment) -> ServiceProvider:
services.add_translation()
return services.build()

View File

@@ -27,14 +27,14 @@ from cpl.api.settings import ApiSettings
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
from cpl.application.abc.application_abc import ApplicationABC
from cpl.core.configuration import Configuration
from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
PolicyInput = Union[dict[str, PolicyResolver], Policy]
class WebApp(ApplicationABC):
def __init__(self, services: ServiceProviderABC):
def __init__(self, services: ServiceProvider):
super().__init__(services, [auth, api])
self._app: Starlette | None = None
@@ -44,15 +44,15 @@ class WebApp(ApplicationABC):
self._policies = services.get_service(PolicyRegistry)
self._routes = services.get_service(RouteRegistry)
self._middleware: list[Middleware] = [
Middleware(RequestMiddleware),
Middleware(LoggingMiddleware),
]
self._middleware: list[Middleware] = []
self._exception_handlers: Mapping[Any, ExceptionHandler] = {
Exception: self._handle_exception,
APIError: self._handle_exception,
}
self.with_middleware(RequestMiddleware)
self.with_middleware(LoggingMiddleware)
async def _handle_exception(self, request: Request, exc: Exception):
if isinstance(exc, APIError):
self._logger.error(exc)
@@ -168,9 +168,9 @@ class WebApp(ApplicationABC):
self._check_for_app()
if isinstance(middleware, Middleware):
self._middleware.append(middleware)
self._middleware.append(inject(middleware))
elif callable(middleware):
self._middleware.append(Middleware(middleware))
self._middleware.append(Middleware(inject(middleware)))
else:
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
@@ -220,7 +220,7 @@ class WebApp(ApplicationABC):
self._validate_policies()
if self._app is None:
routes = [route.to_starlette(self._services.inject) for route in self._routes.all()]
routes = [route.to_starlette(inject) for route in self._routes.all()]
app = Starlette(
routes=routes,

View File

@@ -9,12 +9,10 @@ from cpl.api.router import Router
from cpl.auth.keycloak import KeycloakClient
from cpl.auth.schema import AuthUserDao, AuthUser
from cpl.core.ctx import set_user
from cpl.dependency import ServiceProviderABC
class AuthenticationMiddleware(ASGIMiddleware):
@ServiceProviderABC.inject
def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app)

View File

@@ -9,12 +9,10 @@ from cpl.api.registry.policy import PolicyRegistry
from cpl.api.router import Router
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
from cpl.core.ctx.user_context import get_user
from cpl.dependency.service_provider_abc import ServiceProviderABC
class AuthorizationMiddleware(ASGIMiddleware):
@ServiceProviderABC.inject
def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app)

View File

@@ -6,12 +6,10 @@ from starlette.types import Receive, Scope, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger
from cpl.api.middleware.request import get_request
from cpl.dependency import ServiceProviderABC
class LoggingMiddleware(ASGIMiddleware):
@ServiceProviderABC.inject
def __init__(self, app, logger: APILogger):
ASGIMiddleware.__init__(self, app)

View File

@@ -9,17 +9,18 @@ from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger
from cpl.api.typing import TRequest
from cpl.dependency import ServiceProviderABC
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
class RequestMiddleware(ASGIMiddleware):
@ServiceProviderABC.inject
def __init__(self, app, logger: APILogger):
def __init__(self, app, provider: ServiceProvider, logger: APILogger):
ASGIMiddleware.__init__(self, app)
self._provider = provider
self._logger = logger
self._ctx_token = None
@@ -29,7 +30,8 @@ class RequestMiddleware(ASGIMiddleware):
await self.set_request_data(request)
try:
await self._app(scope, receive, send)
with self._provider.create_scope():
inject(await self._app(scope, receive, send))
finally:
await self.clean_request_data()

View File

@@ -3,6 +3,7 @@ from enum import Enum
from cpl.api.model.validation_match import ValidationMatch
from cpl.api.registry.route import RouteRegistry
from cpl.api.typing import HTTPMethods
from cpl.dependency import get_provider
class Router:
@@ -95,9 +96,7 @@ class Router:
from cpl.api.model.api_route import ApiRoute
if not registry:
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = ServiceProviderABC.get_global_service(RouteRegistry)
routes = get_provider().get_service(RouteRegistry)
else:
routes = registry
@@ -144,9 +143,8 @@ class Router:
"""
from cpl.api.model.api_route import ApiRoute
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = ServiceProviderABC.get_global_service(RouteRegistry)
routes = get_provider().get_service(RouteRegistry)
def inner(fn):
path = getattr(fn, "_route_path", None)

View File

@@ -5,7 +5,7 @@ from cpl.application.host import Host
from cpl.core.log.log_level import LogLevel
from cpl.core.log.log_settings import LogSettings
from cpl.core.log.logger_abc import LoggerABC
from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.dependency.service_provider import ServiceProvider
def __not_implemented__(package: str, func: Callable):
@@ -16,12 +16,12 @@ class ApplicationABC(ABC):
r"""ABC for the Application class
Parameters:
services: :class:`cpl.dependency.service_provider_abc.ServiceProviderABC`
services: :class:`cpl.dependency.service_provider.ServiceProvider`
Contains instances of prepared objects
"""
@abstractmethod
def __init__(self, services: ServiceProviderABC, required_modules: list[str | object] = None):
def __init__(self, services: ServiceProvider, required_modules: list[str | object] = None):
self._services = services
self._required_modules = (
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []

View File

@@ -1,10 +1,10 @@
from abc import ABC, abstractmethod
from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.dependency.service_provider import ServiceProvider
class ApplicationExtensionABC(ABC):
@staticmethod
@abstractmethod
def run(services: ServiceProviderABC): ...
def run(services: ServiceProvider): ...

View File

@@ -7,6 +7,7 @@ from cpl.application.abc.startup_abc import StartupABC
from cpl.application.abc.startup_extension_abc import StartupExtensionABC
from cpl.application.host import Host
from cpl.core.errors import dependency_error
from cpl.dependency.context import get_provider, use_root_provider
from cpl.dependency.service_collection import ServiceCollection
TApp = TypeVar("TApp", bound=ApplicationABC)
@@ -21,6 +22,7 @@ class ApplicationBuilder(Generic[TApp]):
self._app = app if app is not None else ApplicationABC
self._services = ServiceCollection()
use_root_provider(self._services.build())
self._startup: Optional[StartupABC] = None
self._app_extensions: list[Type[ApplicationExtensionABC]] = []
@@ -34,7 +36,12 @@ class ApplicationBuilder(Generic[TApp]):
@property
def service_provider(self):
return self._services.build()
provider = get_provider()
if provider is None:
provider = self._services.build()
use_root_provider(provider)
return provider
def validate_app_required_modules(self, app: ApplicationABC):
for module in app.required_modules:

View File

@@ -1,5 +1,5 @@
from cpl.core.utils.get_value import get_value
from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
class KeycloakUser:
@@ -32,5 +32,5 @@ class KeycloakUser:
def id(self) -> str:
from cpl.auth import KeycloakAdmin
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
keycloak_admin: KeycloakAdmin = get_provider().get_service(KeycloakAdmin)
return keycloak_admin.get_user_id(self._username)

View File

@@ -10,7 +10,8 @@ from cpl.core.log.logger import Logger
from cpl.core.typing import Id, SerialId
from cpl.core.utils.credential_manager import CredentialManager
from cpl.database.abc.db_model_abc import DbModelABC
from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.dependency import get_provider
from cpl.dependency.service_provider import ServiceProvider
_logger = Logger(__name__)
@@ -47,7 +48,7 @@ class ApiKey(DbModelABC):
async def permissions(self):
from cpl.auth.schema._permission.api_key_permission_dao import ApiKeyPermissionDao
apiKeyPermissionDao = ServiceProviderABC.get_global_provider().get_service(ApiKeyPermissionDao)
apiKeyPermissionDao = get_provider().get_service(ApiKeyPermissionDao)
return [await x.permission for x in await apiKeyPermissionDao.find_by_api_key_id(self.id)]

View File

@@ -10,7 +10,7 @@ from cpl.auth.permission.permissions import Permissions
from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC
from cpl.database.logger import DBLogger
from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
class AuthUser(DbModelABC):
@@ -36,12 +36,12 @@ class AuthUser(DbModelABC):
return "ANONYMOUS"
try:
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin)
keycloak = get_provider().get_service(KeycloakAdmin)
return keycloak.get_user(self._keycloak_id).get("username")
except KeycloakGetError as e:
return "UNKNOWN"
except Exception as e:
logger = ServiceProviderABC.get_global_service(DBLogger)
logger = get_provider().get_service(DBLogger)
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
return "UNKNOWN"
@@ -51,12 +51,12 @@ class AuthUser(DbModelABC):
return "ANONYMOUS"
try:
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin)
keycloak = get_provider().get_service(KeycloakAdmin)
return keycloak.get_user(self._keycloak_id).get("email")
except KeycloakGetError as e:
return "UNKNOWN"
except Exception as e:
logger = ServiceProviderABC.get_global_service(DBLogger)
logger = get_provider().get_service(DBLogger)
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
return "UNKNOWN"
@@ -64,26 +64,26 @@ class AuthUser(DbModelABC):
async def roles(self):
from cpl.auth.schema._permission.role_user_dao import RoleUserDao
role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao)
return [await x.role for x in await role_user_dao.get_by_user_id(self.id)]
@async_property
async def permissions(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
return await auth_user_dao.get_permissions(self.id)
async def has_permission(self, permission: Permissions) -> bool:
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
return await auth_user_dao.has_permission(self.id, permission)
async def anonymize(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
self._keycloak_id = str(uuid.UUID(int=0))
await auth_user_dao.update(self)

View File

@@ -5,7 +5,7 @@ from cpl.auth.schema._administration.auth_user import AuthUser
from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
class AuthUserDao(DbModelDaoABC[AuthUser]):
@@ -36,7 +36,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool:
from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
result = await self._db.select_map(
f"""

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId
from cpl.database.abc import DbJoinModelABC
from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
class ApiKeyPermission(DbJoinModelABC):
@@ -31,7 +31,7 @@ class ApiKeyPermission(DbJoinModelABC):
async def api_key(self):
from cpl.auth.schema._administration.api_key_dao import ApiKeyDao
api_key_dao: ApiKeyDao = ServiceProviderABC.get_global_service(ApiKeyDao)
api_key_dao: ApiKeyDao = get_provider().get_service(ApiKeyDao)
return await api_key_dao.get_by_id(self._api_key_id)
@property
@@ -42,5 +42,5 @@ class ApiKeyPermission(DbJoinModelABC):
async def permission(self):
from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
return await permission_dao.get_by_id(self._permission_id)

View File

@@ -6,7 +6,7 @@ from async_property import async_property
from cpl.auth.permission.permissions import Permissions
from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
class Role(DbModelABC):
@@ -44,22 +44,22 @@ class Role(DbModelABC):
async def permissions(self):
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao)
return [await x.permission for x in await role_permission_dao.get_by_role_id(self.id)]
@async_property
async def users(self):
from cpl.auth.schema._permission.role_user_dao import RoleUserDao
role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao)
return [await x.user for x in await role_user_dao.get_by_role_id(self.id)]
async def has_permission(self, permission: Permissions) -> bool:
from cpl.auth.schema._permission.permission_dao import PermissionDao
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao)
p = await permission_dao.get_by_name(permission.value)

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
class RolePermission(DbModelABC):
@@ -31,7 +31,7 @@ class RolePermission(DbModelABC):
async def role(self):
from cpl.auth.schema._permission.role_dao import RoleDao
role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
role_dao: RoleDao = get_provider().get_service(RoleDao)
return await role_dao.get_by_id(self._role_id)
@property
@@ -42,5 +42,5 @@ class RolePermission(DbModelABC):
async def permission(self):
from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
return await permission_dao.get_by_id(self._permission_id)

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId
from cpl.database.abc import DbJoinModelABC
from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
class RoleUser(DbJoinModelABC):
@@ -31,7 +31,7 @@ class RoleUser(DbJoinModelABC):
async def user(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
return await auth_user_dao.get_by_id(self._user_id)
@property
@@ -42,5 +42,5 @@ class RoleUser(DbJoinModelABC):
async def role(self):
from cpl.auth.schema._permission.role_dao import RoleDao
role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
role_dao: RoleDao = get_provider().get_service(RoleDao)
return await role_dao.get_by_id(self._role_id)

View File

@@ -2,15 +2,15 @@ from contextvars import ContextVar
from typing import Optional
from cpl.auth.schema._administration.auth_user import AuthUser
from cpl.dependency import get_provider
_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None)
def set_user(user: Optional[AuthUser]):
from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.core.log.logger_abc import LoggerABC
logger = ServiceProviderABC.get_global_service(LoggerABC)
logger = get_provider().get_service(LoggerABC)
logger.trace("Setting user context", user.id)
_user_context.set(user)

View File

@@ -9,6 +9,7 @@ from starlette.requests import Request
from cpl.core.log.log_level import LogLevel
from cpl.core.log.logger import Logger
from cpl.core.typing import Source, Messages
from cpl.dependency import get_provider
class StructuredLogger(Logger):
@@ -99,10 +100,9 @@ class StructuredLogger(Logger):
if user is None:
return
from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin)
keycloak = get_provider().get_service(KeycloakAdmin)
kc_user = keycloak.get_user(user.keycloak_id)
message["user"] = {
"id": str(user.id),

View File

@@ -3,7 +3,8 @@ from typing import Type
from cpl.core.log import LoggerABC, LogLevel
from cpl.core.typing import Messages
from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
class WrappedLogger(LoggerABC):
@@ -17,13 +18,13 @@ class WrappedLogger(LoggerABC):
self._set_logger()
@ServiceProviderABC.inject
def _set_logger(self, services: ServiceProviderABC):
@inject
def _set_logger(self, services: ServiceProvider):
from cpl.core.log import Logger
t_logger: Type[Logger] = services.get_service_type(LoggerABC)
if t_logger is None:
raise Exception("No LoggerABC service registered in ServiceProviderABC")
raise Exception("No LoggerABC service registered in ServiceProvider")
self._logger = t_logger(self._source, self._file_prefix)
@@ -42,8 +43,8 @@ class WrappedLogger(LoggerABC):
from cpl.dependency import ServiceCollection
ignore_classes = [
ServiceProviderABC,
ServiceProviderABC.__subclasses__(),
ServiceProvider,
ServiceProvider.__subclasses__(),
ServiceCollection,
WrappedLogger,
WrappedLogger.__subclasses__(),

View File

@@ -114,12 +114,15 @@ class String:
characters = []
if letters:
characters.append(string.ascii_letters)
characters.extend(string.ascii_letters)
if digits:
characters.append(string.digits)
characters.extend(string.digits)
if special_characters:
characters.append(string.punctuation)
characters.extend(string.punctuation)
return "".join(random.choice(characters) for _ in range(length)) if characters else ""
x = "".join(random.choice(list(characters)) for _ in range(length)) if characters else ""
if len(x) != length:
raise Exception("No characters selected to generate random string")
return x

View File

@@ -9,21 +9,19 @@ from cpl.core.utils.get_value import get_value
from cpl.core.utils.string import String
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.const import DATETIME_FORMAT
from cpl.database.logger import DBLogger
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.database.logger import DBLogger
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
from cpl.dependency import get_provider
class DataAccessObjectABC(ABC, Generic[T_DBM]):
@abstractmethod
def __init__(self, model_type: Type[T_DBM], table_name: str):
from cpl.dependency.service_provider_abc import ServiceProviderABC
self._db = ServiceProviderABC.get_global_service(DBContextABC)
self._logger = ServiceProviderABC.get_global_service(DBLogger)
self._db = get_provider().get_service(DBContextABC)
self._logger = get_provider().get_service(DBLogger)
self._model_type = model_type
self._table_name = table_name

View File

@@ -6,7 +6,7 @@ from mysql.connector.aio import MySQLConnectionPool
from cpl.core.environment import Environment
from cpl.database.logger import DBLogger
from cpl.database.model import DatabaseSettings
from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
class MySQLPool:
@@ -35,7 +35,7 @@ class MySQLPool:
await cursor.execute("SELECT 1")
await cursor.fetchall()
except Exception as e:
logger = ServiceProviderABC.get_global_service(DBLogger)
logger = get_provider().get_service(DBLogger)
logger.fatal(f"Error connecting to the database: {e}")
finally:
await con.close()

View File

@@ -7,7 +7,7 @@ from psycopg_pool import AsyncConnectionPool, PoolTimeout
from cpl.core.environment import Environment
from cpl.database.logger import DBLogger
from cpl.database.model import DatabaseSettings
from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
class PostgresPool:
@@ -37,7 +37,7 @@ class PostgresPool:
await pool.check_connection(con)
except PoolTimeout as e:
await pool.close()
logger = ServiceProviderABC.get_global_service(DBLogger)
logger = get_provider().get_service(DBLogger)
logger.fatal(f"Failed to connect to the database", e)
self._pool = pool

View File

@@ -1,11 +1,11 @@
from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.logger import DBLogger
from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
class SeederService:
def __init__(self, provider: ServiceProviderABC):
def __init__(self, provider: ServiceProvider):
self._provider = provider
self._logger = provider.get_service(DBLogger)

View File

@@ -1,7 +1,7 @@
from .scope import Scope
from .scope_abc import ScopeABC
from .context import get_provider, use_provider
from .inject import inject
from .service_collection import ServiceCollection
from .service_descriptor import ServiceDescriptor
from .service_lifetime_enum import ServiceLifetimeEnum
from .service_lifetime import ServiceLifetimeEnum
from .service_provider import ServiceProvider
from .service_provider import ServiceProvider
from .service_provider_abc import ServiceProviderABC

View File

@@ -0,0 +1,21 @@
import contextvars
from contextlib import contextmanager
_current_provider = contextvars.ContextVar("current_provider", default=None)
def use_root_provider(provider):
_current_provider.set(provider)
@contextmanager
def use_provider(provider):
token = _current_provider.set(provider)
try:
yield
finally:
_current_provider.reset(token)
def get_provider():
return _current_provider.get()

View File

@@ -0,0 +1,42 @@
import functools
from asyncio import iscoroutinefunction
from inspect import signature
from cpl.dependency.context import get_provider
def inject(f=None):
if f is None:
return functools.partial(inject)
if iscoroutinefunction(f):
@functools.wraps(f)
async def async_inner(*args, **kwargs):
from cpl.dependency.service_provider import ServiceProvider
provider: ServiceProvider | None = get_provider()
if provider is None:
raise ValueError(
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
)
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
return await f(*args, *injection, **kwargs)
return async_inner
@functools.wraps(f)
def inner(*args, **kwargs):
from cpl.dependency.service_provider import ServiceProvider
provider: ServiceProvider | None = get_provider()
if provider is None:
raise ValueError(
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
)
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
return f(*args, *injection, **kwargs)
return inner

View File

@@ -1,22 +0,0 @@
from cpl.dependency.scope_abc import ScopeABC
from cpl.dependency.service_provider_abc import ServiceProviderABC
class Scope(ScopeABC):
def __init__(self, service_provider: ServiceProviderABC):
self._service_provider = service_provider
self._service_provider.set_scope(self)
ScopeABC.__init__(self)
def __enter__(self):
return self
def __exit__(self, *args):
self.dispose()
@property
def service_provider(self) -> ServiceProviderABC:
return self._service_provider
def dispose(self):
self._service_provider = None

View File

@@ -1,20 +0,0 @@
from abc import ABC, abstractmethod
class ScopeABC(ABC):
r"""ABC for the class :class:`cpl.dependency.scope.Scope`"""
def __init__(self): ...
@property
@abstractmethod
def service_provider(self):
r"""Returns to service provider of scope
Returns:
Object of type :class:`cpl.dependency.service_provider_abc.ServiceProviderABC`
"""
@abstractmethod
def dispose(self):
r"""Sets service_provider to None"""

View File

@@ -1,18 +0,0 @@
from cpl.dependency.scope import Scope
from cpl.dependency.scope_abc import ScopeABC
from cpl.dependency.service_provider_abc import ServiceProviderABC
class ScopeBuilder:
r"""Class to build :class:`cpl.dependency.scope.Scope`"""
def __init__(self, service_provider: ServiceProviderABC) -> None:
self._service_provider = service_provider
def build(self) -> ScopeABC:
r"""Returns scope
Returns:
Object of type :class:`cpl.dependency.scope.Scope`
"""
return Scope(self._service_provider)

View File

@@ -4,9 +4,8 @@ from cpl.core.log.logger_abc import LoggerABC
from cpl.core.typing import T, Service
from cpl.core.utils.cache import Cache
from cpl.dependency.service_descriptor import ServiceDescriptor
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.service_provider_abc import ServiceProviderABC
class ServiceCollection:
@@ -62,9 +61,8 @@ class ServiceCollection:
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service)
return self
def build(self) -> ServiceProviderABC:
def build(self) -> ServiceProvider:
sp = ServiceProvider(self._service_descriptors)
ServiceProviderABC.set_global_provider(sp)
return sp
def add_module(self, module: str | object) -> Self:

View File

@@ -1,6 +1,6 @@
from typing import Union, Optional
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
class ServiceDescriptor:

View File

@@ -0,0 +1,7 @@
from enum import Enum, auto
class ServiceLifetimeEnum(Enum):
singleton = auto()
scoped = auto()
transient = auto()

View File

@@ -1,7 +0,0 @@
from enum import Enum
class ServiceLifetimeEnum(Enum):
singleton = 0
scoped = 1
transient = 2

View File

@@ -1,5 +1,6 @@
import copy
import typing
from contextlib import contextmanager
from inspect import signature, Parameter, Signature
from typing import Optional, Type
@@ -7,34 +8,15 @@ from cpl.core.configuration import Configuration
from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC
from cpl.core.environment import Environment
from cpl.core.typing import T, R, Source
from cpl.dependency.scope_abc import ScopeABC
from cpl.dependency.scope_builder import ScopeBuilder
from cpl.dependency import use_provider
from cpl.dependency.service_descriptor import ServiceDescriptor
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
class ServiceProvider(ServiceProviderABC):
r"""Provider for the services
Parameter
---------
service_descriptors: list[:class:`cpl.dependency.service_descriptor.ServiceDescriptor`]
Descriptor of the service
config: :class:`cpl.core.configuration.configuration_abc.ConfigurationABC`
CPL Configuration
db_context: Optional[:class:`cpl.database.context.database_context_abc.DatabaseContextABC`]
Database representation
"""
def __init__(
self,
service_descriptors: list[ServiceDescriptor],
):
ServiceProviderABC.__init__(self)
class ServiceProvider:
def __init__(self, service_descriptors: list[ServiceDescriptor], is_scope: bool = False):
self._service_descriptors: list[ServiceDescriptor] = service_descriptors
self._scope: Optional[ScopeABC] = None
self._is_scope = is_scope
def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]:
origin_type = typing.get_origin(service_type) or service_type
@@ -67,7 +49,7 @@ class ServiceProvider(ServiceProviderABC):
return descriptor.implementation
implementation = self._build_service(descriptor.service_type, origin_service_type=origin_service_type)
if descriptor.lifetime == ServiceLifetimeEnum.singleton:
if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped):
descriptor.implementation = implementation
return implementation
@@ -85,7 +67,7 @@ class ServiceProvider(ServiceProviderABC):
implementation = self._build_service(
descriptor.service_type, origin_service_type=service_type, **kwargs
)
if descriptor.lifetime == ServiceLifetimeEnum.singleton:
if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped):
descriptor.implementation = implementation
implementations.append(implementation)
@@ -103,7 +85,7 @@ class ServiceProvider(ServiceProviderABC):
elif parameter.annotation == Source:
params.append(origin_service_type.__name__)
elif issubclass(parameter.annotation, ServiceProviderABC):
elif issubclass(parameter.annotation, ServiceProvider):
params.append(self)
elif issubclass(parameter.annotation, Environment):
@@ -131,32 +113,27 @@ class ServiceProvider(ServiceProviderABC):
service_type = type(descriptor.implementation)
else:
service_type = descriptor.service_type
break
sig = signature(service_type.__init__)
params = self._build_by_signature(sig, origin_service_type)
return service_type(*params, *args, **kwargs)
def set_scope(self, scope: ScopeABC):
self._scope = scope
def create_scope(self) -> ScopeABC:
descriptors = []
for descriptor in self._service_descriptors:
if descriptor.lifetime == ServiceLifetimeEnum.singleton:
descriptors.append(descriptor)
@contextmanager
def create_scope(self):
scoped_descriptors = []
for d in self._service_descriptors:
if d.lifetime == ServiceLifetimeEnum.singleton:
scoped_descriptors.append(d)
else:
descriptors.append(copy.deepcopy(descriptor))
scoped_descriptors.append(copy.deepcopy(d))
sb = ScopeBuilder(ServiceProvider(descriptors))
return sb.build()
scoped_provider = ServiceProvider(scoped_descriptors, is_scope=True)
with use_provider(scoped_provider):
yield scoped_provider
def get_service(self, service_type: T, *args, **kwargs) -> Optional[R]:
result = self._find_service(service_type)
if result is None:
return None
@@ -164,11 +141,10 @@ class ServiceProvider(ServiceProviderABC):
return result.implementation
implementation = self._build_service(service_type, *args, **kwargs)
if (
result.lifetime == ServiceLifetimeEnum.singleton
or result.lifetime == ServiceLifetimeEnum.scoped
and self._scope is not None
):
if result.lifetime == ServiceLifetimeEnum.singleton:
result.implementation = implementation
elif result.lifetime == ServiceLifetimeEnum.scoped and self._is_scope:
result.implementation = implementation
return implementation
@@ -181,12 +157,9 @@ class ServiceProvider(ServiceProviderABC):
def get_services(self, service_type: T, *args, **kwargs) -> list[Optional[R]]:
implementations = []
if typing.get_origin(service_type) == list:
raise Exception(f"Invalid type {service_type}! Expected single type not list of type")
implementations.extend(self._get_services(service_type))
return implementations
def get_service_types(self, service_type: Type[T]) -> list[Type[T]]:

View File

@@ -1,165 +0,0 @@
import functools
from abc import abstractmethod, ABC
from inspect import Signature, signature, iscoroutinefunction
from typing import Optional, Type
from cpl.core.typing import T, R
from cpl.dependency.scope_abc import ScopeABC
class ServiceProviderABC(ABC):
r"""ABC for the class :class:`cpl.dependency.service_provider.ServiceProvider`"""
_provider: Optional["ServiceProviderABC"] = None
@abstractmethod
def __init__(self): ...
@classmethod
def set_global_provider(cls, provider: "ServiceProviderABC"):
cls._provider = provider
@classmethod
def get_global_provider(cls) -> Optional["ServiceProviderABC"]:
return cls._provider
@classmethod
def get_global_service(cls, instance_type: Type[T], *args, **kwargs) -> Optional[T]:
if cls._provider is None:
return None
return cls._provider.get_service(instance_type, *args, **kwargs)
@classmethod
def get_global_services(cls, instance_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
if cls._provider is None:
return []
return cls._provider.get_services(instance_type, *args, **kwargs)
@abstractmethod
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[T]: ...
@abstractmethod
def _build_service(self, service_type: type, *args, **kwargs) -> object:
r"""Creates instance of given type
Parameter
---------
instance_type: :class:`type`
The type of the searched instance
Returns
-------
Object of the given type
"""
@abstractmethod
def set_scope(self, scope: ScopeABC):
r"""Sets the scope of service provider
Parameter
---------
Object of type :class:`cpl.dependency.scope_abc.ScopeABC`
Service scope
"""
@abstractmethod
def create_scope(self) -> ScopeABC:
r"""Creates a service scope
Returns
-------
Object of type :class:`cpl.dependency.scope_abc.ScopeABC`
"""
@abstractmethod
def get_service(self, instance_type: Type[T], *args, **kwargs) -> Optional[T]:
r"""Returns instance of given type
Parameter
---------
instance_type: :class:`cpl.core.type.T`
The type of the searched instance
Returns
-------
Object of type Optional[:class:`cpl.core.type.T`]
"""
@abstractmethod
def get_service_type(self, instance_type: Type[T]) -> Optional[Type[T]]:
r"""Returns the registered service type for loggers
Parameter
---------
instance_type: :class:`cpl.core.type.T`
The type of the searched instance
Returns
-------
Object of type Optional[:class:`type`]
"""
@abstractmethod
def get_services(self, service_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
r"""Returns instance of given type
Parameter
---------
service_type: :class:`cpl.core.type.T`
The type of the searched instance
Returns
-------
Object of type list[Optional[:class:`cpl.core.type.T`]
"""
@abstractmethod
def get_service_types(self, service_type: Type[T]) -> list[Type[T]]:
r"""Returns all registered service types
Parameter
---------
service_type: :class:`cpl.core.type.T`
The type of the searched instance
Returns
-------
Object of type list[:class:`type`]
"""
@classmethod
def inject(cls, f=None):
r"""Decorator to allow injection into static and class methods
Parameter
---------
f: Callable
Returns
-------
function
"""
if f is None:
return functools.partial(cls.inject)
if iscoroutinefunction(f):
@functools.wraps(f)
async def async_inner(*args, **kwargs):
if cls._provider is None:
raise Exception(f"{cls.__name__} not build!")
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
return await f(*args, *injection, **kwargs)
return async_inner
@functools.wraps(f)
def inner(*args, **kwargs):
if cls._provider is None:
raise Exception(f"{cls.__name__} not build!")
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
return f(*args, *injection, **kwargs)
return inner

View File

@@ -2,7 +2,7 @@ import unittest
from cpl.application import ApplicationABC
from cpl.core.configuration import ConfigurationABC
from cpl.dependency import ServiceProviderABC
from cpl.dependency import ServiceProvider
from unittests_cli.cli_test_suite import CLITestSuite
from unittests_core.core_test_suite import CoreTestSuite
from unittests_query.query_test_suite import QueryTestSuite
@@ -10,7 +10,7 @@ from unittests_translation.translation_test_suite import TranslationTestSuite
class Application(ApplicationABC):
def __init__(self, config: ConfigurationABC, services: ServiceProviderABC):
def __init__(self, config: ConfigurationABC, services: ServiceProvider):
ApplicationABC.__init__(self, config, services)
def configure(self): ...

View File

@@ -2,7 +2,7 @@ import unittest
from unittest.mock import Mock
from cpl.core.configuration import Configuration
from cpl.dependency import ServiceCollection, ServiceLifetimeEnum, ServiceProviderABC
from cpl.dependency import ServiceCollection, ServiceLifetimeEnum, ServiceProvider
class ServiceCollectionTestCase(unittest.TestCase):
@@ -51,6 +51,6 @@ class ServiceCollectionTestCase(unittest.TestCase):
service = self._sc._service_descriptors[0]
self.assertIsNone(service.implementation)
sp = self._sc.build()
self.assertTrue(isinstance(sp, ServiceProviderABC))
self.assertTrue(isinstance(sp, ServiceProvider))
self.assertTrue(isinstance(sp.get_service(Mock), Mock))
self.assertIsNotNone(service.implementation)

View File

@@ -1,7 +1,7 @@
import unittest
from cpl.core.configuration import Configuration
from cpl.dependency import ServiceCollection, ServiceProviderABC
from cpl.dependency import ServiceCollection, ServiceProvider
class ServiceCount:
@@ -10,21 +10,21 @@ class ServiceCount:
class TestService:
def __init__(self, sp: ServiceProviderABC, count: ServiceCount):
def __init__(self, sp: ServiceProvider, count: ServiceCount):
count.count += 1
self.sp = sp
self.id = count.count
class DifferentService:
def __init__(self, sp: ServiceProviderABC, count: ServiceCount):
def __init__(self, sp: ServiceProvider, count: ServiceCount):
count.count += 1
self.sp = sp
self.id = count.count
class MoreDifferentService:
def __init__(self, sp: ServiceProviderABC, count: ServiceCount):
def __init__(self, sp: ServiceProvider, count: ServiceCount):
count.count += 1
self.sp = sp
self.id = count.count
@@ -72,7 +72,7 @@ class ServiceProviderTestCase(unittest.TestCase):
singleton = self._services.get_service(TestService)
transient = self._services.get_service(DifferentService)
with self._services.create_scope() as scope:
sp: ServiceProviderABC = scope.service_provider
sp: ServiceProvider = scope.service_provider
self.assertNotEqual(sp, self._services)
y = sp.get_service(DifferentService)
self.assertIsNotNone(y)