From 01ef96518062bc6bbe729385697fd0f028651c04 Mon Sep 17 00:00:00 2001 From: Sven Heidemann Date: Thu, 4 Mar 2021 17:55:15 +0100 Subject: [PATCH] Improved service providing --- .../dependency_injection/service_provider.py | 21 +++++++++++++------ .../service_provider_base.py | 6 +++--- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/cpl/dependency_injection/service_provider.py b/src/cpl/dependency_injection/service_provider.py index bc030417..a64d014d 100644 --- a/src/cpl/dependency_injection/service_provider.py +++ b/src/cpl/dependency_injection/service_provider.py @@ -52,18 +52,27 @@ class ServiceProvider(ServiceProviderABC): def get_db_context(self) -> Callable[DatabaseContextABC]: return self._database_context - def add_transient(self, service_type: Type[ServiceABC], service: Callable[ServiceABC]): - self._transient_services[service_type] = service + def add_transient(self, service_type: Type[ServiceABC], service: Callable[ServiceABC] = None): + if service is None: + self._transient_services[service_type] = service_type + else: + self._transient_services[service_type] = service - def add_scoped(self, service_type: Type[ServiceABC], service: Callable[ServiceABC]): - self._scoped_services[service_type] = service + def add_scoped(self, service_type: Type[ServiceABC], service: Callable[ServiceABC] = None): + if service is None: + self._scoped_services[service_type] = service_type + else: + self._scoped_services[service_type] = service - def add_singleton(self, service_type: Type[ServiceABC], service: Callable[ServiceABC]): + def add_singleton(self, service_type: Type[ServiceABC], service: Callable[ServiceABC] = None): for known_service in self._singleton_services: if type(known_service) == service_type: raise Exception(f'Service with type {service_type} already exists') - self._singleton_services[service_type] = self._create_instance(service) + if service is None: + self._singleton_services[service_type] = self._create_instance(service_type) + else: + self._singleton_services[service_type] = self._create_instance(service) def get_service(self, instance_type: Type) -> Callable[ServiceABC]: for service in self._transient_services: diff --git a/src/cpl/dependency_injection/service_provider_base.py b/src/cpl/dependency_injection/service_provider_base.py index f14e6cd9..059dd6f3 100644 --- a/src/cpl/dependency_injection/service_provider_base.py +++ b/src/cpl/dependency_injection/service_provider_base.py @@ -18,13 +18,13 @@ class ServiceProviderABC(ABC): def get_db_context(self) -> Callable[DatabaseContextABC]: pass @abstractmethod - def add_transient(self, service_type: Type, service: Callable): pass + def add_transient(self, service_type: Type, service: Callable = None): pass @abstractmethod - def add_scoped(self, service_type: Type, service: Callable): pass + def add_scoped(self, service_type: Type, service: Callable = None): pass @abstractmethod - def add_singleton(self, service_type: Type, service: Callable): pass + def add_singleton(self, service_type: Type, service: Callable = None): pass @abstractmethod def get_service(self, instance_type: Type) -> Callable[ServiceABC]: pass