diff --git a/src/cpl/dependency_injection/service_collection.py b/src/cpl/dependency_injection/service_collection.py index 4be8ed7b..be97fb41 100644 --- a/src/cpl/dependency_injection/service_collection.py +++ b/src/cpl/dependency_injection/service_collection.py @@ -23,10 +23,7 @@ class ServiceCollection(ServiceCollectionABC): def _add_descriptor(self, service: Union[type, object], lifetime: ServiceLifetimeEnum): found = False for descriptor in self._service_descriptors: - if not isinstance(service, type): - service = type(service) - - if descriptor.service_type == service: + if isinstance(service, descriptor.service_type): found = True if found: @@ -39,23 +36,24 @@ class ServiceCollection(ServiceCollectionABC): self._service_descriptors.append(ServiceDescriptor(service, lifetime)) def add_db_context(self, db_context_type: Type[DatabaseContextABC], db_settings: DatabaseSettings): - db_context = db_context_type(db_settings) - db_context.connect(CredentialManager.build_string(db_settings.connection_string, db_settings.credentials)) + self._database_context = db_context_type(db_settings) + self._database_context.connect(CredentialManager.build_string(db_settings.connection_string, db_settings.credentials)) def add_singleton(self, service_type: Union[type, object], service: Union[type, object] = None): + impl = None if service is not None: if isinstance(service, type): - service = self.build_service_provider().build_service(service) + impl = self.build_service_provider().build_service(service) - self._add_descriptor(service, ServiceLifetimeEnum.singleton) + self._add_descriptor(impl, ServiceLifetimeEnum.singleton) else: if isinstance(service_type, type): - service_type = self.build_service_provider().build_service(service_type) + impl = self.build_service_provider().build_service(service_type) - self._add_descriptor(service_type, ServiceLifetimeEnum.singleton) + self._add_descriptor(impl, ServiceLifetimeEnum.singleton) def add_scoped(self, service_type: Type, service: Callable = None): - pass + raise Exception('Not implemented') def add_transient(self, service_type: Union[type], service: Union[type] = None): if service is not None: @@ -64,4 +62,4 @@ class ServiceCollection(ServiceCollectionABC): self._add_descriptor(service_type, ServiceLifetimeEnum.transient) def build_service_provider(self) -> ServiceProviderABC: - return ServiceProvider(self._service_descriptors, self._configuration) + return ServiceProvider(self._service_descriptors, self._configuration, self._database_context) diff --git a/src/cpl/dependency_injection/service_provider.py b/src/cpl/dependency_injection/service_provider.py index e0de3984..c4c8b2df 100644 --- a/src/cpl/dependency_injection/service_provider.py +++ b/src/cpl/dependency_injection/service_provider.py @@ -4,6 +4,7 @@ from typing import Optional from cpl.configuration.configuration_abc import ConfigurationABC from cpl.configuration.configuration_model_abc import ConfigurationModelABC +from cpl.database.context.database_context_abc import DatabaseContextABC from cpl.dependency_injection.service_provider_abc import ServiceProviderABC from cpl.dependency_injection.service_descriptor import ServiceDescriptor from cpl.dependency_injection.service_lifetime_enum import ServiceLifetimeEnum @@ -12,11 +13,12 @@ from cpl.environment.application_environment_abc import ApplicationEnvironmentAB class ServiceProvider(ServiceProviderABC): - def __init__(self, service_descriptors: list[ServiceDescriptor], config: ConfigurationABC): + def __init__(self, service_descriptors: list[ServiceDescriptor], config: ConfigurationABC, db_context: Optional[DatabaseContextABC]): ServiceProviderABC.__init__(self) self._service_descriptors: list[ServiceDescriptor] = service_descriptors self._configuration: ConfigurationABC = config + self._database_context = db_context def _find_service(self, service_type: type) -> [ServiceDescriptor]: for descriptor in self._service_descriptors: @@ -58,8 +60,8 @@ class ServiceProvider(ServiceProviderABC): elif issubclass(parameter.annotation, ApplicationEnvironmentABC): params.append(self._configuration.environment) - # elif issubclass(parameter.annotation, DatabaseContextABC): - # params.append(self._database_context) + elif issubclass(parameter.annotation, DatabaseContextABC): + params.append(self._database_context) elif issubclass(parameter.annotation, ConfigurationModelABC): params.append(self._configuration.get_configuration(parameter.annotation)) diff --git a/src/tests/custom/database/__init__.py b/src/tests/custom/database/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/tests/custom/database/src/__init__.py b/src/tests/custom/database/src/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/tests/custom/database/src/application.py b/src/tests/custom/database/src/application.py index 85b5505c..1f58eb76 100644 --- a/src/tests/custom/database/src/application.py +++ b/src/tests/custom/database/src/application.py @@ -5,6 +5,7 @@ from cpl.configuration import ConfigurationABC from cpl.console import Console from cpl.dependency_injection import ServiceProviderABC from cpl.logging import LoggerABC +from model.user_repo_abc import UserRepoABC class Application(ApplicationABC): @@ -22,3 +23,4 @@ class Application(ApplicationABC): self._logger.debug(__name__, f'Host: {self._configuration.environment.host_name}') self._logger.debug(__name__, f'Environment: {self._configuration.environment.environment_name}') self._logger.debug(__name__, f'Customer: {self._configuration.environment.customer}') + self._services.get_service(UserRepoABC).add_test_user() diff --git a/src/tests/custom/database/src/model/city_model.py b/src/tests/custom/database/src/model/city_model.py new file mode 100644 index 00000000..98523169 --- /dev/null +++ b/src/tests/custom/database/src/model/city_model.py @@ -0,0 +1,14 @@ +from sqlalchemy import Column, Integer, String + +from cpl.database import DatabaseModel + + +class CityModel(DatabaseModel): + __tablename__ = 'Cities' + Id = Column(Integer, primary_key=True, nullable=False, autoincrement=True) + Name = Column(String(64), nullable=False) + ZIP = Column(String(5), nullable=False) + + def __init__(self, name: str, zip_code: str): + self.Name = name + self.ZIP = zip_code diff --git a/src/tests/custom/database/src/model/user_model.py b/src/tests/custom/database/src/model/user_model.py new file mode 100644 index 00000000..bce0ed69 --- /dev/null +++ b/src/tests/custom/database/src/model/user_model.py @@ -0,0 +1,18 @@ +from sqlalchemy import Column, Integer, String, ForeignKey +from sqlalchemy.orm import relationship + +from cpl.database import DatabaseModel +from .city_model import CityModel + + +class UserModel(DatabaseModel): + __tablename__ = 'Users' + Id = Column(Integer, primary_key=True, nullable=False, autoincrement=True) + Name = Column(String(64), nullable=False) + City_Id = Column(Integer, ForeignKey('Cities.Id'), nullable=False) + City = relationship("CityModel") + + def __init__(self, name: str, city: CityModel): + self.Name = name + self.City_Id = city.Id + self.City = city diff --git a/src/tests/custom/database/src/model/user_repo.py b/src/tests/custom/database/src/model/user_repo.py new file mode 100644 index 00000000..414a7638 --- /dev/null +++ b/src/tests/custom/database/src/model/user_repo.py @@ -0,0 +1,23 @@ +from cpl.database.context import DatabaseContextABC +from .city_model import CityModel +from .user_model import UserModel +from .user_repo_abc import UserRepoABC + + +class UserRepo(UserRepoABC): + + def __init__(self, db_context: DatabaseContextABC): + UserRepoABC.__init__(self) + + self._session = db_context.session + self._user_query = db_context.session.query(UserModel) + + def create(self): pass + + def add_test_user(self): + city = CityModel('Haren', '49733') + city2 = CityModel('Meppen', '49716') + self._session.add(city2) + user = UserModel('TestUser', city) + self._session.add(user) + self._session.commit() diff --git a/src/tests/custom/database/src/model/user_repo_abc.py b/src/tests/custom/database/src/model/user_repo_abc.py new file mode 100644 index 00000000..dc0ce31f --- /dev/null +++ b/src/tests/custom/database/src/model/user_repo_abc.py @@ -0,0 +1,7 @@ +from abc import ABC, abstractmethod + + +class UserRepoABC(ABC): + + @abstractmethod + def __init__(self): pass diff --git a/src/tests/custom/database/src/startup.py b/src/tests/custom/database/src/startup.py index f4a37129..27417a3b 100644 --- a/src/tests/custom/database/src/startup.py +++ b/src/tests/custom/database/src/startup.py @@ -4,6 +4,8 @@ from cpl.database import DatabaseSettings from cpl.dependency_injection import ServiceProviderABC, ServiceCollectionABC from cpl.logging import LoggerABC, Logger from model.db_context import DBContext +from model.user_repo import UserRepo +from model.user_repo_abc import UserRepoABC class Startup(StartupABC): @@ -30,6 +32,8 @@ class Startup(StartupABC): db_settings: DatabaseSettings = self._configuration.get_configuration(DatabaseSettings) self._services.add_db_context(DBContext, db_settings) + self._services.add_singleton(UserRepoABC, UserRepo) + self._services.add_singleton(LoggerABC, Logger) return self._services.build_service_provider()