WIP: dev into master #184

Draft
edraft wants to merge 121 commits from dev into master
44 changed files with 108 additions and 326 deletions
Showing only changes of commit 4c8cd988cc - Show all commits

View File

@@ -3,7 +3,7 @@ from cpl.auth.keycloak import KeycloakAdmin
from cpl.core.console import Console from cpl.core.console import Console
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.core.log import LoggerABC from cpl.core.log import LoggerABC
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
from model.city import City from model.city import City
from model.city_dao import CityDao from model.city_dao import CityDao
from model.user import User from model.user import User
@@ -11,7 +11,7 @@ from model.user_dao import UserDao
class Application(ApplicationABC): class Application(ApplicationABC):
def __init__(self, services: ServiceProviderABC): def __init__(self, services: ServiceProvider):
ApplicationABC.__init__(self, services) ApplicationABC.__init__(self, services)
self._logger = services.get_service(LoggerABC) self._logger = services.get_service(LoggerABC)

View File

@@ -1,6 +1,6 @@
from cpl.application.abc import ApplicationABC from cpl.application.abc import ApplicationABC
from cpl.core.console.console import Console from cpl.core.console.console import Console
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
from cpl.dependency.scope import Scope from cpl.dependency.scope import Scope
from di.static_test import StaticTest from di.static_test import StaticTest
from di.test_abc import TestABC from di.test_abc import TestABC
@@ -10,7 +10,7 @@ from di.tester import Tester
class Application(ApplicationABC): class Application(ApplicationABC):
def __init__(self, services: ServiceProviderABC): def __init__(self, services: ServiceProvider):
ApplicationABC.__init__(self, services) ApplicationABC.__init__(self, services)
def _part_of_scoped(self): def _part_of_scoped(self):

View File

@@ -1,5 +1,5 @@
from cpl.application.abc import StartupABC from cpl.application.abc import StartupABC
from cpl.dependency import ServiceProviderABC, ServiceCollection from cpl.dependency import ServiceProvider, ServiceCollection
from di.di_tester_service import DITesterService from di.di_tester_service import DITesterService
from di.test1_service import Test1Service from di.test1_service import Test1Service
from di.test2_service import Test2Service from di.test2_service import Test2Service
@@ -14,7 +14,7 @@ class Startup(StartupABC):
def configure_configuration(self): ... def configure_configuration(self): ...
def configure_services(self, services: ServiceCollection) -> ServiceProviderABC: def configure_services(self, services: ServiceCollection) -> ServiceProvider:
services.add_scoped(TestService) services.add_scoped(TestService)
services.add_scoped(DITesterService) services.add_scoped(DITesterService)

View File

@@ -1,4 +1,4 @@
from cpl.dependency import ServiceProvider, ServiceProviderABC from cpl.dependency import ServiceProvider, ServiceProvider
from cpl.dependency.inject import inject from cpl.dependency.inject import inject
from di.test_service import TestService from di.test_service import TestService
@@ -6,5 +6,5 @@ from di.test_service import TestService
class StaticTest: class StaticTest:
@staticmethod @staticmethod
@inject @inject
def test(services: ServiceProviderABC, t1: TestService): def test(services: ServiceProvider, t1: TestService):
t1.run() t1.run()

View File

@@ -4,7 +4,7 @@ from typing import Optional
from cpl.application.abc import ApplicationABC from cpl.application.abc import ApplicationABC
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.core.console import Console from cpl.core.console import Console
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.core.log import LoggerABC from cpl.core.log import LoggerABC
from cpl.core.pipes import IPAddressPipe from cpl.core.pipes import IPAddressPipe
@@ -16,7 +16,7 @@ from test_settings import TestSettings
class Application(ApplicationABC): class Application(ApplicationABC):
def __init__(self, services: ServiceProviderABC): def __init__(self, services: ServiceProvider):
ApplicationABC.__init__(self, services) ApplicationABC.__init__(self, services)
self._logger = self._services.get_service(LoggerABC) self._logger = self._services.get_service(LoggerABC)
self._mailer = self._services.get_service(EMailClientABC) self._mailer = self._services.get_service(EMailClientABC)

View File

@@ -1,10 +1,10 @@
from cpl.application.abc import ApplicationExtensionABC from cpl.application.abc import ApplicationExtensionABC
from cpl.core.console import Console from cpl.core.console import Console
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
class TestExtension(ApplicationExtensionABC): class TestExtension(ApplicationExtensionABC):
@staticmethod @staticmethod
def run(services: ServiceProviderABC): def run(services: ServiceProvider):
Console.write_line("Hello World from App Extension") Console.write_line("Hello World from App Extension")

View File

@@ -1,10 +1,10 @@
from cpl.core.console.console import Console from cpl.core.console.console import Console
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
from cpl.core.pipes.ip_address_pipe import IPAddressPipe from cpl.core.pipes.ip_address_pipe import IPAddressPipe
class TestService: class TestService:
def __init__(self, provider: ServiceProviderABC): def __init__(self, provider: ServiceProvider):
self._provider = provider self._provider = provider
def run(self): def run(self):

View File

@@ -1,14 +1,14 @@
from cpl.application import ApplicationABC from cpl.application import ApplicationABC
from cpl.core.configuration import ConfigurationABC from cpl.core.configuration import ConfigurationABC
from cpl.core.console import Console from cpl.core.console import Console
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
from cpl.translation.translate_pipe import TranslatePipe from cpl.translation.translate_pipe import TranslatePipe
from cpl.translation.translation_service_abc import TranslationServiceABC from cpl.translation.translation_service_abc import TranslationServiceABC
from cpl.translation.translation_settings import TranslationSettings from cpl.translation.translation_settings import TranslationSettings
class Application(ApplicationABC): class Application(ApplicationABC):
def __init__(self, config: ConfigurationABC, services: ServiceProviderABC): def __init__(self, config: ConfigurationABC, services: ServiceProvider):
ApplicationABC.__init__(self, config, services) ApplicationABC.__init__(self, config, services)
self._translate: TranslatePipe = services.get_service(TranslatePipe) self._translate: TranslatePipe = services.get_service(TranslatePipe)

View File

@@ -1,6 +1,6 @@
from cpl.application import StartupABC from cpl.application import StartupABC
from cpl.core.configuration import ConfigurationABC from cpl.core.configuration import ConfigurationABC
from cpl.dependency import ServiceProviderABC, ServiceCollection from cpl.dependency import ServiceProvider, ServiceCollection
from cpl.core.environment import Environment from cpl.core.environment import Environment
@@ -12,6 +12,6 @@ class Startup(StartupABC):
configuration.add_json_file("appsettings.json") configuration.add_json_file("appsettings.json")
return configuration return configuration
def configure_services(self, services: ServiceCollection, environment: Environment) -> ServiceProviderABC: def configure_services(self, services: ServiceCollection, environment: Environment) -> ServiceProvider:
services.add_translation() services.add_translation()
return services.build() return services.build()

View File

@@ -28,14 +28,14 @@ from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
from cpl.application.abc.application_abc import ApplicationABC from cpl.application.abc.application_abc import ApplicationABC
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.dependency.inject import inject from cpl.dependency.inject import inject
from cpl.dependency.service_provider_abc import ServiceProviderABC from cpl.dependency.service_provider import ServiceProvider
PolicyInput = Union[dict[str, PolicyResolver], Policy] PolicyInput = Union[dict[str, PolicyResolver], Policy]
class WebApp(ApplicationABC): class WebApp(ApplicationABC):
def __init__(self, services: ServiceProviderABC): def __init__(self, services: ServiceProvider):
super().__init__(services, [auth, api]) super().__init__(services, [auth, api])
self._app: Starlette | None = None self._app: Starlette | None = None

View File

@@ -3,6 +3,7 @@ from enum import Enum
from cpl.api.model.validation_match import ValidationMatch from cpl.api.model.validation_match import ValidationMatch
from cpl.api.registry.route import RouteRegistry from cpl.api.registry.route import RouteRegistry
from cpl.api.typing import HTTPMethods from cpl.api.typing import HTTPMethods
from cpl.dependency import get_provider
class Router: class Router:
@@ -95,9 +96,7 @@ class Router:
from cpl.api.model.api_route import ApiRoute from cpl.api.model.api_route import ApiRoute
if not registry: if not registry:
from cpl.dependency.service_provider_abc import ServiceProviderABC routes = get_provider().get_service(RouteRegistry)
routes = ServiceProviderABC.get_global_service(RouteRegistry)
else: else:
routes = registry routes = registry
@@ -144,9 +143,8 @@ class Router:
""" """
from cpl.api.model.api_route import ApiRoute from cpl.api.model.api_route import ApiRoute
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = ServiceProviderABC.get_global_service(RouteRegistry) routes = get_provider().get_service(RouteRegistry)
def inner(fn): def inner(fn):
path = getattr(fn, "_route_path", None) path = getattr(fn, "_route_path", None)

View File

@@ -5,7 +5,7 @@ from cpl.application.host import Host
from cpl.core.log.log_level import LogLevel from cpl.core.log.log_level import LogLevel
from cpl.core.log.log_settings import LogSettings from cpl.core.log.log_settings import LogSettings
from cpl.core.log.logger_abc import LoggerABC from cpl.core.log.logger_abc import LoggerABC
from cpl.dependency.service_provider_abc import ServiceProviderABC from cpl.dependency.service_provider import ServiceProvider
def __not_implemented__(package: str, func: Callable): def __not_implemented__(package: str, func: Callable):
@@ -16,12 +16,12 @@ class ApplicationABC(ABC):
r"""ABC for the Application class r"""ABC for the Application class
Parameters: Parameters:
services: :class:`cpl.dependency.service_provider_abc.ServiceProviderABC` services: :class:`cpl.dependency.service_provider.ServiceProvider`
Contains instances of prepared objects Contains instances of prepared objects
""" """
@abstractmethod @abstractmethod
def __init__(self, services: ServiceProviderABC, required_modules: list[str | object] = None): def __init__(self, services: ServiceProvider, required_modules: list[str | object] = None):
self._services = services self._services = services
self._required_modules = ( self._required_modules = (
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else [] [x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []

View File

@@ -1,10 +1,10 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from cpl.dependency.service_provider_abc import ServiceProviderABC from cpl.dependency.service_provider import ServiceProvider
class ApplicationExtensionABC(ABC): class ApplicationExtensionABC(ABC):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def run(services: ServiceProviderABC): ... def run(services: ServiceProvider): ...

View File

@@ -7,7 +7,7 @@ from cpl.application.abc.startup_abc import StartupABC
from cpl.application.abc.startup_extension_abc import StartupExtensionABC from cpl.application.abc.startup_extension_abc import StartupExtensionABC
from cpl.application.host import Host from cpl.application.host import Host
from cpl.core.errors import dependency_error from cpl.core.errors import dependency_error
from cpl.dependency.context import get_current_provider, use_root_provider from cpl.dependency.context import get_provider, use_root_provider
from cpl.dependency.service_collection import ServiceCollection from cpl.dependency.service_collection import ServiceCollection
TApp = TypeVar("TApp", bound=ApplicationABC) TApp = TypeVar("TApp", bound=ApplicationABC)
@@ -36,7 +36,7 @@ class ApplicationBuilder(Generic[TApp]):
@property @property
def service_provider(self): def service_provider(self):
provider = get_current_provider() provider = get_provider()
if provider is None: if provider is None:
provider = self._services.build() provider = self._services.build()
use_root_provider(provider) use_root_provider(provider)

View File

@@ -1,5 +1,5 @@
from cpl.core.utils.get_value import get_value from cpl.core.utils.get_value import get_value
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
class KeycloakUser: class KeycloakUser:
@@ -32,5 +32,5 @@ class KeycloakUser:
def id(self) -> str: def id(self) -> str:
from cpl.auth import KeycloakAdmin from cpl.auth import KeycloakAdmin
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin) keycloak_admin: KeycloakAdmin = get_provider().get_service(KeycloakAdmin)
return keycloak_admin.get_user_id(self._username) return keycloak_admin.get_user_id(self._username)

View File

@@ -10,7 +10,8 @@ from cpl.core.log.logger import Logger
from cpl.core.typing import Id, SerialId from cpl.core.typing import Id, SerialId
from cpl.core.utils.credential_manager import CredentialManager from cpl.core.utils.credential_manager import CredentialManager
from cpl.database.abc.db_model_abc import DbModelABC from cpl.database.abc.db_model_abc import DbModelABC
from cpl.dependency.service_provider_abc import ServiceProviderABC from cpl.dependency import get_provider
from cpl.dependency.service_provider import ServiceProvider
_logger = Logger(__name__) _logger = Logger(__name__)
@@ -47,7 +48,7 @@ class ApiKey(DbModelABC):
async def permissions(self): async def permissions(self):
from cpl.auth.schema._permission.api_key_permission_dao import ApiKeyPermissionDao from cpl.auth.schema._permission.api_key_permission_dao import ApiKeyPermissionDao
apiKeyPermissionDao = ServiceProviderABC.get_global_provider().get_service(ApiKeyPermissionDao) apiKeyPermissionDao = get_provider().get_service(ApiKeyPermissionDao)
return [await x.permission for x in await apiKeyPermissionDao.find_by_api_key_id(self.id)] return [await x.permission for x in await apiKeyPermissionDao.find_by_api_key_id(self.id)]

View File

@@ -10,7 +10,7 @@ from cpl.auth.permission.permissions import Permissions
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
from cpl.database.logger import DBLogger from cpl.database.logger import DBLogger
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
class AuthUser(DbModelABC): class AuthUser(DbModelABC):
@@ -36,12 +36,12 @@ class AuthUser(DbModelABC):
return "ANONYMOUS" return "ANONYMOUS"
try: try:
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin) keycloak = get_provider().get_service(KeycloakAdmin)
return keycloak.get_user(self._keycloak_id).get("username") return keycloak.get_user(self._keycloak_id).get("username")
except KeycloakGetError as e: except KeycloakGetError as e:
return "UNKNOWN" return "UNKNOWN"
except Exception as e: except Exception as e:
logger = ServiceProviderABC.get_global_service(DBLogger) logger = get_provider().get_service(DBLogger)
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e) logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
return "UNKNOWN" return "UNKNOWN"
@@ -51,12 +51,12 @@ class AuthUser(DbModelABC):
return "ANONYMOUS" return "ANONYMOUS"
try: try:
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin) keycloak = get_provider().get_service(KeycloakAdmin)
return keycloak.get_user(self._keycloak_id).get("email") return keycloak.get_user(self._keycloak_id).get("email")
except KeycloakGetError as e: except KeycloakGetError as e:
return "UNKNOWN" return "UNKNOWN"
except Exception as e: except Exception as e:
logger = ServiceProviderABC.get_global_service(DBLogger) logger = get_provider().get_service(DBLogger)
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e) logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
return "UNKNOWN" return "UNKNOWN"
@@ -64,26 +64,26 @@ class AuthUser(DbModelABC):
async def roles(self): async def roles(self):
from cpl.auth.schema._permission.role_user_dao import RoleUserDao from cpl.auth.schema._permission.role_user_dao import RoleUserDao
role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao) role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao)
return [await x.role for x in await role_user_dao.get_by_user_id(self.id)] return [await x.role for x in await role_user_dao.get_by_user_id(self.id)]
@async_property @async_property
async def permissions(self): async def permissions(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao) auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
return await auth_user_dao.get_permissions(self.id) return await auth_user_dao.get_permissions(self.id)
async def has_permission(self, permission: Permissions) -> bool: async def has_permission(self, permission: Permissions) -> bool:
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao) auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
return await auth_user_dao.has_permission(self.id, permission) return await auth_user_dao.has_permission(self.id, permission)
async def anonymize(self): async def anonymize(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao) auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
self._keycloak_id = str(uuid.UUID(int=0)) self._keycloak_id = str(uuid.UUID(int=0))
await auth_user_dao.update(self) await auth_user_dao.update(self)

View File

@@ -5,7 +5,7 @@ from cpl.auth.schema._administration.auth_user import AuthUser
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
class AuthUserDao(DbModelDaoABC[AuthUser]): class AuthUserDao(DbModelDaoABC[AuthUser]):
@@ -36,7 +36,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool: async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool:
from cpl.auth.schema._permission.permission_dao import PermissionDao from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao) permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value) p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
result = await self._db.select_map( result = await self._db.select_map(
f""" f"""

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbJoinModelABC from cpl.database.abc import DbJoinModelABC
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
class ApiKeyPermission(DbJoinModelABC): class ApiKeyPermission(DbJoinModelABC):
@@ -31,7 +31,7 @@ class ApiKeyPermission(DbJoinModelABC):
async def api_key(self): async def api_key(self):
from cpl.auth.schema._administration.api_key_dao import ApiKeyDao from cpl.auth.schema._administration.api_key_dao import ApiKeyDao
api_key_dao: ApiKeyDao = ServiceProviderABC.get_global_service(ApiKeyDao) api_key_dao: ApiKeyDao = get_provider().get_service(ApiKeyDao)
return await api_key_dao.get_by_id(self._api_key_id) return await api_key_dao.get_by_id(self._api_key_id)
@property @property
@@ -42,5 +42,5 @@ class ApiKeyPermission(DbJoinModelABC):
async def permission(self): async def permission(self):
from cpl.auth.schema._permission.permission_dao import PermissionDao from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao) permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
return await permission_dao.get_by_id(self._permission_id) return await permission_dao.get_by_id(self._permission_id)

View File

@@ -6,7 +6,7 @@ from async_property import async_property
from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions import Permissions
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
class Role(DbModelABC): class Role(DbModelABC):
@@ -44,22 +44,22 @@ class Role(DbModelABC):
async def permissions(self): async def permissions(self):
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao) role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao)
return [await x.permission for x in await role_permission_dao.get_by_role_id(self.id)] return [await x.permission for x in await role_permission_dao.get_by_role_id(self.id)]
@async_property @async_property
async def users(self): async def users(self):
from cpl.auth.schema._permission.role_user_dao import RoleUserDao from cpl.auth.schema._permission.role_user_dao import RoleUserDao
role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao) role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao)
return [await x.user for x in await role_user_dao.get_by_role_id(self.id)] return [await x.user for x in await role_user_dao.get_by_role_id(self.id)]
async def has_permission(self, permission: Permissions) -> bool: async def has_permission(self, permission: Permissions) -> bool:
from cpl.auth.schema._permission.permission_dao import PermissionDao from cpl.auth.schema._permission.permission_dao import PermissionDao
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao) permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao) role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao)
p = await permission_dao.get_by_name(permission.value) p = await permission_dao.get_by_name(permission.value)

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
class RolePermission(DbModelABC): class RolePermission(DbModelABC):
@@ -31,7 +31,7 @@ class RolePermission(DbModelABC):
async def role(self): async def role(self):
from cpl.auth.schema._permission.role_dao import RoleDao from cpl.auth.schema._permission.role_dao import RoleDao
role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao) role_dao: RoleDao = get_provider().get_service(RoleDao)
return await role_dao.get_by_id(self._role_id) return await role_dao.get_by_id(self._role_id)
@property @property
@@ -42,5 +42,5 @@ class RolePermission(DbModelABC):
async def permission(self): async def permission(self):
from cpl.auth.schema._permission.permission_dao import PermissionDao from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao) permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
return await permission_dao.get_by_id(self._permission_id) return await permission_dao.get_by_id(self._permission_id)

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbJoinModelABC from cpl.database.abc import DbJoinModelABC
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
class RoleUser(DbJoinModelABC): class RoleUser(DbJoinModelABC):
@@ -31,7 +31,7 @@ class RoleUser(DbJoinModelABC):
async def user(self): async def user(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao) auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
return await auth_user_dao.get_by_id(self._user_id) return await auth_user_dao.get_by_id(self._user_id)
@property @property
@@ -42,5 +42,5 @@ class RoleUser(DbJoinModelABC):
async def role(self): async def role(self):
from cpl.auth.schema._permission.role_dao import RoleDao from cpl.auth.schema._permission.role_dao import RoleDao
role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao) role_dao: RoleDao = get_provider().get_service(RoleDao)
return await role_dao.get_by_id(self._role_id) return await role_dao.get_by_id(self._role_id)

View File

@@ -2,15 +2,15 @@ from contextvars import ContextVar
from typing import Optional from typing import Optional
from cpl.auth.schema._administration.auth_user import AuthUser from cpl.auth.schema._administration.auth_user import AuthUser
from cpl.dependency import get_provider
_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None) _user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None)
def set_user(user: Optional[AuthUser]): def set_user(user: Optional[AuthUser]):
from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.core.log.logger_abc import LoggerABC from cpl.core.log.logger_abc import LoggerABC
logger = ServiceProviderABC.get_global_service(LoggerABC) logger = get_provider().get_service(LoggerABC)
logger.trace("Setting user context", user.id) logger.trace("Setting user context", user.id)
_user_context.set(user) _user_context.set(user)

View File

@@ -9,6 +9,7 @@ from starlette.requests import Request
from cpl.core.log.log_level import LogLevel from cpl.core.log.log_level import LogLevel
from cpl.core.log.logger import Logger from cpl.core.log.logger import Logger
from cpl.core.typing import Source, Messages from cpl.core.typing import Source, Messages
from cpl.dependency import get_provider
class StructuredLogger(Logger): class StructuredLogger(Logger):
@@ -99,10 +100,9 @@ class StructuredLogger(Logger):
if user is None: if user is None:
return return
from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin) keycloak = get_provider().get_service(KeycloakAdmin)
kc_user = keycloak.get_user(user.keycloak_id) kc_user = keycloak.get_user(user.keycloak_id)
message["user"] = { message["user"] = {
"id": str(user.id), "id": str(user.id),

View File

@@ -4,7 +4,7 @@ from typing import Type
from cpl.core.log import LoggerABC, LogLevel from cpl.core.log import LoggerABC, LogLevel
from cpl.core.typing import Messages from cpl.core.typing import Messages
from cpl.dependency.inject import inject from cpl.dependency.inject import inject
from cpl.dependency.service_provider_abc import ServiceProviderABC from cpl.dependency.service_provider import ServiceProvider
class WrappedLogger(LoggerABC): class WrappedLogger(LoggerABC):
@@ -19,12 +19,12 @@ class WrappedLogger(LoggerABC):
self._set_logger() self._set_logger()
@inject @inject
def _set_logger(self, services: ServiceProviderABC): def _set_logger(self, services: ServiceProvider):
from cpl.core.log import Logger from cpl.core.log import Logger
t_logger: Type[Logger] = services.get_service_type(LoggerABC) t_logger: Type[Logger] = services.get_service_type(LoggerABC)
if t_logger is None: if t_logger is None:
raise Exception("No LoggerABC service registered in ServiceProviderABC") raise Exception("No LoggerABC service registered in ServiceProvider")
self._logger = t_logger(self._source, self._file_prefix) self._logger = t_logger(self._source, self._file_prefix)
@@ -43,8 +43,8 @@ class WrappedLogger(LoggerABC):
from cpl.dependency import ServiceCollection from cpl.dependency import ServiceCollection
ignore_classes = [ ignore_classes = [
ServiceProviderABC, ServiceProvider,
ServiceProviderABC.__subclasses__(), ServiceProvider.__subclasses__(),
ServiceCollection, ServiceCollection,
WrappedLogger, WrappedLogger,
WrappedLogger.__subclasses__(), WrappedLogger.__subclasses__(),

View File

@@ -9,21 +9,19 @@ from cpl.core.utils.get_value import get_value
from cpl.core.utils.string import String from cpl.core.utils.string import String
from cpl.database.abc.db_context_abc import DBContextABC from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.const import DATETIME_FORMAT from cpl.database.const import DATETIME_FORMAT
from cpl.database.logger import DBLogger
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.database.logger import DBLogger
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
from cpl.dependency import get_provider
class DataAccessObjectABC(ABC, Generic[T_DBM]): class DataAccessObjectABC(ABC, Generic[T_DBM]):
@abstractmethod @abstractmethod
def __init__(self, model_type: Type[T_DBM], table_name: str): def __init__(self, model_type: Type[T_DBM], table_name: str):
from cpl.dependency.service_provider_abc import ServiceProviderABC self._db = get_provider().get_service(DBContextABC)
self._logger = get_provider().get_service(DBLogger)
self._db = ServiceProviderABC.get_global_service(DBContextABC)
self._logger = ServiceProviderABC.get_global_service(DBLogger)
self._model_type = model_type self._model_type = model_type
self._table_name = table_name self._table_name = table_name

View File

@@ -6,7 +6,7 @@ from mysql.connector.aio import MySQLConnectionPool
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.database.logger import DBLogger from cpl.database.logger import DBLogger
from cpl.database.model import DatabaseSettings from cpl.database.model import DatabaseSettings
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
class MySQLPool: class MySQLPool:
@@ -35,7 +35,7 @@ class MySQLPool:
await cursor.execute("SELECT 1") await cursor.execute("SELECT 1")
await cursor.fetchall() await cursor.fetchall()
except Exception as e: except Exception as e:
logger = ServiceProviderABC.get_global_service(DBLogger) logger = get_provider().get_service(DBLogger)
logger.fatal(f"Error connecting to the database: {e}") logger.fatal(f"Error connecting to the database: {e}")
finally: finally:
await con.close() await con.close()

View File

@@ -7,7 +7,7 @@ from psycopg_pool import AsyncConnectionPool, PoolTimeout
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.database.logger import DBLogger from cpl.database.logger import DBLogger
from cpl.database.model import DatabaseSettings from cpl.database.model import DatabaseSettings
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
class PostgresPool: class PostgresPool:
@@ -37,7 +37,7 @@ class PostgresPool:
await pool.check_connection(con) await pool.check_connection(con)
except PoolTimeout as e: except PoolTimeout as e:
await pool.close() await pool.close()
logger = ServiceProviderABC.get_global_service(DBLogger) logger = get_provider().get_service(DBLogger)
logger.fatal(f"Failed to connect to the database", e) logger.fatal(f"Failed to connect to the database", e)
self._pool = pool self._pool = pool

View File

@@ -1,11 +1,11 @@
from cpl.database.abc.data_seeder_abc import DataSeederABC from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.logger import DBLogger from cpl.database.logger import DBLogger
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
class SeederService: class SeederService:
def __init__(self, provider: ServiceProviderABC): def __init__(self, provider: ServiceProvider):
self._provider = provider self._provider = provider
self._logger = provider.get_service(DBLogger) self._logger = provider.get_service(DBLogger)

View File

@@ -1,9 +1,7 @@
from .context import get_current_provider, use_provider from .context import get_provider, use_provider
from .inject import inject from .inject import inject
from .scope import Scope
from .scope_abc import ScopeABC
from .service_collection import ServiceCollection from .service_collection import ServiceCollection
from .service_descriptor import ServiceDescriptor from .service_descriptor import ServiceDescriptor
from .service_lifetime_enum import ServiceLifetimeEnum from .service_lifetime import ServiceLifetimeEnum
from .service_provider import ServiceProvider
from .service_provider import ServiceProvider from .service_provider import ServiceProvider
from .service_provider_abc import ServiceProviderABC

View File

@@ -17,5 +17,5 @@ def use_provider(provider):
_current_provider.reset(token) _current_provider.reset(token)
def get_current_provider(): def get_provider():
return _current_provider.get() return _current_provider.get()

View File

@@ -2,7 +2,7 @@ import functools
from asyncio import iscoroutinefunction from asyncio import iscoroutinefunction
from inspect import signature from inspect import signature
from cpl.dependency.context import get_current_provider from cpl.dependency.context import get_provider
def inject(f=None): def inject(f=None):
@@ -15,7 +15,7 @@ def inject(f=None):
async def async_inner(*args, **kwargs): async def async_inner(*args, **kwargs):
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
provider: ServiceProvider | None = get_current_provider() provider: ServiceProvider | None = get_provider()
if provider is None: if provider is None:
raise ValueError( raise ValueError(
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context." "No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
@@ -30,7 +30,7 @@ def inject(f=None):
def inner(*args, **kwargs): def inner(*args, **kwargs):
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
provider: ServiceProvider | None = get_current_provider() provider: ServiceProvider | None = get_provider()
if provider is None: if provider is None:
raise ValueError( raise ValueError(
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context." "No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."

View File

@@ -1,22 +0,0 @@
from cpl.dependency.scope_abc import ScopeABC
from cpl.dependency.service_provider_abc import ServiceProviderABC
class Scope(ScopeABC):
def __init__(self, service_provider: ServiceProviderABC):
self._service_provider = service_provider
self._service_provider.set_scope(self)
ScopeABC.__init__(self)
def __enter__(self):
return self
def __exit__(self, *args):
self.dispose()
@property
def service_provider(self) -> ServiceProviderABC:
return self._service_provider
def dispose(self):
self._service_provider = None

View File

@@ -1,20 +0,0 @@
from abc import ABC, abstractmethod
class ScopeABC(ABC):
r"""ABC for the class :class:`cpl.dependency.scope.Scope`"""
def __init__(self): ...
@property
@abstractmethod
def service_provider(self):
r"""Returns to service provider of scope
Returns:
Object of type :class:`cpl.dependency.service_provider_abc.ServiceProviderABC`
"""
@abstractmethod
def dispose(self):
r"""Sets service_provider to None"""

View File

@@ -1,18 +0,0 @@
from cpl.dependency.scope import Scope
from cpl.dependency.scope_abc import ScopeABC
from cpl.dependency.service_provider_abc import ServiceProviderABC
class ScopeBuilder:
r"""Class to build :class:`cpl.dependency.scope.Scope`"""
def __init__(self, service_provider: ServiceProviderABC) -> None:
self._service_provider = service_provider
def build(self) -> ScopeABC:
r"""Returns scope
Returns:
Object of type :class:`cpl.dependency.scope.Scope`
"""
return Scope(self._service_provider)

View File

@@ -4,9 +4,8 @@ from cpl.core.log.logger_abc import LoggerABC
from cpl.core.typing import T, Service from cpl.core.typing import T, Service
from cpl.core.utils.cache import Cache from cpl.core.utils.cache import Cache
from cpl.dependency.service_descriptor import ServiceDescriptor from cpl.dependency.service_descriptor import ServiceDescriptor
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum from cpl.dependency.service_lifetime import ServiceLifetimeEnum
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.service_provider_abc import ServiceProviderABC
class ServiceCollection: class ServiceCollection:
@@ -62,9 +61,8 @@ class ServiceCollection:
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service) self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service)
return self return self
def build(self) -> ServiceProviderABC: def build(self) -> ServiceProvider:
sp = ServiceProvider(self._service_descriptors) sp = ServiceProvider(self._service_descriptors)
ServiceProviderABC.set_global_provider(sp)
return sp return sp
def add_module(self, module: str | object) -> Self: def add_module(self, module: str | object) -> Self:

View File

@@ -1,6 +1,6 @@
from typing import Union, Optional from typing import Union, Optional
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum from cpl.dependency.service_lifetime import ServiceLifetimeEnum
class ServiceDescriptor: class ServiceDescriptor:

View File

@@ -0,0 +1,7 @@
from enum import Enum, auto
class ServiceLifetimeEnum(Enum):
singleton = auto()
scoped = auto()
transient = auto()

View File

@@ -1,7 +0,0 @@
from enum import Enum
class ServiceLifetimeEnum(Enum):
singleton = 0
scoped = 1
transient = 2

View File

@@ -1,4 +1,3 @@
import copy
import typing import typing
from inspect import signature, Parameter, Signature from inspect import signature, Parameter, Signature
from typing import Optional, Type from typing import Optional, Type
@@ -7,14 +6,11 @@ from cpl.core.configuration import Configuration
from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.core.typing import T, R, Source from cpl.core.typing import T, R, Source
from cpl.dependency.scope_abc import ScopeABC
from cpl.dependency.scope_builder import ScopeBuilder
from cpl.dependency.service_descriptor import ServiceDescriptor from cpl.dependency.service_descriptor import ServiceDescriptor
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum from cpl.dependency.service_lifetime import ServiceLifetimeEnum
from cpl.dependency.service_provider_abc import ServiceProviderABC
class ServiceProvider(ServiceProviderABC): class ServiceProvider:
r"""Provider for the services r"""Provider for the services
Parameter Parameter
@@ -31,10 +27,7 @@ class ServiceProvider(ServiceProviderABC):
self, self,
service_descriptors: list[ServiceDescriptor], service_descriptors: list[ServiceDescriptor],
): ):
ServiceProviderABC.__init__(self)
self._service_descriptors: list[ServiceDescriptor] = service_descriptors self._service_descriptors: list[ServiceDescriptor] = service_descriptors
self._scope: Optional[ScopeABC] = None
def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]: def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]:
origin_type = typing.get_origin(service_type) or service_type origin_type = typing.get_origin(service_type) or service_type
@@ -103,7 +96,7 @@ class ServiceProvider(ServiceProviderABC):
elif parameter.annotation == Source: elif parameter.annotation == Source:
params.append(origin_service_type.__name__) params.append(origin_service_type.__name__)
elif issubclass(parameter.annotation, ServiceProviderABC): elif issubclass(parameter.annotation, ServiceProvider):
params.append(self) params.append(self)
elif issubclass(parameter.annotation, Environment): elif issubclass(parameter.annotation, Environment):
@@ -139,21 +132,6 @@ class ServiceProvider(ServiceProviderABC):
return service_type(*params, *args, **kwargs) return service_type(*params, *args, **kwargs)
def set_scope(self, scope: ScopeABC):
self._scope = scope
def create_scope(self) -> ScopeABC:
descriptors = []
for descriptor in self._service_descriptors:
if descriptor.lifetime == ServiceLifetimeEnum.singleton:
descriptors.append(descriptor)
else:
descriptors.append(copy.deepcopy(descriptor))
sb = ScopeBuilder(ServiceProvider(descriptors))
return sb.build()
def get_service(self, service_type: T, *args, **kwargs) -> Optional[R]: def get_service(self, service_type: T, *args, **kwargs) -> Optional[R]:
result = self._find_service(service_type) result = self._find_service(service_type)
@@ -166,8 +144,6 @@ class ServiceProvider(ServiceProviderABC):
implementation = self._build_service(service_type, *args, **kwargs) implementation = self._build_service(service_type, *args, **kwargs)
if ( if (
result.lifetime == ServiceLifetimeEnum.singleton result.lifetime == ServiceLifetimeEnum.singleton
or result.lifetime == ServiceLifetimeEnum.scoped
and self._scope is not None
): ):
result.implementation = implementation result.implementation = implementation

View File

@@ -1,127 +0,0 @@
from abc import abstractmethod, ABC
from inspect import Signature
from typing import Optional, Type
from cpl.core.typing import T
from cpl.dependency.scope_abc import ScopeABC
class ServiceProviderABC(ABC):
r"""ABC for the class :class:`cpl.dependency.service_provider.ServiceProvider`"""
_provider: Optional["ServiceProviderABC"] = None
@abstractmethod
def __init__(self): ...
@classmethod
def set_global_provider(cls, provider: "ServiceProviderABC"):
cls._provider = provider
@classmethod
def get_global_provider(cls) -> Optional["ServiceProviderABC"]:
return cls._provider
@classmethod
def get_global_service(cls, instance_type: Type[T], *args, **kwargs) -> Optional[T]:
if cls._provider is None:
return None
return cls._provider.get_service(instance_type, *args, **kwargs)
@classmethod
def get_global_services(cls, instance_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
if cls._provider is None:
return []
return cls._provider.get_services(instance_type, *args, **kwargs)
@abstractmethod
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[T]: ...
@abstractmethod
def _build_service(self, service_type: type, *args, **kwargs) -> object:
r"""Creates instance of given type
Parameter
---------
instance_type: :class:`type`
The type of the searched instance
Returns
-------
Object of the given type
"""
@abstractmethod
def set_scope(self, scope: ScopeABC):
r"""Sets the scope of service provider
Parameter
---------
Object of type :class:`cpl.dependency.scope_abc.ScopeABC`
Service scope
"""
@abstractmethod
def create_scope(self) -> ScopeABC:
r"""Creates a service scope
Returns
-------
Object of type :class:`cpl.dependency.scope_abc.ScopeABC`
"""
@abstractmethod
def get_service(self, instance_type: Type[T], *args, **kwargs) -> Optional[T]:
r"""Returns instance of given type
Parameter
---------
instance_type: :class:`cpl.core.type.T`
The type of the searched instance
Returns
-------
Object of type Optional[:class:`cpl.core.type.T`]
"""
@abstractmethod
def get_service_type(self, instance_type: Type[T]) -> Optional[Type[T]]:
r"""Returns the registered service type for loggers
Parameter
---------
instance_type: :class:`cpl.core.type.T`
The type of the searched instance
Returns
-------
Object of type Optional[:class:`type`]
"""
@abstractmethod
def get_services(self, service_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
r"""Returns instance of given type
Parameter
---------
service_type: :class:`cpl.core.type.T`
The type of the searched instance
Returns
-------
Object of type list[Optional[:class:`cpl.core.type.T`]
"""
@abstractmethod
def get_service_types(self, service_type: Type[T]) -> list[Type[T]]:
r"""Returns all registered service types
Parameter
---------
service_type: :class:`cpl.core.type.T`
The type of the searched instance
Returns
-------
Object of type list[:class:`type`]
"""

View File

@@ -2,7 +2,7 @@ import unittest
from cpl.application import ApplicationABC from cpl.application import ApplicationABC
from cpl.core.configuration import ConfigurationABC from cpl.core.configuration import ConfigurationABC
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProvider
from unittests_cli.cli_test_suite import CLITestSuite from unittests_cli.cli_test_suite import CLITestSuite
from unittests_core.core_test_suite import CoreTestSuite from unittests_core.core_test_suite import CoreTestSuite
from unittests_query.query_test_suite import QueryTestSuite from unittests_query.query_test_suite import QueryTestSuite
@@ -10,7 +10,7 @@ from unittests_translation.translation_test_suite import TranslationTestSuite
class Application(ApplicationABC): class Application(ApplicationABC):
def __init__(self, config: ConfigurationABC, services: ServiceProviderABC): def __init__(self, config: ConfigurationABC, services: ServiceProvider):
ApplicationABC.__init__(self, config, services) ApplicationABC.__init__(self, config, services)
def configure(self): ... def configure(self): ...

View File

@@ -2,7 +2,7 @@ import unittest
from unittest.mock import Mock from unittest.mock import Mock
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.dependency import ServiceCollection, ServiceLifetimeEnum, ServiceProviderABC from cpl.dependency import ServiceCollection, ServiceLifetimeEnum, ServiceProvider
class ServiceCollectionTestCase(unittest.TestCase): class ServiceCollectionTestCase(unittest.TestCase):
@@ -51,6 +51,6 @@ class ServiceCollectionTestCase(unittest.TestCase):
service = self._sc._service_descriptors[0] service = self._sc._service_descriptors[0]
self.assertIsNone(service.implementation) self.assertIsNone(service.implementation)
sp = self._sc.build() sp = self._sc.build()
self.assertTrue(isinstance(sp, ServiceProviderABC)) self.assertTrue(isinstance(sp, ServiceProvider))
self.assertTrue(isinstance(sp.get_service(Mock), Mock)) self.assertTrue(isinstance(sp.get_service(Mock), Mock))
self.assertIsNotNone(service.implementation) self.assertIsNotNone(service.implementation)

View File

@@ -1,7 +1,7 @@
import unittest import unittest
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.dependency import ServiceCollection, ServiceProviderABC from cpl.dependency import ServiceCollection, ServiceProvider
class ServiceCount: class ServiceCount:
@@ -10,21 +10,21 @@ class ServiceCount:
class TestService: class TestService:
def __init__(self, sp: ServiceProviderABC, count: ServiceCount): def __init__(self, sp: ServiceProvider, count: ServiceCount):
count.count += 1 count.count += 1
self.sp = sp self.sp = sp
self.id = count.count self.id = count.count
class DifferentService: class DifferentService:
def __init__(self, sp: ServiceProviderABC, count: ServiceCount): def __init__(self, sp: ServiceProvider, count: ServiceCount):
count.count += 1 count.count += 1
self.sp = sp self.sp = sp
self.id = count.count self.id = count.count
class MoreDifferentService: class MoreDifferentService:
def __init__(self, sp: ServiceProviderABC, count: ServiceCount): def __init__(self, sp: ServiceProvider, count: ServiceCount):
count.count += 1 count.count += 1
self.sp = sp self.sp = sp
self.id = count.count self.id = count.count
@@ -72,7 +72,7 @@ class ServiceProviderTestCase(unittest.TestCase):
singleton = self._services.get_service(TestService) singleton = self._services.get_service(TestService)
transient = self._services.get_service(DifferentService) transient = self._services.get_service(DifferentService)
with self._services.create_scope() as scope: with self._services.create_scope() as scope:
sp: ServiceProviderABC = scope.service_provider sp: ServiceProvider = scope.service_provider
self.assertNotEqual(sp, self._services) self.assertNotEqual(sp, self._services)
y = sp.get_service(DifferentService) y = sp.get_service(DifferentService)
self.assertIsNotNone(y) self.assertIsNotNone(y)