Fixed loading by base_type
This commit is contained in:
		@@ -1,4 +1,4 @@
 | 
			
		||||
from typing import Union, Type, Callable, Optional
 | 
			
		||||
from typing import Union, Type, Callable, Optional, overload
 | 
			
		||||
 | 
			
		||||
import lifetime as lifetime
 | 
			
		||||
 | 
			
		||||
@@ -24,7 +24,7 @@ class ServiceCollection(ServiceCollectionABC):
 | 
			
		||||
        self._database_context: Optional[DatabaseContextABC] = None
 | 
			
		||||
        self._service_descriptors: list[ServiceDescriptor] = []
 | 
			
		||||
 | 
			
		||||
    def _add_descriptor(self, service: Union[type, object], lifetime: ServiceLifetimeEnum):
 | 
			
		||||
    def _add_descriptor(self, service: Union[type, object], lifetime: ServiceLifetimeEnum, base_type: Callable = None):
 | 
			
		||||
        found = False
 | 
			
		||||
        for descriptor in self._service_descriptors:
 | 
			
		||||
            if isinstance(service, descriptor.service_type):
 | 
			
		||||
@@ -37,11 +37,11 @@ class ServiceCollection(ServiceCollectionABC):
 | 
			
		||||
 | 
			
		||||
            raise Exception(f'Service of type {service_type} already exists')
 | 
			
		||||
 | 
			
		||||
        self._service_descriptors.append(ServiceDescriptor(service, lifetime))
 | 
			
		||||
        self._service_descriptors.append(ServiceDescriptor(service, lifetime, base_type))
 | 
			
		||||
 | 
			
		||||
    def _add_descriptor_by_lifetime(self, service_type: Type, lifetime: ServiceLifetimeEnum, service: Callable = None):
 | 
			
		||||
        if service is not None:
 | 
			
		||||
            self._add_descriptor(service, lifetime)
 | 
			
		||||
            self._add_descriptor(service, lifetime, service_type)
 | 
			
		||||
        else:
 | 
			
		||||
            self._add_descriptor(service_type, lifetime)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,6 @@
 | 
			
		||||
from typing import Union, Optional
 | 
			
		||||
 | 
			
		||||
from cpl_core.console import Console
 | 
			
		||||
from cpl_core.dependency_injection.service_lifetime_enum import ServiceLifetimeEnum
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -14,7 +15,7 @@ class ServiceDescriptor:
 | 
			
		||||
            Lifetime of the service
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, implementation: Union[type, Optional[object]], lifetime: ServiceLifetimeEnum):
 | 
			
		||||
    def __init__(self, implementation: Union[type, Optional[object]], lifetime: ServiceLifetimeEnum, base_type=None):
 | 
			
		||||
 | 
			
		||||
        self._service_type = implementation
 | 
			
		||||
        self._implementation = implementation
 | 
			
		||||
@@ -25,10 +26,16 @@ class ServiceDescriptor:
 | 
			
		||||
        else:
 | 
			
		||||
            self._implementation = None
 | 
			
		||||
 | 
			
		||||
        self._base_type = base_type if base_type is not None else self._service_type
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def service_type(self) -> type:
 | 
			
		||||
        return self._service_type
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def base_type(self) -> type:
 | 
			
		||||
        return self._base_type
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def implementation(self) -> Union[type, Optional[object]]:
 | 
			
		||||
        return self._implementation
 | 
			
		||||
 
 | 
			
		||||
@@ -37,7 +37,7 @@ class ServiceProvider(ServiceProviderABC):
 | 
			
		||||
 | 
			
		||||
    def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]:
 | 
			
		||||
        for descriptor in self._service_descriptors:
 | 
			
		||||
            if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type):
 | 
			
		||||
            if descriptor.service_type == service_type or issubclass(descriptor.base_type, service_type):
 | 
			
		||||
                return descriptor
 | 
			
		||||
 | 
			
		||||
        return None
 | 
			
		||||
 
 | 
			
		||||
@@ -49,4 +49,5 @@ class Application(ApplicationABC):
 | 
			
		||||
        test2: TestService = self._services.get_service(TestService)
 | 
			
		||||
        ip_pipe2: IPAddressPipe = self._services.get_service(IPAddressPipe)
 | 
			
		||||
        Console.write_line(f'DI working: {test == test2 and ip_pipe != ip_pipe2}')
 | 
			
		||||
        Console.write_line(self._services.get_service(LoggerABC))
 | 
			
		||||
        # self.test_send_mail()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user