Fixed loading by base_type

This commit is contained in:
Sven Heidemann 2022-05-24 18:13:39 +02:00
parent 353c1d30ec
commit 35ecf158a2
4 changed files with 14 additions and 6 deletions

View File

@ -1,4 +1,4 @@
from typing import Union, Type, Callable, Optional from typing import Union, Type, Callable, Optional, overload
import lifetime as lifetime import lifetime as lifetime
@ -24,7 +24,7 @@ class ServiceCollection(ServiceCollectionABC):
self._database_context: Optional[DatabaseContextABC] = None self._database_context: Optional[DatabaseContextABC] = None
self._service_descriptors: list[ServiceDescriptor] = [] self._service_descriptors: list[ServiceDescriptor] = []
def _add_descriptor(self, service: Union[type, object], lifetime: ServiceLifetimeEnum): def _add_descriptor(self, service: Union[type, object], lifetime: ServiceLifetimeEnum, base_type: Callable = None):
found = False found = False
for descriptor in self._service_descriptors: for descriptor in self._service_descriptors:
if isinstance(service, descriptor.service_type): if isinstance(service, descriptor.service_type):
@ -37,11 +37,11 @@ class ServiceCollection(ServiceCollectionABC):
raise Exception(f'Service of type {service_type} already exists') raise Exception(f'Service of type {service_type} already exists')
self._service_descriptors.append(ServiceDescriptor(service, lifetime)) self._service_descriptors.append(ServiceDescriptor(service, lifetime, base_type))
def _add_descriptor_by_lifetime(self, service_type: Type, lifetime: ServiceLifetimeEnum, service: Callable = None): def _add_descriptor_by_lifetime(self, service_type: Type, lifetime: ServiceLifetimeEnum, service: Callable = None):
if service is not None: if service is not None:
self._add_descriptor(service, lifetime) self._add_descriptor(service, lifetime, service_type)
else: else:
self._add_descriptor(service_type, lifetime) self._add_descriptor(service_type, lifetime)

View File

@ -1,5 +1,6 @@
from typing import Union, Optional from typing import Union, Optional
from cpl_core.console import Console
from cpl_core.dependency_injection.service_lifetime_enum import ServiceLifetimeEnum from cpl_core.dependency_injection.service_lifetime_enum import ServiceLifetimeEnum
@ -14,7 +15,7 @@ class ServiceDescriptor:
Lifetime of the service Lifetime of the service
""" """
def __init__(self, implementation: Union[type, Optional[object]], lifetime: ServiceLifetimeEnum): def __init__(self, implementation: Union[type, Optional[object]], lifetime: ServiceLifetimeEnum, base_type=None):
self._service_type = implementation self._service_type = implementation
self._implementation = implementation self._implementation = implementation
@ -25,10 +26,16 @@ class ServiceDescriptor:
else: else:
self._implementation = None self._implementation = None
self._base_type = base_type if base_type is not None else self._service_type
@property @property
def service_type(self) -> type: def service_type(self) -> type:
return self._service_type return self._service_type
@property
def base_type(self) -> type:
return self._base_type
@property @property
def implementation(self) -> Union[type, Optional[object]]: def implementation(self) -> Union[type, Optional[object]]:
return self._implementation return self._implementation

View File

@ -37,7 +37,7 @@ class ServiceProvider(ServiceProviderABC):
def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]: def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]:
for descriptor in self._service_descriptors: for descriptor in self._service_descriptors:
if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type): if descriptor.service_type == service_type or issubclass(descriptor.base_type, service_type):
return descriptor return descriptor
return None return None

View File

@ -49,4 +49,5 @@ class Application(ApplicationABC):
test2: TestService = self._services.get_service(TestService) test2: TestService = self._services.get_service(TestService)
ip_pipe2: IPAddressPipe = self._services.get_service(IPAddressPipe) ip_pipe2: IPAddressPipe = self._services.get_service(IPAddressPipe)
Console.write_line(f'DI working: {test == test2 and ip_pipe != ip_pipe2}') Console.write_line(f'DI working: {test == test2 and ip_pipe != ip_pipe2}')
Console.write_line(self._services.get_service(LoggerABC))
# self.test_send_mail() # self.test_send_mail()