diff --git a/src/cpl_core/dependency_injection/service_collection.py b/src/cpl_core/dependency_injection/service_collection.py index a608c710..51817b2d 100644 --- a/src/cpl_core/dependency_injection/service_collection.py +++ b/src/cpl_core/dependency_injection/service_collection.py @@ -1,4 +1,4 @@ -from typing import Union, Type, Callable, Optional +from typing import Union, Type, Callable, Optional, overload import lifetime as lifetime @@ -24,7 +24,7 @@ class ServiceCollection(ServiceCollectionABC): self._database_context: Optional[DatabaseContextABC] = None self._service_descriptors: list[ServiceDescriptor] = [] - def _add_descriptor(self, service: Union[type, object], lifetime: ServiceLifetimeEnum): + def _add_descriptor(self, service: Union[type, object], lifetime: ServiceLifetimeEnum, base_type: Callable = None): found = False for descriptor in self._service_descriptors: if isinstance(service, descriptor.service_type): @@ -37,11 +37,11 @@ class ServiceCollection(ServiceCollectionABC): raise Exception(f'Service of type {service_type} already exists') - self._service_descriptors.append(ServiceDescriptor(service, lifetime)) + self._service_descriptors.append(ServiceDescriptor(service, lifetime, base_type)) def _add_descriptor_by_lifetime(self, service_type: Type, lifetime: ServiceLifetimeEnum, service: Callable = None): if service is not None: - self._add_descriptor(service, lifetime) + self._add_descriptor(service, lifetime, service_type) else: self._add_descriptor(service_type, lifetime) diff --git a/src/cpl_core/dependency_injection/service_descriptor.py b/src/cpl_core/dependency_injection/service_descriptor.py index 2188ffd0..1c559b1f 100644 --- a/src/cpl_core/dependency_injection/service_descriptor.py +++ b/src/cpl_core/dependency_injection/service_descriptor.py @@ -1,5 +1,6 @@ from typing import Union, Optional +from cpl_core.console import Console from cpl_core.dependency_injection.service_lifetime_enum import ServiceLifetimeEnum @@ -14,7 +15,7 @@ class ServiceDescriptor: Lifetime of the service """ - def __init__(self, implementation: Union[type, Optional[object]], lifetime: ServiceLifetimeEnum): + def __init__(self, implementation: Union[type, Optional[object]], lifetime: ServiceLifetimeEnum, base_type=None): self._service_type = implementation self._implementation = implementation @@ -25,10 +26,16 @@ class ServiceDescriptor: else: self._implementation = None + self._base_type = base_type if base_type is not None else self._service_type + @property def service_type(self) -> type: return self._service_type + @property + def base_type(self) -> type: + return self._base_type + @property def implementation(self) -> Union[type, Optional[object]]: return self._implementation diff --git a/src/cpl_core/dependency_injection/service_provider.py b/src/cpl_core/dependency_injection/service_provider.py index 45687060..acf3cb95 100644 --- a/src/cpl_core/dependency_injection/service_provider.py +++ b/src/cpl_core/dependency_injection/service_provider.py @@ -37,7 +37,7 @@ class ServiceProvider(ServiceProviderABC): def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]: for descriptor in self._service_descriptors: - if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type): + if descriptor.service_type == service_type or issubclass(descriptor.base_type, service_type): return descriptor return None diff --git a/src/tests/custom/general/src/general/application.py b/src/tests/custom/general/src/general/application.py index 8f282560..c8889159 100644 --- a/src/tests/custom/general/src/general/application.py +++ b/src/tests/custom/general/src/general/application.py @@ -49,4 +49,5 @@ class Application(ApplicationABC): test2: TestService = self._services.get_service(TestService) ip_pipe2: IPAddressPipe = self._services.get_service(IPAddressPipe) Console.write_line(f'DI working: {test == test2 and ip_pipe != ip_pipe2}') + Console.write_line(self._services.get_service(LoggerABC)) # self.test_send_mail()