Minor DI fixes & cleanup
Some checks failed
Test before pr merge / test-lint (pull_request) Failing after 6s
Build on push / prepare (push) Successful in 9s
Build on push / core (push) Successful in 19s
Build on push / query (push) Successful in 27s
Build on push / dependency (push) Successful in 18s
Build on push / application (push) Successful in 15s
Build on push / database (push) Successful in 20s
Build on push / translation (push) Successful in 19s
Build on push / mail (push) Successful in 20s
Build on push / auth (push) Successful in 14s
Build on push / api (push) Successful in 14s

This commit is contained in:
2025-09-25 10:29:40 +02:00
parent 55a727c482
commit cf4aa8291f
14 changed files with 46 additions and 57 deletions

View File

@@ -83,6 +83,7 @@ class ApplicationBuilder(Generic[TApp]):
for extension in self._app_extensions:
Host.run(extension.run, self.service_provider)
use_root_provider(self._services.build())
app = self._app(self.service_provider)
self.validate_app_required_modules(app)
return app

View File

@@ -1,8 +1,8 @@
from cpl.auth.auth_module import AuthModule
from cpl.auth.permission.permission_seeder import PermissionSeeder
from cpl.auth.permission.permissions import Permissions
from cpl.auth.permission.permissions_registry import PermissionsRegistry
from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.model.server_type import ServerType, ServerTypes
from cpl.dependency.module import Module, TModule
from cpl.dependency.service_collection import ServiceCollection
@@ -12,21 +12,7 @@ class PermissionsModule(Module):
def dependencies() -> list[TModule]:
from cpl.database.database_module import DatabaseModule
r = [DatabaseModule]
match ServerType.server_type:
case ServerTypes.POSTGRES:
from cpl.database.postgres.postgres_module import PostgresModule
r.append(PostgresModule)
case ServerTypes.MYSQL:
from cpl.database.mysql.mysql_module import MySQLModule
r.append(MySQLModule)
case _:
raise Exception(f"Unsupported database type: {ServerType.server_type}")
return r
return [DatabaseModule, AuthModule]
@staticmethod
def register(collection: ServiceCollection):

View File

@@ -2,7 +2,7 @@ import contextvars
from contextlib import contextmanager
_current_provider = contextvars.ContextVar("current_provider")
_current_provider = contextvars.ContextVar("current_provider", default=None)
def use_root_provider(provider: "ServiceProvider"):

View File

@@ -23,6 +23,9 @@ class ServiceProvider:
type_args = list(typing.get_args(service_type))
for descriptor in self._service_descriptors:
if typing.get_origin(service_type) is None and (descriptor.service_type == service_type or issubclass(descriptor.base_type, service_type)):
return descriptor
descriptor_base_type = typing.get_origin(descriptor.base_type) or descriptor.base_type
descriptor_type_args = list(typing.get_args(descriptor.base_type))
@@ -48,7 +51,7 @@ class ServiceProvider:
if descriptor.implementation is not None:
return descriptor.implementation
implementation = self._build_service(descriptor.service_type, origin_service_type=origin_service_type)
implementation = self._build_service(descriptor, origin_service_type=origin_service_type)
if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped):
descriptor.implementation = implementation
@@ -63,7 +66,7 @@ class ServiceProvider:
continue
implementation = self._build_service(
descriptor.service_type, *args, origin_service_type=service_type, **kwargs
descriptor, *args, origin_service_type=service_type, **kwargs
)
if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped):
descriptor.implementation = implementation
@@ -78,7 +81,7 @@ class ServiceProvider:
parameter = param[1]
if parameter.name != "self" and parameter.annotation != Parameter.empty:
if typing.get_origin(parameter.annotation) == list:
params.append(self._get_services(typing.get_args(parameter.annotation)[0], origin_service_type))
params.append(self._get_services(typing.get_args(parameter.annotation)[0], service_type=origin_service_type))
elif parameter.annotation == Source:
params.append(origin_service_type.__name__)
@@ -101,18 +104,15 @@ class ServiceProvider:
return params
def _build_service(self, service_type: type, *args, origin_service_type: type = None, **kwargs) -> object:
def _build_service(self, descriptor: ServiceDescriptor, *args, origin_service_type: type = None, **kwargs) -> object:
if descriptor.implementation is not None:
service_type = type(descriptor.implementation)
else:
service_type = descriptor.service_type
if origin_service_type is None:
origin_service_type = service_type
for descriptor in self._service_descriptors:
if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type):
if descriptor.implementation is not None:
service_type = type(descriptor.implementation)
else:
service_type = descriptor.service_type
break
sig = signature(service_type.__init__)
params = self._build_by_signature(sig, origin_service_type)
return service_type(*params, *args, **kwargs)
@@ -142,7 +142,7 @@ class ServiceProvider:
if result.implementation is not None:
return result.implementation
implementation = self._build_service(service_type, *args, **kwargs)
implementation = self._build_service(result, *args, **kwargs)
if result.lifetime == ServiceLifetimeEnum.singleton:
result.implementation = implementation