diff --git a/example/database/src/startup.py b/example/database/src/startup.py index d9eab6c8..31fd4c54 100644 --- a/example/database/src/startup.py +++ b/example/database/src/startup.py @@ -1,11 +1,14 @@ from cpl import auth from cpl.application.abc.startup_abc import StartupABC from cpl.auth import permission +from cpl.auth.auth_module import AuthModule +from cpl.auth.permission.permission_module import PermissionsModule from cpl.core.configuration import Configuration from cpl.core.environment import Environment from cpl.core.log import Logger, LoggerABC from cpl.database import mysql from cpl.database.abc.data_access_object_abc import DataAccessObjectABC +from cpl.database.mysql.mysql_module import MySQLModule from cpl.dependency import ServiceCollection from model.city_dao import CityDao from model.user_dao import UserDao @@ -21,9 +24,9 @@ class Startup(StartupABC): @staticmethod async def configure_services(services: ServiceCollection): - services.add_module(mysql) - services.add_module(auth) - services.add_module(permission) + services.add_module(MySQLModule) + services.add_module(AuthModule) + services.add_module(PermissionsModule) services.add_transient(DataAccessObjectABC, UserDao) services.add_transient(DataAccessObjectABC, CityDao) diff --git a/example/di/src/application.py b/example/di/src/application.py index b8a62dd6..11af3e1d 100644 --- a/example/di/src/application.py +++ b/example/di/src/application.py @@ -1,11 +1,10 @@ from cpl.application.abc import ApplicationABC from cpl.core.console.console import Console from cpl.dependency import ServiceProvider -from di.static_test import StaticTest -from di.test_abc import TestABC -from di.test_service import TestService -from di.di_tester_service import DITesterService -from di.tester import Tester +from test_abc import TestABC +from test_service import TestService +from di_tester_service import DITesterService +from tester import Tester class Application(ApplicationABC): @@ -39,7 +38,8 @@ class Application(ApplicationABC): Console.write_line("Global") self._part_of_scoped() - StaticTest.test() + #from static_test import StaticTest + #StaticTest.test() self._services.get_service(Tester) Console.write_line(self._services.get_services(TestABC)) diff --git a/example/di/src/di_tester_service.py b/example/di/src/di_tester_service.py index 9937f561..e250badb 100644 --- a/example/di/src/di_tester_service.py +++ b/example/di/src/di_tester_service.py @@ -1,5 +1,5 @@ from cpl.core.console.console import Console -from di.test_service import TestService +from test_service import TestService class DITesterService: diff --git a/example/di/src/main.py b/example/di/src/main.py index a5ba63d8..06ef261b 100644 --- a/example/di/src/main.py +++ b/example/di/src/main.py @@ -1,7 +1,7 @@ from cpl.application import ApplicationBuilder -from di.application import Application -from di.startup import Startup +from application import Application +from startup import Startup def main(): diff --git a/example/di/src/startup.py b/example/di/src/startup.py index 89e7e1a7..0b949d37 100644 --- a/example/di/src/startup.py +++ b/example/di/src/startup.py @@ -1,11 +1,11 @@ from cpl.application.abc import StartupABC from cpl.dependency import ServiceProvider, ServiceCollection -from di.di_tester_service import DITesterService -from di.test1_service import Test1Service -from di.test2_service import Test2Service -from di.test_abc import TestABC -from di.test_service import TestService -from di.tester import Tester +from di_tester_service import DITesterService +from test1_service import Test1Service +from test2_service import Test2Service +from test_abc import TestABC +from test_service import TestService +from tester import Tester class Startup(StartupABC): diff --git a/example/di/src/static_test.py b/example/di/src/static_test.py index e60d3f53..775b758f 100644 --- a/example/di/src/static_test.py +++ b/example/di/src/static_test.py @@ -1,6 +1,6 @@ from cpl.dependency import ServiceProvider, ServiceProvider from cpl.dependency.inject import inject -from di.test_service import TestService +from test_service import TestService class StaticTest: diff --git a/example/di/src/test1_service.py b/example/di/src/test1_service.py index 21852f49..c9a60dd7 100644 --- a/example/di/src/test1_service.py +++ b/example/di/src/test1_service.py @@ -1,7 +1,7 @@ import string from cpl.core.console.console import Console from cpl.core.utils.string import String -from di.test_abc import TestABC +from test_abc import TestABC class Test1Service(TestABC): diff --git a/example/di/src/test2_service.py b/example/di/src/test2_service.py index 06832778..428be96b 100644 --- a/example/di/src/test2_service.py +++ b/example/di/src/test2_service.py @@ -1,7 +1,7 @@ import string from cpl.core.console.console import Console from cpl.core.utils.string import String -from di.test_abc import TestABC +from test_abc import TestABC class Test2Service(TestABC): diff --git a/example/di/src/tester.py b/example/di/src/tester.py index a05914cb..94e61e35 100644 --- a/example/di/src/tester.py +++ b/example/di/src/tester.py @@ -1,8 +1,7 @@ from cpl.core.console.console import Console -from di.test_abc import TestABC +from test_abc import TestABC class Tester: - def __init__(self, t1: TestABC, t2: TestABC, t3: list[TestABC]): - Console.write_line("Tester:") - Console.write_line(t1, t2, t3) + def __init__(self, t1: TestABC, t2: TestABC, t3: TestABC, t: list[TestABC]): + Console.write_line("Tester:", t, t1, t2, t3) diff --git a/example/query/main.py b/example/query/main.py index 883b4aa2..780a26c4 100644 --- a/example/query/main.py +++ b/example/query/main.py @@ -48,9 +48,9 @@ def t_benchmark(data: list): def main(): - N = 10_000_000 + N = 1_000_000 data = list(range(N)) - #t_benchmark(data) + t_benchmark(data) Console.write_line() _default() diff --git a/src/cpl-application/cpl/application/application_builder.py b/src/cpl-application/cpl/application/application_builder.py index 5d7ee4bd..073c0ae2 100644 --- a/src/cpl-application/cpl/application/application_builder.py +++ b/src/cpl-application/cpl/application/application_builder.py @@ -83,6 +83,7 @@ class ApplicationBuilder(Generic[TApp]): for extension in self._app_extensions: Host.run(extension.run, self.service_provider) + use_root_provider(self._services.build()) app = self._app(self.service_provider) self.validate_app_required_modules(app) return app diff --git a/src/cpl-auth/cpl/auth/permission/permission_module.py b/src/cpl-auth/cpl/auth/permission/permission_module.py index 6096d2b9..f70aa7e0 100644 --- a/src/cpl-auth/cpl/auth/permission/permission_module.py +++ b/src/cpl-auth/cpl/auth/permission/permission_module.py @@ -1,8 +1,8 @@ +from cpl.auth.auth_module import AuthModule from cpl.auth.permission.permission_seeder import PermissionSeeder from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions_registry import PermissionsRegistry from cpl.database.abc.data_seeder_abc import DataSeederABC -from cpl.database.model.server_type import ServerType, ServerTypes from cpl.dependency.module import Module, TModule from cpl.dependency.service_collection import ServiceCollection @@ -12,21 +12,7 @@ class PermissionsModule(Module): def dependencies() -> list[TModule]: from cpl.database.database_module import DatabaseModule - r = [DatabaseModule] - - match ServerType.server_type: - case ServerTypes.POSTGRES: - from cpl.database.postgres.postgres_module import PostgresModule - - r.append(PostgresModule) - case ServerTypes.MYSQL: - from cpl.database.mysql.mysql_module import MySQLModule - - r.append(MySQLModule) - case _: - raise Exception(f"Unsupported database type: {ServerType.server_type}") - - return r + return [DatabaseModule, AuthModule] @staticmethod def register(collection: ServiceCollection): diff --git a/src/cpl-dependency/cpl/dependency/context.py b/src/cpl-dependency/cpl/dependency/context.py index 1254b982..f4d8a331 100644 --- a/src/cpl-dependency/cpl/dependency/context.py +++ b/src/cpl-dependency/cpl/dependency/context.py @@ -2,7 +2,7 @@ import contextvars from contextlib import contextmanager -_current_provider = contextvars.ContextVar("current_provider") +_current_provider = contextvars.ContextVar("current_provider", default=None) def use_root_provider(provider: "ServiceProvider"): diff --git a/src/cpl-dependency/cpl/dependency/service_provider.py b/src/cpl-dependency/cpl/dependency/service_provider.py index f6ed2318..2a529cf4 100644 --- a/src/cpl-dependency/cpl/dependency/service_provider.py +++ b/src/cpl-dependency/cpl/dependency/service_provider.py @@ -23,6 +23,9 @@ class ServiceProvider: type_args = list(typing.get_args(service_type)) for descriptor in self._service_descriptors: + if typing.get_origin(service_type) is None and (descriptor.service_type == service_type or issubclass(descriptor.base_type, service_type)): + return descriptor + descriptor_base_type = typing.get_origin(descriptor.base_type) or descriptor.base_type descriptor_type_args = list(typing.get_args(descriptor.base_type)) @@ -48,7 +51,7 @@ class ServiceProvider: if descriptor.implementation is not None: return descriptor.implementation - implementation = self._build_service(descriptor.service_type, origin_service_type=origin_service_type) + implementation = self._build_service(descriptor, origin_service_type=origin_service_type) if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped): descriptor.implementation = implementation @@ -63,7 +66,7 @@ class ServiceProvider: continue implementation = self._build_service( - descriptor.service_type, *args, origin_service_type=service_type, **kwargs + descriptor, *args, origin_service_type=service_type, **kwargs ) if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped): descriptor.implementation = implementation @@ -78,7 +81,7 @@ class ServiceProvider: parameter = param[1] if parameter.name != "self" and parameter.annotation != Parameter.empty: if typing.get_origin(parameter.annotation) == list: - params.append(self._get_services(typing.get_args(parameter.annotation)[0], origin_service_type)) + params.append(self._get_services(typing.get_args(parameter.annotation)[0], service_type=origin_service_type)) elif parameter.annotation == Source: params.append(origin_service_type.__name__) @@ -101,18 +104,15 @@ class ServiceProvider: return params - def _build_service(self, service_type: type, *args, origin_service_type: type = None, **kwargs) -> object: + def _build_service(self, descriptor: ServiceDescriptor, *args, origin_service_type: type = None, **kwargs) -> object: + if descriptor.implementation is not None: + service_type = type(descriptor.implementation) + else: + service_type = descriptor.service_type + if origin_service_type is None: origin_service_type = service_type - for descriptor in self._service_descriptors: - if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type): - if descriptor.implementation is not None: - service_type = type(descriptor.implementation) - else: - service_type = descriptor.service_type - break - sig = signature(service_type.__init__) params = self._build_by_signature(sig, origin_service_type) return service_type(*params, *args, **kwargs) @@ -142,7 +142,7 @@ class ServiceProvider: if result.implementation is not None: return result.implementation - implementation = self._build_service(service_type, *args, **kwargs) + implementation = self._build_service(result, *args, **kwargs) if result.lifetime == ServiceLifetimeEnum.singleton: result.implementation = implementation