diff --git a/unittests/unittests_core/core_test_suite.py b/unittests/unittests_core/core_test_suite.py index b1e0c96b..f3321712 100644 --- a/unittests/unittests_core/core_test_suite.py +++ b/unittests/unittests_core/core_test_suite.py @@ -3,6 +3,8 @@ import unittest from unittests_core.configuration.console_arguments_test_case import ConsoleArgumentsTestCase from unittests_core.configuration.configuration_test_case import ConfigurationTestCase from unittests_core.configuration.environment_test_case import EnvironmentTestCase +from unittests_core.di.service_collection_test_case import ServiceCollectionTestCase +from unittests_core.di.service_provider_test_case import ServiceProviderTestCase from unittests_core.pipes.bool_pipe_test_case import BoolPipeTestCase from unittests_core.pipes.ip_address_pipe_test_case import IPAddressTestCase from unittests_core.pipes.version_pipe_test_case import VersionPipeTestCase @@ -21,6 +23,9 @@ class CoreTestSuite(unittest.TestSuite): ConfigurationTestCase, ConsoleArgumentsTestCase, EnvironmentTestCase, + # di + ServiceCollectionTestCase, + ServiceProviderTestCase, # pipes BoolPipeTestCase, IPAddressTestCase, diff --git a/unittests/unittests_core/di/__init__.py b/unittests/unittests_core/di/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/unittests/unittests_core/di/service_collection_test_case.py b/unittests/unittests_core/di/service_collection_test_case.py new file mode 100644 index 00000000..ba41797f --- /dev/null +++ b/unittests/unittests_core/di/service_collection_test_case.py @@ -0,0 +1,56 @@ +import unittest +from unittest.mock import Mock + +from cpl_core.configuration import Configuration +from cpl_core.dependency_injection import ServiceCollection, ServiceLifetimeEnum, ServiceProviderABC + + +class ServiceCollectionTestCase(unittest.TestCase): + def setUp(self): + self._sc = ServiceCollection(Configuration()) + + def test_add_singleton_type(self): + self._sc.add_singleton(Mock) + + service = self._sc._service_descriptors[0] + self.assertEqual(ServiceLifetimeEnum.singleton, service.lifetime) + self.assertEqual(Mock, service.service_type) + self.assertEqual(Mock, service.base_type) + self.assertIsNone(service.implementation) + + def test_add_singleton_instance(self): + mock = Mock() + self._sc.add_singleton(mock) + + service = self._sc._service_descriptors[0] + self.assertEqual(ServiceLifetimeEnum.singleton, service.lifetime) + self.assertEqual(type(mock), service.service_type) + self.assertEqual(type(mock), service.base_type) + self.assertIsNotNone(service.implementation) + + def test_add_transient_type(self): + self._sc.add_transient(Mock) + + service = self._sc._service_descriptors[0] + self.assertEqual(ServiceLifetimeEnum.transient, service.lifetime) + self.assertEqual(Mock, service.service_type) + self.assertEqual(Mock, service.base_type) + self.assertIsNone(service.implementation) + + def test_add_scoped_type(self): + self._sc.add_scoped(Mock) + + service = self._sc._service_descriptors[0] + self.assertEqual(ServiceLifetimeEnum.scoped, service.lifetime) + self.assertEqual(Mock, service.service_type) + self.assertEqual(Mock, service.base_type) + self.assertIsNone(service.implementation) + + def test_build_service_provider(self): + self._sc.add_singleton(Mock) + service = self._sc._service_descriptors[0] + self.assertIsNone(service.implementation) + sp = self._sc.build_service_provider() + self.assertTrue(isinstance(sp, ServiceProviderABC)) + self.assertTrue(isinstance(sp.get_service(Mock), Mock)) + self.assertIsNotNone(service.implementation) diff --git a/unittests/unittests_core/di/service_provider_test_case.py b/unittests/unittests_core/di/service_provider_test_case.py new file mode 100644 index 00000000..637a8aa8 --- /dev/null +++ b/unittests/unittests_core/di/service_provider_test_case.py @@ -0,0 +1,85 @@ +import unittest + +from cpl_core.configuration import Configuration +from cpl_core.dependency_injection import ServiceCollection, ServiceProviderABC + + +class ServiceCount: + def __init__(self): + self.count = 0 + + +class TestService: + def __init__(self, count: ServiceCount): + count.count += 1 + self.id = count.count + + +class DifferentService: + def __init__(self, count: ServiceCount): + count.count += 1 + self.id = count.count + + +class MoreDifferentService: + def __init__(self, count: ServiceCount): + count.count += 1 + self.id = count.count + + +class ServiceProviderTestCase(unittest.TestCase): + def setUp(self): + self._services = ( + ServiceCollection(Configuration()) + .add_singleton(ServiceCount) + .add_singleton(TestService) + .add_singleton(TestService) + .add_transient(DifferentService) + .add_scoped(MoreDifferentService) + .build_service_provider() + ) + + count = self._services.get_service(ServiceCount) + + def test_get_singleton(self): + x = self._services.get_service(TestService) + self.assertIsNotNone(x) + self.assertEqual(1, x.id) + self.assertEqual(x, self._services.get_service(TestService)) + self.assertEqual(x, self._services.get_service(TestService)) + self.assertEqual(x, self._services.get_service(TestService)) + + def test_get_singletons(self): + x = self._services.get_services(list[TestService]) + self.assertEqual(2, len(x)) + self.assertEqual(1, x[0].id) + self.assertEqual(2, x[1].id) + self.assertNotEqual(x[0], x[1]) + + def test_get_transient(self): + x = self._services.get_service(DifferentService) + self.assertIsNotNone(x) + self.assertEqual(1, x.id) + self.assertNotEqual(x, self._services.get_service(DifferentService)) + self.assertNotEqual(x, self._services.get_service(DifferentService)) + self.assertNotEqual(x, self._services.get_service(DifferentService)) + + def test_scoped(self): + scoped_id = 0 + singleton = self._services.get_service(TestService) + with self._services.create_scope() as scope: + sp: ServiceProviderABC = scope.service_provider + y = sp.get_service(DifferentService) + self.assertIsNotNone(y) + self.assertEqual(2, y.id) + x = sp.get_service(MoreDifferentService) + self.assertIsNotNone(x) + self.assertEqual(3, x.id) + scoped_id = 3 + self.assertEqual(x.id, sp.get_service(MoreDifferentService).id) + self.assertEqual(x.id, sp.get_service(MoreDifferentService).id) + self.assertNotEqual(x, self._services.get_service(MoreDifferentService)) + self.assertEqual(singleton, self._services.get_service(TestService)) + + self.assertIsNone(scope.service_provider) + self.assertNotEqual(scoped_id, self._services.get_service(MoreDifferentService).id)