Compare commits
5 Commits
2025.09.25
...
2025.09.26
| Author | SHA1 | Date | |
|---|---|---|---|
| e0f6e1c241 | |||
| c410a692be | |||
| 56a16cbeba | |||
| d05d947d54 | |||
| 0529269747 |
@@ -1,9 +1,9 @@
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from cpl import api
|
||||
from cpl.api.api_module import ApiModule
|
||||
from cpl.api.application.web_app import WebApp
|
||||
from cpl.api_module import ApiModule
|
||||
from cpl.application import ApplicationBuilder
|
||||
from cpl.application.application_builder import ApplicationBuilder
|
||||
from cpl.auth import AuthModule
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.auth.schema import AuthUser, Role
|
||||
from cpl.core.configuration import Configuration
|
||||
@@ -35,12 +35,17 @@ def main():
|
||||
|
||||
app = builder.build()
|
||||
app.with_logging()
|
||||
app.with_database()
|
||||
|
||||
app.with_authentication()
|
||||
app.with_authorization()
|
||||
|
||||
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET", authentication=True, permissions=[Permissions.administrator])
|
||||
app.with_route(
|
||||
path="/route1",
|
||||
fn=lambda r: JSONResponse("route1"),
|
||||
method="GET",
|
||||
authentication=True,
|
||||
permissions=[Permissions.administrator],
|
||||
)
|
||||
app.with_routes_directory("routes")
|
||||
|
||||
provider = builder.service_provider
|
||||
|
||||
@@ -4,6 +4,7 @@ from cpl.core.console import Console
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.core.log import LoggerABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.dependency.typing import Modules
|
||||
from model.city import City
|
||||
from model.city_dao import CityDao
|
||||
from model.user import User
|
||||
@@ -11,8 +12,8 @@ from model.user_dao import UserDao
|
||||
|
||||
|
||||
class Application(ApplicationABC):
|
||||
def __init__(self, services: ServiceProvider):
|
||||
ApplicationABC.__init__(self, services)
|
||||
def __init__(self, services: ServiceProvider, modules: Modules):
|
||||
ApplicationABC.__init__(self, services, modules)
|
||||
|
||||
self._logger = services.get_service(LoggerABC)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from cpl.application import ApplicationBuilder
|
||||
from cpl.auth.permission.permissions_registry import PermissionsRegistry
|
||||
from cpl.core.console import Console
|
||||
from cpl.core.log import LogLevel
|
||||
from cpl.database import DatabaseModule
|
||||
from custom_permissions import CustomPermissions
|
||||
from startup import Startup
|
||||
|
||||
@@ -10,13 +11,12 @@ from startup import Startup
|
||||
def main():
|
||||
builder = ApplicationBuilder(Application).with_startup(Startup)
|
||||
builder.services.add_logging()
|
||||
|
||||
app = builder.build()
|
||||
|
||||
app.with_logging(LogLevel.trace)
|
||||
app.with_permissions(CustomPermissions)
|
||||
app.with_migrations("./scripts")
|
||||
app.with_seeders()
|
||||
# app.with_seeders()
|
||||
|
||||
Console.write_line(CustomPermissions.test.value in PermissionsRegistry.get())
|
||||
app.run()
|
||||
|
||||
@@ -6,7 +6,7 @@ from cpl.auth.permission.permission_module import PermissionsModule
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.core.log import Logger, LoggerABC
|
||||
from cpl.database import mysql
|
||||
from cpl.database import mysql, DatabaseModule
|
||||
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
||||
from cpl.database.mysql.mysql_module import MySQLModule
|
||||
from cpl.dependency import ServiceCollection
|
||||
@@ -25,6 +25,7 @@ class Startup(StartupABC):
|
||||
@staticmethod
|
||||
async def configure_services(services: ServiceCollection):
|
||||
services.add_module(MySQLModule)
|
||||
services.add_module(DatabaseModule)
|
||||
services.add_module(AuthModule)
|
||||
services.add_module(PermissionsModule)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from cpl.core.environment import Environment
|
||||
from cpl.core.log import LoggerABC
|
||||
from cpl.core.pipes import IPAddressPipe
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.dependency.typing import Modules
|
||||
from cpl.mail import EMail, EMailClientABC
|
||||
from cpl.query import List
|
||||
from scoped_service import ScopedService
|
||||
@@ -16,8 +17,8 @@ from test_settings import TestSettings
|
||||
|
||||
class Application(ApplicationABC):
|
||||
|
||||
def __init__(self, services: ServiceProvider):
|
||||
ApplicationABC.__init__(self, services)
|
||||
def __init__(self, services: ServiceProvider, modules: Modules):
|
||||
ApplicationABC.__init__(self, services, modules)
|
||||
self._logger = self._services.get_service(LoggerABC)
|
||||
self._mailer = self._services.get_service(EMailClientABC)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .error import APIError, AlreadyExists, EndpointNotImplemented, Forbidden, NotFound, Unauthorized
|
||||
from .logger import APILogger
|
||||
from .settings import ApiSettings
|
||||
|
||||
from .api_module import ApiModule
|
||||
|
||||
22
src/cpl-api/cpl/api/api_module.py
Normal file
22
src/cpl-api/cpl/api/api_module.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from cpl.api import ApiSettings
|
||||
from cpl.api.registry.policy import PolicyRegistry
|
||||
from cpl.api.registry.route import RouteRegistry
|
||||
from cpl.auth.auth_module import AuthModule
|
||||
from cpl.auth.permission.permission_module import PermissionsModule
|
||||
from cpl.database.database_module import DatabaseModule
|
||||
from cpl.dependency import ServiceCollection
|
||||
from cpl.dependency.module.module import Module
|
||||
|
||||
|
||||
class ApiModule(Module):
|
||||
config = [ApiSettings]
|
||||
singleton = [
|
||||
PolicyRegistry,
|
||||
RouteRegistry,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def register(collection: ServiceCollection):
|
||||
collection.add_module(DatabaseModule)
|
||||
collection.add_module(AuthModule)
|
||||
collection.add_module(PermissionsModule)
|
||||
@@ -10,7 +10,7 @@ from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ExceptionHandler
|
||||
|
||||
from cpl import api, auth
|
||||
from cpl.api.api_module import ApiModule
|
||||
from cpl.api.error import APIError
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.middleware.authentication import AuthenticationMiddleware
|
||||
@@ -25,20 +25,20 @@ from cpl.api.registry.route import RouteRegistry
|
||||
from cpl.api.router import Router
|
||||
from cpl.api.settings import ApiSettings
|
||||
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
|
||||
from cpl.api_module import ApiModule
|
||||
from cpl.application.abc.application_abc import ApplicationABC
|
||||
from cpl.auth.auth_module import AuthModule
|
||||
from cpl.auth.permission.permission_module import PermissionsModule
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.configuration.configuration import Configuration
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.dependency.typing import Modules
|
||||
|
||||
PolicyInput = Union[dict[str, PolicyResolver], Policy]
|
||||
|
||||
|
||||
class WebApp(ApplicationABC):
|
||||
def __init__(self, services: ServiceProvider):
|
||||
super().__init__(services, [AuthModule, PermissionsModule, ApiModule])
|
||||
def __init__(self, services: ServiceProvider, modules: Modules):
|
||||
super().__init__(services, modules, [AuthModule, PermissionsModule, ApiModule])
|
||||
self._app: Starlette | None = None
|
||||
|
||||
self._logger = services.get_service(APILogger)
|
||||
@@ -78,11 +78,6 @@ class WebApp(ApplicationABC):
|
||||
self._logger.debug(f"Allowed origins: {origins}")
|
||||
return origins.split(",")
|
||||
|
||||
def with_database(self) -> Self:
|
||||
self.with_migrations()
|
||||
self.with_seeders()
|
||||
return self
|
||||
|
||||
def with_app(self, app: Starlette) -> Self:
|
||||
assert app is not None, "app must not be None"
|
||||
assert isinstance(app, Starlette), "app must be an instance of Starlette"
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
from cpl.api.registry.policy import PolicyRegistry
|
||||
from cpl.api.registry.route import RouteRegistry
|
||||
from cpl.auth.auth_module import AuthModule
|
||||
from cpl.auth.permission.permission_module import PermissionsModule
|
||||
from cpl.core.errors import dependency_error
|
||||
from cpl.database.database_module import DatabaseModule
|
||||
from cpl.database.model.server_type import ServerType, ServerTypes
|
||||
from cpl.database.mysql.mysql_module import MySQLModule
|
||||
from cpl.dependency.module import Module, TModule
|
||||
|
||||
|
||||
class ApiModule(Module):
|
||||
|
||||
@staticmethod
|
||||
def dependencies() -> list[TModule]:
|
||||
return [AuthModule, DatabaseModule, PermissionsModule]
|
||||
|
||||
@staticmethod
|
||||
def register(collection: "ServiceCollection"):
|
||||
collection.add_module(DatabaseModule)
|
||||
|
||||
collection.add_module(AuthModule)
|
||||
collection.add_module(PermissionsModule)
|
||||
|
||||
collection.add_singleton(PolicyRegistry)
|
||||
collection.add_singleton(RouteRegistry)
|
||||
@@ -1,2 +1,2 @@
|
||||
from .application_builder import ApplicationBuilder
|
||||
from .host import Host
|
||||
from .host import Host
|
||||
|
||||
@@ -2,10 +2,12 @@ from abc import ABC, abstractmethod
|
||||
from typing import Callable, Self
|
||||
|
||||
from cpl.application.host import Host
|
||||
from cpl.core.errors import module_dependency_error
|
||||
from cpl.core.log.log_level import LogLevel
|
||||
from cpl.core.log.log_settings import LogSettings
|
||||
from cpl.core.log.logger_abc import LoggerABC
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.dependency.typing import TModule
|
||||
|
||||
|
||||
def __not_implemented__(package: str, func: Callable):
|
||||
@@ -20,17 +22,6 @@ class ApplicationABC(ABC):
|
||||
Contains instances of prepared objects
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, services: ServiceProvider, required_modules: list[str | object] = None):
|
||||
self._services = services
|
||||
self._required_modules = (
|
||||
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []
|
||||
)
|
||||
|
||||
@property
|
||||
def required_modules(self) -> list[str]:
|
||||
return self._required_modules
|
||||
|
||||
@classmethod
|
||||
def extend(cls, name: str | Callable, func: Callable[[Self], Self]):
|
||||
r"""Extend the Application with a custom method
|
||||
@@ -47,6 +38,30 @@ class ApplicationABC(ABC):
|
||||
setattr(cls, name, func)
|
||||
return cls
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self, services: ServiceProvider, loaded_modules: set[TModule], required_modules: list[str | object] = None
|
||||
):
|
||||
self._services = services
|
||||
self._modules = loaded_modules
|
||||
self._required_modules = (
|
||||
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []
|
||||
)
|
||||
|
||||
def validate_app_required_modules(self):
|
||||
modules_names = {x.__name__ for x in self._modules}
|
||||
for module in self._required_modules:
|
||||
if module in modules_names:
|
||||
continue
|
||||
|
||||
module_dependency_error(
|
||||
type(self).__name__,
|
||||
module.__name__,
|
||||
ImportError(
|
||||
f"Required module '{module}' for application '{self.__class__.__name__}' is not loaded. Load using 'add_module({module})' method."
|
||||
),
|
||||
)
|
||||
|
||||
def with_logging(self, level: LogLevel = None):
|
||||
if level is None:
|
||||
from cpl.core.configuration.configuration import Configuration
|
||||
@@ -57,14 +72,21 @@ class ApplicationABC(ABC):
|
||||
logger = self._services.get_service(LoggerABC)
|
||||
logger.set_level(level)
|
||||
|
||||
def with_permissions(self, *args, **kwargs):
|
||||
__not_implemented__("cpl-auth", self.with_permissions)
|
||||
def with_permissions(self, *args):
|
||||
try:
|
||||
from cpl.auth import AuthModule
|
||||
|
||||
def with_migrations(self, *args, **kwargs):
|
||||
__not_implemented__("cpl-database", self.with_migrations)
|
||||
AuthModule.with_permissions(*args)
|
||||
except ImportError:
|
||||
__not_implemented__("cpl-auth", self.with_permissions)
|
||||
|
||||
def with_seeders(self, *args, **kwargs):
|
||||
__not_implemented__("cpl-database", self.with_seeders)
|
||||
def with_migrations(self, *args):
|
||||
try:
|
||||
from cpl.database.database_module import DatabaseModule
|
||||
|
||||
DatabaseModule.with_migrations(self._services, *args)
|
||||
except ImportError:
|
||||
__not_implemented__("cpl-database", self.with_migrations)
|
||||
|
||||
def with_extension(self, func: Callable[[Self, ...], None], *args, **kwargs):
|
||||
r"""Extend the Application with a custom method
|
||||
@@ -84,9 +106,17 @@ class ApplicationABC(ABC):
|
||||
Called by custom Application.main
|
||||
"""
|
||||
try:
|
||||
for module in self._modules:
|
||||
if not hasattr(module, "configure") and not callable(getattr(module, "configure")):
|
||||
continue
|
||||
module.configure(self._services)
|
||||
|
||||
Host.run_app(self.main)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
logger = self._services.get_service(LoggerABC)
|
||||
logger.info("Application shutdown")
|
||||
|
||||
@abstractmethod
|
||||
def main(self): ...
|
||||
|
||||
@@ -6,7 +6,6 @@ from cpl.application.abc.application_extension_abc import ApplicationExtensionAB
|
||||
from cpl.application.abc.startup_abc import StartupABC
|
||||
from cpl.application.abc.startup_extension_abc import StartupExtensionABC
|
||||
from cpl.application.host import Host
|
||||
from cpl.core.errors import dependency_error
|
||||
from cpl.dependency.context import get_provider, use_root_provider
|
||||
from cpl.dependency.service_collection import ServiceCollection
|
||||
|
||||
@@ -43,19 +42,6 @@ class ApplicationBuilder(Generic[TApp]):
|
||||
|
||||
return provider
|
||||
|
||||
def validate_app_required_modules(self, app: ApplicationABC):
|
||||
for module in app.required_modules:
|
||||
if module in self._services.loaded_modules:
|
||||
continue
|
||||
|
||||
dependency_error(
|
||||
type(app).__name__,
|
||||
module,
|
||||
ImportError(
|
||||
f"Required module '{module}' for application '{app.__class__.__name__}' is not loaded. Load using 'add_module({module})' method."
|
||||
),
|
||||
)
|
||||
|
||||
def with_startup(self, startup: Type[StartupABC]) -> "ApplicationBuilder":
|
||||
self._startup = startup
|
||||
return self
|
||||
@@ -84,6 +70,6 @@ class ApplicationBuilder(Generic[TApp]):
|
||||
Host.run(extension.run, self.service_provider)
|
||||
|
||||
use_root_provider(self._services.build())
|
||||
app = self._app(self.service_provider)
|
||||
self.validate_app_required_modules(app)
|
||||
app = self._app(self.service_provider, self._services.loaded_modules)
|
||||
app.validate_app_required_modules()
|
||||
return app
|
||||
|
||||
@@ -77,4 +77,4 @@ class Host:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return cls.get_loop().run_until_complete(func(*args, **kwargs))
|
||||
|
||||
return func(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@@ -1,21 +1,6 @@
|
||||
from enum import Enum
|
||||
from typing import Type
|
||||
|
||||
from cpl.application.abc import ApplicationABC as _ApplicationABC
|
||||
from cpl.auth import permission as _permission
|
||||
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin as _KeycloakAdmin
|
||||
from cpl.auth.keycloak.keycloak_client import KeycloakClient as _KeycloakClient
|
||||
from .auth_module import AuthModule
|
||||
from .keycloak_settings import KeycloakSettings
|
||||
from .logger import AuthLogger
|
||||
|
||||
|
||||
def _with_permissions(self: _ApplicationABC, *permissions: Type[Enum]) -> _ApplicationABC:
|
||||
from cpl.auth.permission.permissions_registry import PermissionsRegistry
|
||||
|
||||
for perm in permissions:
|
||||
PermissionsRegistry.with_enum(perm)
|
||||
return self
|
||||
|
||||
|
||||
|
||||
_ApplicationABC.extend(_ApplicationABC.with_permissions, _with_permissions)
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Type
|
||||
|
||||
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin as _KeycloakAdmin
|
||||
from cpl.auth.keycloak.keycloak_client import KeycloakClient as _KeycloakClient
|
||||
from cpl.auth.keycloak_settings import KeycloakSettings
|
||||
from cpl.database.database_module import DatabaseModule
|
||||
from cpl.database.model.server_type import ServerType, ServerTypes
|
||||
from cpl.database.service.migration_service import MigrationService
|
||||
from cpl.dependency.module import Module, TModule
|
||||
from cpl.dependency.service_collection import ServiceCollection
|
||||
from cpl.database.mysql.mysql_module import MySQLModule
|
||||
from cpl.database.postgres.postgres_module import PostgresModule
|
||||
from cpl.dependency.module.module import Module
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from .keycloak.keycloak_admin import KeycloakAdmin
|
||||
from .keycloak.keycloak_client import KeycloakClient
|
||||
from .schema._administration.api_key_dao import ApiKeyDao
|
||||
from .schema._administration.auth_user_dao import AuthUserDao
|
||||
from .schema._permission.api_key_permission_dao import ApiKeyPermissionDao
|
||||
@@ -17,28 +21,36 @@ from .schema._permission.role_user_dao import RoleUserDao
|
||||
|
||||
|
||||
class AuthModule(Module):
|
||||
@staticmethod
|
||||
def dependencies() -> list[TModule]:
|
||||
return [DatabaseModule]
|
||||
dependencies = [DatabaseModule, (MySQLModule, PostgresModule)]
|
||||
config = [KeycloakSettings]
|
||||
singleton = [
|
||||
KeycloakClient,
|
||||
KeycloakAdmin,
|
||||
AuthUserDao,
|
||||
ApiKeyDao,
|
||||
ApiKeyPermissionDao,
|
||||
PermissionDao,
|
||||
RoleDao,
|
||||
RolePermissionDao,
|
||||
RoleUserDao,
|
||||
]
|
||||
scoped = []
|
||||
transient = []
|
||||
|
||||
@staticmethod
|
||||
def register(collection: ServiceCollection):
|
||||
collection.add_singleton(_KeycloakClient)
|
||||
collection.add_singleton(_KeycloakAdmin)
|
||||
def configure(provider: ServiceProvider):
|
||||
paths = {
|
||||
ServerTypes.POSTGRES: "scripts/postgres",
|
||||
ServerTypes.MYSQL: "scripts/mysql",
|
||||
}
|
||||
|
||||
collection.add_singleton(AuthUserDao)
|
||||
collection.add_singleton(ApiKeyDao)
|
||||
collection.add_singleton(ApiKeyPermissionDao)
|
||||
collection.add_singleton(PermissionDao)
|
||||
collection.add_singleton(RoleDao)
|
||||
collection.add_singleton(RolePermissionDao)
|
||||
collection.add_singleton(RoleUserDao)
|
||||
DatabaseModule.with_migrations(
|
||||
provider, str(os.path.join(os.path.dirname(os.path.realpath(__file__)), paths[ServerType.server_type]))
|
||||
)
|
||||
|
||||
provider = collection.build()
|
||||
migration_service: MigrationService = provider.get_service(MigrationService)
|
||||
if ServerType.server_type == ServerTypes.POSTGRES:
|
||||
migration_service.with_directory(
|
||||
os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/postgres")
|
||||
)
|
||||
elif ServerType.server_type == ServerTypes.MYSQL:
|
||||
migration_service.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/mysql"))
|
||||
@staticmethod
|
||||
def with_permissions(*permissions: Type[Enum]):
|
||||
from cpl.auth.permission.permissions_registry import PermissionsRegistry
|
||||
|
||||
for perm in permissions:
|
||||
PermissionsRegistry.with_enum(perm)
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from .permission_module import PermissionsModule
|
||||
from .permission_seeder import PermissionSeeder
|
||||
from .permissions import Permissions
|
||||
from .permissions_registry import PermissionsRegistry
|
||||
|
||||
@@ -3,18 +3,15 @@ from cpl.auth.permission.permission_seeder import PermissionSeeder
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.auth.permission.permissions_registry import PermissionsRegistry
|
||||
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
||||
from cpl.dependency.module import Module, TModule
|
||||
from cpl.database.database_module import DatabaseModule
|
||||
from cpl.dependency.module.module import Module
|
||||
from cpl.dependency.service_collection import ServiceCollection
|
||||
|
||||
|
||||
class PermissionsModule(Module):
|
||||
@staticmethod
|
||||
def dependencies() -> list[TModule]:
|
||||
from cpl.database.database_module import DatabaseModule
|
||||
|
||||
return [DatabaseModule, AuthModule]
|
||||
dependencies = [DatabaseModule, AuthModule]
|
||||
singleton = [(DataSeederABC, PermissionSeeder)]
|
||||
|
||||
@staticmethod
|
||||
def register(collection: ServiceCollection):
|
||||
collection.add_singleton(DataSeederABC, PermissionSeeder)
|
||||
PermissionsRegistry.with_enum(Permissions)
|
||||
|
||||
@@ -7,20 +7,21 @@ def dependency_error(src: str, package_name: str, e: ImportError = None) -> None
|
||||
Console.error(f"'{package_name}' is required to use feature: {src}. Please install it and try again.")
|
||||
tb = traceback.format_exc()
|
||||
if not tb.startswith("NoneType: None"):
|
||||
Console.write_line("->", tb)
|
||||
Console.error("->", tb)
|
||||
|
||||
elif e is not None:
|
||||
Console.write_line("->", str(e))
|
||||
Console.error(f"-> {str(e)}")
|
||||
|
||||
exit(1)
|
||||
|
||||
|
||||
def module_dependency_error(src: str, module: str, e: ImportError = None) -> None:
|
||||
Console.error(f"'{module}' is required to use feature: {src}. Please initialize it with `add_module({module})`.")
|
||||
Console.error(f"'{module}' is required by '{src}'. Please initialize it with `add_module({module})`.")
|
||||
tb = traceback.format_exc()
|
||||
if not tb.startswith("NoneType: None"):
|
||||
Console.write_line("->", tb)
|
||||
Console.error("->", tb)
|
||||
|
||||
elif e is not None:
|
||||
Console.write_line("->", str(e))
|
||||
Console.error(f"-> {str(e)}")
|
||||
|
||||
exit(1)
|
||||
exit(1)
|
||||
|
||||
@@ -93,14 +93,13 @@ class Logger(LoggerABC):
|
||||
def _log(self, level: LogLevel, *messages: Messages):
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
formatted_message = self._format_message(level.value, timestamp, *messages)
|
||||
|
||||
self._write_log_to_file(level, formatted_message)
|
||||
self._write_to_console(level, formatted_message)
|
||||
self._write_log_to_file(level, self._file_format_message(level.value, timestamp, *messages))
|
||||
self._write_to_console(level, self._console_format_message(level.value, timestamp, *messages))
|
||||
except Exception as e:
|
||||
print(f"Error while logging: {e} -> {traceback.format_exc()}")
|
||||
|
||||
def _format_message(self, level: str, timestamp, *messages: Messages) -> str:
|
||||
def _file_format_message(self, level: str, timestamp, *messages: Messages) -> str:
|
||||
if isinstance(messages, tuple):
|
||||
messages = list(messages)
|
||||
|
||||
@@ -119,6 +118,24 @@ class Logger(LoggerABC):
|
||||
|
||||
return message
|
||||
|
||||
def _console_format_message(self, level: str, timestamp, *messages: Messages) -> str:
|
||||
if isinstance(messages, tuple):
|
||||
messages = list(messages)
|
||||
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
messages = [str(message) for message in messages if message is not None]
|
||||
|
||||
message = f"[{level.upper():^3}]"
|
||||
message += f" [{self._file_prefix}]"
|
||||
if self._source is not None:
|
||||
message += f" - [{self._source}]"
|
||||
|
||||
message += f": {' '.join(messages)}"
|
||||
|
||||
return message
|
||||
|
||||
def header(self, string: str):
|
||||
self._log(LogLevel.info, string)
|
||||
|
||||
|
||||
@@ -11,7 +11,10 @@ class LoggerABC(ABC):
|
||||
def set_level(self, level: LogLevel): ...
|
||||
|
||||
@abstractmethod
|
||||
def _format_message(self, level: str, timestamp, *messages: Messages) -> str: ...
|
||||
def _file_format_message(self, level: str, timestamp, *messages: Messages) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def _console_format_message(self, level: str, timestamp, *messages: Messages) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def header(self, string: str):
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import json
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
from cpl.core.log.log_level import LogLevel
|
||||
from cpl.core.log.logger import Logger
|
||||
from cpl.core.typing import Source, Messages
|
||||
from cpl.dependency import get_provider
|
||||
from cpl.dependency.context import get_provider
|
||||
|
||||
|
||||
class StructuredLogger(Logger):
|
||||
@@ -21,18 +19,7 @@ class StructuredLogger(Logger):
|
||||
def log_file(self):
|
||||
return f"logs/{self._file_prefix}_{datetime.now().strftime('%Y-%m-%d')}.jsonl"
|
||||
|
||||
def _log(self, level: LogLevel, *messages: Messages):
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
formatted_message = self._format_message(level.value, timestamp, *messages)
|
||||
structured_message = self._get_structured_message(level.value, timestamp, formatted_message)
|
||||
|
||||
self._write_log_to_file(level, structured_message)
|
||||
self._write_to_console(level, formatted_message)
|
||||
except Exception as e:
|
||||
print(f"Error while logging: {e} -> {traceback.format_exc()}")
|
||||
|
||||
def _get_structured_message(self, level: str, timestamp: str, messages: str) -> str:
|
||||
def _file_format_message(self, level: str, timestamp: str, *messages: Messages) -> str:
|
||||
structured_message = {
|
||||
"timestamp": timestamp,
|
||||
"level": level.upper(),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import inspect
|
||||
from typing import Type
|
||||
|
||||
from cpl.core.log import LoggerABC, LogLevel
|
||||
from cpl.core.log import LoggerABC, LogLevel, StructuredLogger
|
||||
from cpl.core.typing import Messages
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
@@ -31,8 +31,11 @@ class WrappedLogger(LoggerABC):
|
||||
def set_level(self, level: LogLevel):
|
||||
self._logger.set_level(level)
|
||||
|
||||
def _format_message(self, level: str, timestamp, *messages: Messages) -> str:
|
||||
return self._logger._format_message(level, timestamp, *messages)
|
||||
def _file_format_message(self, level: str, timestamp, *messages: Messages) -> str:
|
||||
return self._logger._file_format_message(level, timestamp, *messages)
|
||||
|
||||
def _console_format_message(self, level: str, timestamp, *messages: Messages) -> str:
|
||||
return self._logger._console_format_message(level, timestamp, *messages)
|
||||
|
||||
@staticmethod
|
||||
def _get_source() -> str | None:
|
||||
@@ -48,6 +51,7 @@ class WrappedLogger(LoggerABC):
|
||||
ServiceCollection,
|
||||
WrappedLogger,
|
||||
WrappedLogger.__subclasses__(),
|
||||
StructuredLogger,
|
||||
]
|
||||
|
||||
ignore_modules = [x.__module__ for x in ignore_classes if isinstance(x, type)]
|
||||
|
||||
@@ -1,34 +1,5 @@
|
||||
import os
|
||||
|
||||
from cpl.application.abc import ApplicationABC as _ApplicationABC
|
||||
from . import mysql as _mysql
|
||||
from . import postgres as _postgres
|
||||
from .database_module import DatabaseModule
|
||||
from .logger import DBLogger
|
||||
from .table_manager import TableManager
|
||||
|
||||
|
||||
def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC:
|
||||
from cpl.database.service.migration_service import MigrationService
|
||||
|
||||
migration_service = self._services.get_service(MigrationService)
|
||||
migration_service.with_directory(os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts"))
|
||||
|
||||
if isinstance(paths, str):
|
||||
paths = [paths]
|
||||
|
||||
for path in paths:
|
||||
migration_service.with_directory(path)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def _with_seeders(self: _ApplicationABC) -> _ApplicationABC:
|
||||
from cpl.database.service.seeder_service import SeederService
|
||||
from cpl.application.host import Host
|
||||
|
||||
seeder_service: SeederService = self._services.get_service(SeederService)
|
||||
Host.run(seeder_service.seed)
|
||||
return self
|
||||
|
||||
|
||||
_ApplicationABC.extend(_ApplicationABC.with_migrations, _with_migrations)
|
||||
_ApplicationABC.extend(_ApplicationABC.with_seeders, _with_seeders)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from .connection_abc import ConnectionABC
|
||||
from .data_access_object_abc import DataAccessObjectABC
|
||||
from .data_seeder_abc import DataSeederABC
|
||||
from .db_context_abc import DBContextABC
|
||||
from .db_join_model_abc import DbJoinModelABC
|
||||
from .db_model_abc import DbModelABC
|
||||
|
||||
@@ -14,7 +14,7 @@ from cpl.database.logger import DBLogger
|
||||
from cpl.database.model.server_type import ServerType, ServerTypes
|
||||
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
|
||||
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
|
||||
from cpl.dependency import get_provider
|
||||
from cpl.dependency.context import get_provider
|
||||
|
||||
|
||||
class DataAccessObjectABC(ABC, Generic[T_DBM]):
|
||||
|
||||
@@ -2,7 +2,7 @@ from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Type
|
||||
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.table_manager import TableManager
|
||||
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
||||
from cpl.database.abc.db_model_abc import DbModelABC
|
||||
|
||||
|
||||
@@ -1,22 +1,33 @@
|
||||
from cpl.core.errors import module_dependency_error
|
||||
from cpl.database.model.server_type import ServerType
|
||||
from cpl.database.model.database_settings import DatabaseSettings
|
||||
from cpl.database.mysql.mysql_module import MySQLModule
|
||||
from cpl.database.postgres.postgres_module import PostgresModule
|
||||
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
|
||||
from cpl.database.service.migration_service import MigrationService
|
||||
from cpl.database.service.seeder_service import SeederService
|
||||
from cpl.dependency.module import Module, TModule
|
||||
from cpl.dependency.service_collection import ServiceCollection
|
||||
from cpl.dependency.module.module import Module
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
|
||||
class DatabaseModule(Module):
|
||||
@staticmethod
|
||||
def dependencies() -> list[TModule]:
|
||||
if not ServerType.has_server_type:
|
||||
module_dependency_error(__name__, "MySQLModule or PostgresModule")
|
||||
dependencies = [(MySQLModule, PostgresModule)]
|
||||
config = [DatabaseSettings]
|
||||
singleton = [
|
||||
ExecutedMigrationDao,
|
||||
MigrationService,
|
||||
SeederService,
|
||||
]
|
||||
|
||||
return []
|
||||
@classmethod
|
||||
def configure(cls, provider: ServiceProvider): ...
|
||||
|
||||
@staticmethod
|
||||
def register(collection: ServiceCollection):
|
||||
collection.add_singleton(ExecutedMigrationDao)
|
||||
collection.add_singleton(MigrationService)
|
||||
collection.add_singleton(SeederService)
|
||||
def with_migrations(services: ServiceProvider, *paths: str | list[str]):
|
||||
from cpl.database.service.migration_service import MigrationService
|
||||
|
||||
migration_service = services.get_service(MigrationService)
|
||||
|
||||
if isinstance(paths, str):
|
||||
paths = [paths]
|
||||
|
||||
for path in paths:
|
||||
migration_service.with_directory(path)
|
||||
|
||||
@@ -21,4 +21,4 @@ class DatabaseSettings(ConfigurationModelABC):
|
||||
self.option("use_unicode", bool, False)
|
||||
self.option("buffered", bool, False)
|
||||
self.option("auth_plugin", str, "caching_sha2_password")
|
||||
self.option("ssl_disabled", bool, False)
|
||||
self.option("ssl_disabled", bool, True)
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from .connection import DatabaseConnection
|
||||
from .db_context import DBContext
|
||||
from .mysql_module import MySQLModule
|
||||
from .mysql_pool import MySQLPool
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
from cpl.core.configuration.configuration import Configuration
|
||||
from cpl.database.abc.db_context_abc import DBContextABC
|
||||
from cpl.database.model.database_settings import DatabaseSettings
|
||||
from cpl.database.model.server_type import ServerTypes, ServerType
|
||||
from cpl.database.mysql.db_context import DBContext
|
||||
from cpl.dependency.module import Module, TModule
|
||||
from cpl.dependency.module.module import Module
|
||||
from cpl.dependency.service_collection import ServiceCollection
|
||||
|
||||
|
||||
class MySQLModule(Module):
|
||||
@staticmethod
|
||||
def dependencies() -> list[TModule]:
|
||||
return []
|
||||
config = [DatabaseSettings]
|
||||
singleton = [(DBContextABC, DBContext)]
|
||||
|
||||
@staticmethod
|
||||
def register(collection: ServiceCollection):
|
||||
ServerType.set_server_type(ServerTypes(ServerTypes.MYSQL.value))
|
||||
Configuration.set("DB_DEFAULT_PORT", 3306)
|
||||
|
||||
collection.add_singleton(DBContextABC, DBContext)
|
||||
|
||||
@@ -22,27 +22,27 @@ class MySQLPool:
|
||||
"use_unicode": database_settings.use_unicode,
|
||||
"buffered": database_settings.buffered,
|
||||
"auth_plugin": database_settings.auth_plugin,
|
||||
"ssl_disabled": False,
|
||||
"ssl_disabled": database_settings.ssl_disabled,
|
||||
}
|
||||
self._pool: Optional[MySQLConnectionPool] = None
|
||||
|
||||
async def _get_pool(self):
|
||||
if self._pool is None:
|
||||
self._pool = MySQLConnectionPool(
|
||||
pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig
|
||||
)
|
||||
await self._pool.initialize_pool()
|
||||
|
||||
con = await self._pool.get_connection()
|
||||
try:
|
||||
self._pool = MySQLConnectionPool(
|
||||
pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig
|
||||
)
|
||||
await self._pool.initialize_pool()
|
||||
|
||||
con = await self._pool.get_connection()
|
||||
async with await con.cursor() as cursor:
|
||||
await cursor.execute("SELECT 1")
|
||||
await cursor.fetchall()
|
||||
|
||||
await con.close()
|
||||
except Exception as e:
|
||||
logger = get_provider().get_service(DBLogger)
|
||||
logger.fatal(f"Error connecting to the database: {e}")
|
||||
finally:
|
||||
await con.close()
|
||||
logger.fatal(f"Error connecting to the database", e)
|
||||
|
||||
return self._pool
|
||||
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from .db_context import DBContext
|
||||
from .postgres_module import PostgresModule
|
||||
from .postgres_pool import PostgresPool
|
||||
from .sql_select_builder import SQLSelectBuilder
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
from cpl.core.configuration.configuration import Configuration
|
||||
from cpl.database.abc.db_context_abc import DBContextABC
|
||||
from cpl.database.database_module import DatabaseModule
|
||||
from cpl.database.model.database_settings import DatabaseSettings
|
||||
from cpl.database.model.server_type import ServerTypes, ServerType
|
||||
from cpl.database.postgres.db_context import DBContext
|
||||
from cpl.dependency.module import Module, TModule
|
||||
from cpl.dependency.module.module import Module
|
||||
from cpl.dependency.service_collection import ServiceCollection
|
||||
|
||||
|
||||
class PostgresModule(Module):
|
||||
@staticmethod
|
||||
def dependencies() -> list[TModule]:
|
||||
return [DatabaseModule]
|
||||
config = [DatabaseSettings]
|
||||
singleton = [(DBContextABC, DBContext)]
|
||||
|
||||
@staticmethod
|
||||
def register(collection: ServiceCollection):
|
||||
ServerType.set_server_type(ServerTypes(ServerTypes.POSTGRES.value))
|
||||
Configuration.set("DB_DEFAULT_PORT", 5432)
|
||||
|
||||
collection.add_singleton(DBContextABC, DBContext)
|
||||
|
||||
@@ -7,7 +7,7 @@ from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.model import DatabaseSettings
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.dependency.context import get_provider
|
||||
|
||||
|
||||
class PostgresPool:
|
||||
@@ -31,15 +31,16 @@ class PostgresPool:
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
|
||||
)
|
||||
await pool.open()
|
||||
try:
|
||||
await pool.open()
|
||||
async with pool.connection() as con:
|
||||
await pool.check_connection(con)
|
||||
|
||||
self._pool = pool
|
||||
except PoolTimeout as e:
|
||||
await pool.close()
|
||||
logger = get_provider().get_service(DBLogger)
|
||||
logger.fatal(f"Failed to connect to the database", e)
|
||||
self._pool = pool
|
||||
|
||||
return self._pool
|
||||
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
from .executed_migration import ExecutedMigration
|
||||
from .executed_migration_dao import ExecutedMigrationDao
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.table_manager import TableManager
|
||||
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
||||
from cpl.database.schema.executed_migration import ExecutedMigration
|
||||
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
from .seeder_service import SeederService
|
||||
from .migration_service import MigrationService
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
import glob
|
||||
import os
|
||||
|
||||
from cpl.database.abc import DBContextABC
|
||||
from cpl.database.abc.db_context_abc import DBContextABC
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.model import Migration
|
||||
from cpl.database.model.migration import Migration
|
||||
from cpl.database.model.server_type import ServerType, ServerTypes
|
||||
from cpl.database.schema.executed_migration import ExecutedMigration
|
||||
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
|
||||
from cpl.dependency.hosted.startup_task import StartupTask
|
||||
from cpl.dependency.hosted import StartupTask
|
||||
|
||||
|
||||
class MigrationService(StartupTask):
|
||||
|
||||
def __init__(self, logger: DBLogger, db: DBContextABC, executed_migration_dao: ExecutedMigrationDao):
|
||||
StartupTask.__init__(self)
|
||||
self._logger = logger
|
||||
self._db = db
|
||||
self._executed_migration_dao = executed_migration_dao
|
||||
@@ -24,12 +23,23 @@ class MigrationService(StartupTask):
|
||||
self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/postgres"))
|
||||
elif ServerType.server_type == ServerTypes.MYSQL:
|
||||
self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/mysql"))
|
||||
else:
|
||||
raise Exception("Unsupported database type")
|
||||
|
||||
async def run(self):
|
||||
await self._execute(self._load_scripts())
|
||||
|
||||
def with_directory(self, directory: str) -> "MigrationService":
|
||||
self._script_directories.append(directory)
|
||||
cpl_rel_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../..")
|
||||
cpl_abs_path = os.path.abspath(cpl_rel_path)
|
||||
|
||||
if directory.startswith(cpl_abs_path) or os.path.abspath(directory).startswith(cpl_abs_path):
|
||||
if len(self._script_directories) > 0:
|
||||
self._script_directories.insert(1, directory)
|
||||
else:
|
||||
self._script_directories.append(directory)
|
||||
else:
|
||||
self._script_directories.append(directory)
|
||||
return self
|
||||
|
||||
async def _get_migration_history(self) -> list[ExecutedMigration]:
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.dependency.hosted import StartupTask
|
||||
|
||||
|
||||
class SeederService:
|
||||
class SeederService(StartupTask):
|
||||
|
||||
def __init__(self, provider: ServiceProvider):
|
||||
StartupTask.__init__(self)
|
||||
self._provider = provider
|
||||
self._logger = provider.get_service(DBLogger)
|
||||
|
||||
async def seed(self):
|
||||
async def run(self):
|
||||
seeders = self._provider.get_services(DataSeederABC)
|
||||
self._logger.debug(f"Found {len(seeders)} seeders")
|
||||
for seeder in seeders:
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from .hosted_service import HostedService
|
||||
from .startup_task import StartupTask
|
||||
from .startup_task import StartupTask
|
||||
|
||||
@@ -6,4 +6,4 @@ class HostedService(ABC):
|
||||
async def start(self): ...
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self): ...
|
||||
async def stop(self): ...
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Type
|
||||
|
||||
TModule = Type["Module"]
|
||||
|
||||
class Module(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def dependencies() -> list[TModule]: ...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def register(collection: "ServiceCollection"): ...
|
||||
10
src/cpl-dependency/cpl/dependency/module/module.py
Normal file
10
src/cpl-dependency/cpl/dependency/module/module.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from cpl.dependency.module.module_abc import ModuleABC
|
||||
|
||||
|
||||
class Module(ModuleABC):
|
||||
|
||||
@staticmethod
|
||||
def register(collection: "ServiceCollection"): ...
|
||||
|
||||
@staticmethod
|
||||
def configure(provider: "ServiceProvider"): ...
|
||||
60
src/cpl-dependency/cpl/dependency/module/module_abc.py
Normal file
60
src/cpl-dependency/cpl/dependency/module/module_abc.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import isclass
|
||||
|
||||
from cpl.core.configuration import ConfigurationModelABC
|
||||
|
||||
|
||||
class ModuleABC(ABC):
|
||||
__OPTIONAL_VARS = ["dependencies", "configuration", "singleton", "scoped", "transient", "hosted"]
|
||||
|
||||
def __init_subclass__(cls):
|
||||
super().__init_subclass__()
|
||||
|
||||
if f"{cls.__module__}.{cls.__name__}" == "cpl.dependency.module.module.Module":
|
||||
return
|
||||
|
||||
for var in cls.__OPTIONAL_VARS:
|
||||
if not hasattr(cls, var):
|
||||
continue
|
||||
|
||||
value = getattr(cls, var)
|
||||
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"'{var}' attribute of {cls.__name__} must be a list, not {type(value).__name__}")
|
||||
|
||||
for dep in value:
|
||||
if var == "config":
|
||||
if not isclass(dep) or not issubclass(dep, ConfigurationModelABC):
|
||||
raise TypeError(
|
||||
f"Invalid config {dep} in {cls.__name__}: must be subclass of ConfigurationModelABC"
|
||||
)
|
||||
elif var == "dependencies":
|
||||
if not isinstance(dep, (list, tuple)) and not isclass(dep):
|
||||
raise TypeError(f"Invalid dependency {dep} in {cls.__name__}")
|
||||
else:
|
||||
if not isinstance(dep, tuple) and not isclass(dep):
|
||||
raise TypeError(f"Invalid {var} {dep} in {cls.__name__}")
|
||||
|
||||
@classmethod
|
||||
def get_singleton(cls):
|
||||
return getattr(cls, "singleton", [])
|
||||
|
||||
@classmethod
|
||||
def get_scoped(cls):
|
||||
return getattr(cls, "scoped", [])
|
||||
|
||||
@classmethod
|
||||
def get_transient(cls):
|
||||
return getattr(cls, "transient", [])
|
||||
|
||||
@classmethod
|
||||
def get_hosted(cls):
|
||||
return getattr(cls, "hosted", [])
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def register(collection: "ServiceCollection"): ...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def configure(provider: "ServiceProvider"): ...
|
||||
17
src/cpl-dependency/cpl/dependency/module/module_protocol.py
Normal file
17
src/cpl-dependency/cpl/dependency/module/module_protocol.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Protocol
|
||||
|
||||
from cpl.dependency.typing import TService, TModule, TConfig
|
||||
|
||||
|
||||
class ModuleProtocol(Protocol):
|
||||
dependencies: list[TModule | TService] = []
|
||||
config: list[TConfig] = []
|
||||
singleton: list[TService] = []
|
||||
scoped: list[TService] = []
|
||||
transient: list[TService] = []
|
||||
|
||||
@staticmethod
|
||||
def register(collection: "ServiceCollection"): ...
|
||||
|
||||
@staticmethod
|
||||
def configure(provider: "ServiceProvider"): ...
|
||||
@@ -1,14 +1,16 @@
|
||||
from inspect import isclass
|
||||
from typing import Union, Type, Callable, Self
|
||||
from typing import Union, Callable, Self, Type
|
||||
|
||||
from cpl.core.errors import module_dependency_error
|
||||
from cpl.core.log.logger_abc import LoggerABC
|
||||
from cpl.core.typing import T, Service
|
||||
from cpl.core.utils.cache import Cache
|
||||
from cpl.dependency.hosted.startup_task import StartupTask
|
||||
from cpl.dependency.module import Module
|
||||
from cpl.dependency.module.module import Module
|
||||
from cpl.dependency.service_descriptor import ServiceDescriptor
|
||||
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.dependency.typing import TModule, TService, TStartupTask
|
||||
|
||||
|
||||
class ServiceCollection:
|
||||
@@ -16,19 +18,112 @@ class ServiceCollection:
|
||||
|
||||
_modules: dict[str, Callable] = {}
|
||||
|
||||
@classmethod
|
||||
def with_module(cls, func: Callable, name: str = None) -> type[Self]:
|
||||
# cls._modules[func.__name__ if name is None else name] = func
|
||||
return cls
|
||||
|
||||
def __init__(self):
|
||||
self._service_descriptors: list[ServiceDescriptor] = []
|
||||
self._loaded_modules: set[str] = set()
|
||||
self._loaded_modules: set[TModule] = set()
|
||||
|
||||
@property
|
||||
def loaded_modules(self) -> set[str]:
|
||||
def loaded_modules(self) -> set[TModule]:
|
||||
return self._loaded_modules
|
||||
|
||||
def _check_dependency(self, module: TModule, dependency: TModule | TService, optional: bool = False) -> bool:
|
||||
if not issubclass(dependency, Module):
|
||||
found_services = [
|
||||
x
|
||||
for x in self._service_descriptors
|
||||
if x.service_type == dependency or x.base_type == dependency or isinstance(x.implementation, dependency)
|
||||
]
|
||||
|
||||
if len(found_services) > 0:
|
||||
return True
|
||||
|
||||
if optional:
|
||||
return False
|
||||
|
||||
module_dependency_error(module.__name__, dependency.__name__)
|
||||
|
||||
if dependency not in self._loaded_modules:
|
||||
if optional:
|
||||
return False
|
||||
|
||||
module_dependency_error(module.__name__, dependency.__name__)
|
||||
|
||||
return True
|
||||
|
||||
def _add_module_service(self, service: TService | tuple[TService, TService], lifetime: ServiceLifetimeEnum):
|
||||
args = ()
|
||||
|
||||
if isinstance(service, tuple):
|
||||
if len(service) != 2:
|
||||
raise ValueError("Service must be a tuple in the format (XABC, X)")
|
||||
|
||||
k, v = service
|
||||
if not (isinstance(k, type) and isinstance(v, type)):
|
||||
raise ValueError("Service tuple must have elements in the format (XABC, X)")
|
||||
args = (k, v)
|
||||
else:
|
||||
if not isinstance(service, type):
|
||||
raise ValueError("Service must be a type or a tuple of two types")
|
||||
args = (service,)
|
||||
|
||||
match lifetime:
|
||||
case ServiceLifetimeEnum.singleton:
|
||||
self.add_singleton(*args)
|
||||
case ServiceLifetimeEnum.scoped:
|
||||
self.add_scoped(*args)
|
||||
case ServiceLifetimeEnum.transient:
|
||||
self.add_transient(*args)
|
||||
case ServiceLifetimeEnum.hosted:
|
||||
self.add_hosted_service(*args)
|
||||
case _:
|
||||
raise ValueError(f"Unknown service lifetime: {lifetime}")
|
||||
|
||||
def _add_module_services(self, module: TModule):
|
||||
for s in module.get_singleton():
|
||||
self._add_module_service(s, ServiceLifetimeEnum.singleton)
|
||||
|
||||
for s in module.get_scoped():
|
||||
self._add_module_service(s, ServiceLifetimeEnum.scoped)
|
||||
|
||||
for s in module.get_transient():
|
||||
self._add_module_service(s, ServiceLifetimeEnum.transient)
|
||||
|
||||
for s in module.get_hosted():
|
||||
self._add_module_service(s, ServiceLifetimeEnum.hosted)
|
||||
|
||||
def _add_module_configuration(self, module: TModule):
|
||||
from cpl.core.configuration.configuration import Configuration
|
||||
from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC
|
||||
|
||||
configs = getattr(module, "configuration", [])
|
||||
for config in configs:
|
||||
if not issubclass(config, ConfigurationModelABC):
|
||||
raise TypeError(
|
||||
f"Invalid config {config} in {module.__name__}: must be subclass of ConfigurationModelABC"
|
||||
)
|
||||
|
||||
cfg = Configuration.get(config)
|
||||
if cfg is None:
|
||||
continue
|
||||
self.add_singleton(cfg)
|
||||
|
||||
def _check_dependencies(self, module: TModule):
|
||||
dependencies: list[TModule | Type] = getattr(module, "dependencies", [])
|
||||
for dependency in dependencies:
|
||||
if isinstance(dependency, (list, tuple)):
|
||||
deps_exists = [self._check_dependency(module, dep, optional=True) for dep in dependency]
|
||||
|
||||
if not any(deps_exists):
|
||||
if len(dependency) > 1:
|
||||
names = ", ".join([dep.__name__ for dep in dependency[:-1]]) + f" or {dependency[-1].__name__}"
|
||||
else:
|
||||
names = dependency[0].__name__
|
||||
|
||||
module_dependency_error(module.__name__, names)
|
||||
continue
|
||||
|
||||
self._check_dependency(module, dependency)
|
||||
|
||||
def _add_descriptor(self, service: Union[type, object], lifetime: ServiceLifetimeEnum, base_type: Callable = None):
|
||||
found = False
|
||||
for descriptor in self._service_descriptors:
|
||||
@@ -44,7 +139,9 @@ class ServiceCollection:
|
||||
|
||||
self._service_descriptors.append(ServiceDescriptor(service, lifetime, base_type))
|
||||
|
||||
def _add_descriptor_by_lifetime(self, service_type: Type, lifetime: ServiceLifetimeEnum, service: Callable = None):
|
||||
def _add_descriptor_by_lifetime(
|
||||
self, service_type: TService | T, lifetime: ServiceLifetimeEnum, service: Callable = None
|
||||
):
|
||||
if service is not None:
|
||||
self._add_descriptor(service, lifetime, service_type)
|
||||
else:
|
||||
@@ -52,19 +149,19 @@ class ServiceCollection:
|
||||
|
||||
return self
|
||||
|
||||
def add_singleton(self, service_type: T, service: Service = None) -> Self:
|
||||
def add_singleton(self, service_type: TService | T, service: Service = None) -> Self:
|
||||
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.singleton, service)
|
||||
return self
|
||||
|
||||
def add_scoped(self, service_type: T, service: Service = None) -> Self:
|
||||
def add_scoped(self, service_type: TService | T, service: Service = None) -> Self:
|
||||
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.scoped, service)
|
||||
return self
|
||||
|
||||
def add_transient(self, service_type: T, service: Service = None) -> Self:
|
||||
def add_transient(self, service_type: TService | T, service: Service = None) -> Self:
|
||||
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service)
|
||||
return self
|
||||
|
||||
def add_startup_task(self, task: Type[StartupTask]) -> Self:
|
||||
def add_startup_task(self, task: TStartupTask) -> Self:
|
||||
self.add_singleton(StartupTask, task)
|
||||
return self
|
||||
|
||||
@@ -76,22 +173,20 @@ class ServiceCollection:
|
||||
sp = ServiceProvider(self._service_descriptors)
|
||||
return sp
|
||||
|
||||
def add_module(self, module: Type[Module]) -> Self:
|
||||
def add_module(self, module: TModule) -> Self:
|
||||
assert isclass(module), "Module must be a Module"
|
||||
assert issubclass(module, Module), f"Module must be subclass of {Module.__name__}"
|
||||
|
||||
name = module.__name__
|
||||
if module in self._modules:
|
||||
raise ValueError(f"Module {module} not found")
|
||||
raise ValueError(f"Module {module.__name__} is already registered")
|
||||
|
||||
for dependency in module.dependencies():
|
||||
if dependency.__name__ not in self._loaded_modules:
|
||||
self.add_module(dependency)
|
||||
self._check_dependencies(module)
|
||||
self._add_module_configuration(module)
|
||||
self._add_module_services(module)
|
||||
module.register(self)
|
||||
|
||||
module().register(self)
|
||||
|
||||
if name not in self._loaded_modules:
|
||||
self._loaded_modules.add(name)
|
||||
if module not in self._loaded_modules:
|
||||
self._loaded_modules.add(module)
|
||||
|
||||
return self
|
||||
|
||||
@@ -114,6 +209,6 @@ class ServiceCollection:
|
||||
self.add_transient(wrapper)
|
||||
return self
|
||||
|
||||
def add_cache(self, t: Type[T]):
|
||||
def add_cache(self, t: TService):
|
||||
self._service_descriptors.append(ServiceDescriptor(Cache(t=t), ServiceLifetimeEnum.singleton, Cache[t]))
|
||||
return self
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import inspect
|
||||
import typing
|
||||
from contextlib import contextmanager
|
||||
from inspect import signature, Parameter, Signature
|
||||
@@ -23,7 +24,11 @@ class ServiceProvider:
|
||||
type_args = list(typing.get_args(service_type))
|
||||
|
||||
for descriptor in self._service_descriptors:
|
||||
if typing.get_origin(service_type) is None and (descriptor.service_type == service_type or issubclass(descriptor.base_type, service_type)):
|
||||
if typing.get_origin(service_type) is None and (
|
||||
descriptor.service_type == service_type
|
||||
or typing.get_origin(descriptor.base_type) is None
|
||||
and issubclass(descriptor.base_type, service_type)
|
||||
):
|
||||
return descriptor
|
||||
|
||||
descriptor_base_type = typing.get_origin(descriptor.base_type) or descriptor.base_type
|
||||
@@ -65,9 +70,7 @@ class ServiceProvider:
|
||||
implementations.append(descriptor.implementation)
|
||||
continue
|
||||
|
||||
implementation = self._build_service(
|
||||
descriptor, *args, origin_service_type=service_type, **kwargs
|
||||
)
|
||||
implementation = self._build_service(descriptor, *args, origin_service_type=service_type, **kwargs)
|
||||
if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped):
|
||||
descriptor.implementation = implementation
|
||||
|
||||
@@ -75,16 +78,51 @@ class ServiceProvider:
|
||||
|
||||
return implementations
|
||||
|
||||
def _get_source(self):
|
||||
stack = inspect.stack()
|
||||
if len(stack) <= 1:
|
||||
return None
|
||||
|
||||
from cpl.dependency.service_collection import ServiceCollection
|
||||
|
||||
ignore_classes = [
|
||||
ServiceProvider,
|
||||
ServiceProvider.__subclasses__(),
|
||||
ServiceCollection,
|
||||
]
|
||||
|
||||
ignore_modules = [x.__module__ for x in ignore_classes if isinstance(x, type)]
|
||||
|
||||
for i, frame_info in enumerate(stack[1:]):
|
||||
module = inspect.getmodule(frame_info.frame)
|
||||
if module is None:
|
||||
continue
|
||||
|
||||
if module.__name__ in ignore_classes or module in ignore_classes:
|
||||
continue
|
||||
|
||||
if module in ignore_modules or module.__name__ in ignore_modules:
|
||||
continue
|
||||
|
||||
if module.__name__ != __name__:
|
||||
return module.__name__
|
||||
|
||||
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[T]:
|
||||
params = []
|
||||
for param in sig.parameters.items():
|
||||
parameter = param[1]
|
||||
if parameter.name != "self" and parameter.annotation != Parameter.empty:
|
||||
if typing.get_origin(parameter.annotation) == list:
|
||||
params.append(self._get_services(typing.get_args(parameter.annotation)[0], service_type=origin_service_type))
|
||||
params.append(
|
||||
self._get_services(typing.get_args(parameter.annotation)[0], service_type=origin_service_type)
|
||||
)
|
||||
|
||||
elif parameter.annotation == Source:
|
||||
params.append(origin_service_type.__name__)
|
||||
params.append(
|
||||
origin_service_type.__name__
|
||||
if inspect.isclass(origin_service_type)
|
||||
else str(origin_service_type)
|
||||
)
|
||||
|
||||
elif issubclass(parameter.annotation, ServiceProvider):
|
||||
params.append(self)
|
||||
@@ -104,12 +142,17 @@ class ServiceProvider:
|
||||
|
||||
return params
|
||||
|
||||
def _build_service(self, descriptor: ServiceDescriptor, *args, origin_service_type: type = None, **kwargs) -> object:
|
||||
def _build_service(
|
||||
self, descriptor: ServiceDescriptor, *args, origin_service_type: type = None, **kwargs
|
||||
) -> object:
|
||||
if descriptor.implementation is not None:
|
||||
service_type = type(descriptor.implementation)
|
||||
else:
|
||||
service_type = descriptor.service_type
|
||||
|
||||
if origin_service_type is None:
|
||||
origin_service_type = self._get_source()
|
||||
|
||||
if origin_service_type is None:
|
||||
origin_service_type = service_type
|
||||
|
||||
@@ -131,7 +174,11 @@ class ServiceProvider:
|
||||
yield scoped_provider
|
||||
|
||||
def get_hosted_services(self) -> list[Optional[T]]:
|
||||
hosted_services = [self.get_service(d.service_type) for d in self._service_descriptors if d.lifetime == ServiceLifetimeEnum.hosted]
|
||||
hosted_services = [
|
||||
self.get_service(d.service_type)
|
||||
for d in self._service_descriptors
|
||||
if d.lifetime == ServiceLifetimeEnum.hosted
|
||||
]
|
||||
return hosted_services
|
||||
|
||||
def get_service(self, service_type: Type[T], *args, **kwargs) -> Optional[T]:
|
||||
|
||||
12
src/cpl-dependency/cpl/dependency/typing.py
Normal file
12
src/cpl-dependency/cpl/dependency/typing.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import Type
|
||||
|
||||
from cpl.core.configuration import ConfigurationModelABC
|
||||
from cpl.core.typing import T
|
||||
from cpl.dependency.hosted import StartupTask
|
||||
from cpl.dependency.module.module import Module
|
||||
|
||||
TModule = Type[Module]
|
||||
Modules = set[TModule]
|
||||
TService = Type[T]
|
||||
TConfig = Type[ConfigurationModelABC]
|
||||
TStartupTask = Type[StartupTask]
|
||||
@@ -1,20 +1,6 @@
|
||||
from cpl.dependency import ServiceCollection as _ServiceCollection
|
||||
from .abc.email_client_abc import EMailClientABC
|
||||
from .email_client import EMailClient
|
||||
from .email_client_settings import EMailClientSettings
|
||||
from .email_model import EMail
|
||||
from .logger import MailLogger
|
||||
|
||||
|
||||
def add_mail(collection: _ServiceCollection):
|
||||
from cpl.core.console import Console
|
||||
from cpl.core.log import LoggerABC
|
||||
|
||||
try:
|
||||
collection.add_singleton(EMailClientABC, EMailClient)
|
||||
collection.add_transient(LoggerABC, MailLogger)
|
||||
except ImportError as e:
|
||||
Console.error("cpl-translation is not installed", str(e))
|
||||
|
||||
|
||||
_ServiceCollection.with_module(add_mail, __name__)
|
||||
from .mail_module import MailModule
|
||||
|
||||
@@ -1,15 +1,8 @@
|
||||
from cpl.dependency import ServiceCollection
|
||||
from cpl.dependency.module import Module, TModule
|
||||
from cpl.mail.email_client import EMailClient
|
||||
from cpl.dependency.module.module import Module
|
||||
|
||||
from cpl.mail.abc.email_client_abc import EMailClientABC
|
||||
from cpl.mail.email_client import EMailClient
|
||||
|
||||
|
||||
class MailModule(Module):
|
||||
@staticmethod
|
||||
def dependencies() -> list[TModule]:
|
||||
return []
|
||||
|
||||
|
||||
@staticmethod
|
||||
def register(collection: ServiceCollection):
|
||||
collection.add_singleton(EMailClientABC, EMailClient)
|
||||
singleton = [(EMailClientABC, EMailClient)]
|
||||
|
||||
@@ -1,22 +1,5 @@
|
||||
from cpl.dependency import ServiceCollection as _ServiceCollection
|
||||
from .translate_pipe import TranslatePipe
|
||||
from .translation_module import TranslationModule
|
||||
from .translation_service import TranslationService
|
||||
from .translation_service_abc import TranslationServiceABC
|
||||
from .translation_settings import TranslationSettings
|
||||
|
||||
|
||||
def add_translation(collection: _ServiceCollection):
|
||||
from cpl.core.console import Console
|
||||
from cpl.core.pipes import PipeABC
|
||||
from cpl.translation.translate_pipe import TranslatePipe
|
||||
from cpl.translation.translation_service import TranslationService
|
||||
from cpl.translation.translation_service_abc import TranslationServiceABC
|
||||
|
||||
try:
|
||||
collection.add_singleton(TranslationServiceABC, TranslationService)
|
||||
collection.add_transient(PipeABC, TranslatePipe)
|
||||
except ImportError as e:
|
||||
Console.error("cpl-translation is not installed", str(e))
|
||||
|
||||
|
||||
_ServiceCollection.with_module(add_translation, __name__)
|
||||
|
||||
@@ -1,14 +1,7 @@
|
||||
from cpl.dependency import ServiceCollection
|
||||
from cpl.dependency.module import Module, TModule
|
||||
from cpl.dependency.module.module import Module
|
||||
from cpl.translation.translation_service import TranslationService
|
||||
from cpl.translation.translation_service_abc import TranslationServiceABC
|
||||
|
||||
|
||||
class TranslationModule(Module):
|
||||
@staticmethod
|
||||
def dependencies() -> list[TModule]:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def register(collection: ServiceCollection):
|
||||
collection.add_singleton(TranslationServiceABC, TranslationService)
|
||||
singleton = [(TranslationServiceABC, TranslationService)]
|
||||
|
||||
Reference in New Issue
Block a user