diff --git a/example/custom/api/src/main.py b/example/custom/api/src/main.py index 5f5d2428..3b2d8157 100644 --- a/example/custom/api/src/main.py +++ b/example/custom/api/src/main.py @@ -9,7 +9,7 @@ from cpl.core.configuration import Configuration from cpl.core.console import Console from cpl.core.environment import Environment from cpl.core.utils.cache import Cache -from custom.api.src.scoped_service import ScopedService +from scoped_service import ScopedService from service import PingService diff --git a/example/custom/api/src/routes/ping.py b/example/custom/api/src/routes/ping.py index 7bb40ad2..6abfc976 100644 --- a/example/custom/api/src/routes/ping.py +++ b/example/custom/api/src/routes/ping.py @@ -7,7 +7,7 @@ from cpl.api import APILogger from cpl.api.router import Router from cpl.core.console import Console from cpl.dependency import ServiceProvider -from custom.api.src.scoped_service import ScopedService +from scoped_service import ScopedService @Router.authenticate() diff --git a/src/cpl-application/cpl/application/host.py b/src/cpl-application/cpl/application/host.py index 4db27b6a..b540f33f 100644 --- a/src/cpl-application/cpl/application/host.py +++ b/src/cpl-application/cpl/application/host.py @@ -1,6 +1,9 @@ import asyncio from typing import Callable +from cpl.dependency import get_provider +from cpl.dependency.hosted.startup_task import StartupTask + class Host: _loop = asyncio.get_event_loop() @@ -9,8 +12,20 @@ class Host: def get_loop(cls): return cls._loop + @classmethod + def run_start_tasks(cls): + provider = get_provider() + tasks = provider.get_services(StartupTask) + for task in tasks: + if asyncio.iscoroutinefunction(task.run): + cls._loop.run_until_complete(task.run()) + else: + task.run() + @classmethod def run(cls, func: Callable, *args, **kwargs): + cls.run_start_tasks() + if asyncio.iscoroutinefunction(func): return cls._loop.run_until_complete(func(*args, **kwargs)) diff --git a/src/cpl-database/cpl/database/__init__.py b/src/cpl-database/cpl/database/__init__.py index 0029d995..184b398e 100644 --- a/src/cpl-database/cpl/database/__init__.py +++ b/src/cpl-database/cpl/database/__init__.py @@ -9,7 +9,6 @@ from .table_manager import TableManager def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC: - from cpl.application.host import Host from cpl.database.service.migration_service import MigrationService migration_service = self._services.get_service(MigrationService) @@ -21,8 +20,6 @@ def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _Applica for path in paths: migration_service.with_directory(path) - Host.run(migration_service.migrate) - return self diff --git a/src/cpl-database/cpl/database/abc/data_access_object_abc.py b/src/cpl-database/cpl/database/abc/data_access_object_abc.py index 25d40b74..e7014809 100644 --- a/src/cpl-database/cpl/database/abc/data_access_object_abc.py +++ b/src/cpl-database/cpl/database/abc/data_access_object_abc.py @@ -11,6 +11,7 @@ from cpl.database.abc.db_context_abc import DBContextABC from cpl.database.const import DATETIME_FORMAT from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder from cpl.database.logger import DBLogger +from cpl.database.model.server_type import ServerType, ServerTypes from cpl.database.postgres.sql_select_builder import SQLSelectBuilder from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts from cpl.dependency import get_provider @@ -351,13 +352,13 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]): values = f"{await self._get_editor_id(obj) if not skip_editor else ''}{f', {values}' if not skip_editor and len(values) > 0 else f'{values}'}" return f""" - INSERT INTO {self._table_name} ( - {fields} - ) VALUES ( - {values} - ) - RETURNING {self.__primary_key}; - """ + INSERT INTO {self._table_name} ( + {fields} + ) VALUES ( + {values} + ) + {"RETURNING {self.__primary_key};" if ServerType.server_type == ServerTypes.POSTGRES else ";SELECT LAST_INSERT_ID();"} + """ async def create(self, obj: T_DBM, skip_editor=False) -> int: self._logger.debug(f"create {type(obj).__name__} {obj.__dict__}") diff --git a/src/cpl-database/cpl/database/mysql/mysql_pool.py b/src/cpl-database/cpl/database/mysql/mysql_pool.py index 44f648ff..fe8110f0 100644 --- a/src/cpl-database/cpl/database/mysql/mysql_pool.py +++ b/src/cpl-database/cpl/database/mysql/mysql_pool.py @@ -6,7 +6,7 @@ from mysql.connector.aio import MySQLConnectionPool from cpl.core.environment import Environment from cpl.database.logger import DBLogger from cpl.database.model import DatabaseSettings -from cpl.dependency import ServiceProvider +from cpl.dependency.context import get_provider class MySQLPool: @@ -18,7 +18,11 @@ class MySQLPool: "user": database_settings.user, "password": database_settings.password, "database": database_settings.database, - "ssl_disabled": True, + "charset": database_settings.charset, + "use_unicode": database_settings.use_unicode, + "buffered": database_settings.buffered, + "auth_plugin": database_settings.auth_plugin, + "ssl_disabled": False, } self._pool: Optional[MySQLConnectionPool] = None diff --git a/src/cpl-database/cpl/database/service/migration_service.py b/src/cpl-database/cpl/database/service/migration_service.py index 710480c6..d51b9d2a 100644 --- a/src/cpl-database/cpl/database/service/migration_service.py +++ b/src/cpl-database/cpl/database/service/migration_service.py @@ -7,14 +7,16 @@ from cpl.database.model import Migration from cpl.database.model.server_type import ServerType, ServerTypes from cpl.database.schema.executed_migration import ExecutedMigration from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao +from cpl.dependency.hosted.startup_task import StartupTask -class MigrationService: +class MigrationService(StartupTask): - def __init__(self, logger: DBLogger, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao): + def __init__(self, logger: DBLogger, db: DBContextABC, executed_migration_dao: ExecutedMigrationDao): + StartupTask.__init__(self) self._logger = logger self._db = db - self._executedMigrationDao = executedMigrationDao + self._executed_migration_dao = executed_migration_dao self._script_directories: list[str] = [] @@ -23,12 +25,15 @@ class MigrationService: elif ServerType.server_type == ServerTypes.MYSQL: self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/mysql")) + async def run(self): + await self._execute(self._load_scripts()) + def with_directory(self, directory: str) -> "MigrationService": self._script_directories.append(directory) return self async def _get_migration_history(self) -> list[ExecutedMigration]: - results = await self._db.select(f"SELECT * FROM {self._executedMigrationDao.table_name}") + results = await self._db.select(f"SELECT * FROM {self._executed_migration_dao.table_name}") applied_migrations = [] for result in results: applied_migrations.append(ExecutedMigration(result[0])) @@ -91,7 +96,7 @@ class MigrationService: try: # check if table exists if len(result) > 0: - migration_from_db = await self._executedMigrationDao.find_by_id(migration.name) + migration_from_db = await self._executed_migration_dao.find_by_id(migration.name) if migration_from_db is not None: continue @@ -99,12 +104,9 @@ class MigrationService: await self._db.execute(migration.script, multi=True) - await self._executedMigrationDao.create(ExecutedMigration(migration.name), skip_editor=True) + await self._executed_migration_dao.create(ExecutedMigration(migration.name), skip_editor=True) except Exception as e: self._logger.fatal( f"Migration failed: {migration.name}\n{active_statement}", e, ) - - async def migrate(self): - await self._execute(self._load_scripts()) diff --git a/src/cpl-dependency/cpl/dependency/context.py b/src/cpl-dependency/cpl/dependency/context.py index d81c299e..1254b982 100644 --- a/src/cpl-dependency/cpl/dependency/context.py +++ b/src/cpl-dependency/cpl/dependency/context.py @@ -1,15 +1,16 @@ import contextvars from contextlib import contextmanager -_current_provider = contextvars.ContextVar("current_provider", default=None) + +_current_provider = contextvars.ContextVar("current_provider") -def use_root_provider(provider): +def use_root_provider(provider: "ServiceProvider"): _current_provider.set(provider) @contextmanager -def use_provider(provider): +def use_provider(provider: "ServiceProvider"): token = _current_provider.set(provider) try: yield @@ -17,5 +18,5 @@ def use_provider(provider): _current_provider.reset(token) -def get_provider(): +def get_provider() -> "ServiceProvider": return _current_provider.get() diff --git a/src/cpl-dependency/cpl/dependency/hosted/__init__.py b/src/cpl-dependency/cpl/dependency/hosted/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-dependency/cpl/dependency/hosted/startup_task.py b/src/cpl-dependency/cpl/dependency/hosted/startup_task.py new file mode 100644 index 00000000..3d16e921 --- /dev/null +++ b/src/cpl-dependency/cpl/dependency/hosted/startup_task.py @@ -0,0 +1,6 @@ +from abc import ABC, abstractmethod + + +class StartupTask(ABC): + @abstractmethod + async def run(self): ... diff --git a/src/cpl-dependency/cpl/dependency/service_collection.py b/src/cpl-dependency/cpl/dependency/service_collection.py index 47c2d1fe..a63b263a 100644 --- a/src/cpl-dependency/cpl/dependency/service_collection.py +++ b/src/cpl-dependency/cpl/dependency/service_collection.py @@ -3,6 +3,7 @@ from typing import Union, Type, Callable, Self from cpl.core.log.logger_abc import LoggerABC from cpl.core.typing import T, Service from cpl.core.utils.cache import Cache +from cpl.dependency.hosted.startup_task import StartupTask from cpl.dependency.service_descriptor import ServiceDescriptor from cpl.dependency.service_lifetime import ServiceLifetimeEnum from cpl.dependency.service_provider import ServiceProvider @@ -61,6 +62,10 @@ class ServiceCollection: self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service) return self + def add_startup_task(self, task: Type[StartupTask]) -> Self: + self.add_singleton(StartupTask, task) + return self + def build(self) -> ServiceProvider: sp = ServiceProvider(self._service_descriptors) return sp