WIP: dev into master #184

Draft
edraft wants to merge 121 commits from dev into master
11 changed files with 58 additions and 27 deletions
Showing only changes of commit 15d3c59f02 - Show all commits

View File

@@ -9,7 +9,7 @@ from cpl.core.configuration import Configuration
from cpl.core.console import Console from cpl.core.console import Console
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.core.utils.cache import Cache from cpl.core.utils.cache import Cache
from custom.api.src.scoped_service import ScopedService from scoped_service import ScopedService
from service import PingService from service import PingService

View File

@@ -7,7 +7,7 @@ from cpl.api import APILogger
from cpl.api.router import Router from cpl.api.router import Router
from cpl.core.console import Console from cpl.core.console import Console
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProvider
from custom.api.src.scoped_service import ScopedService from scoped_service import ScopedService
@Router.authenticate() @Router.authenticate()

View File

@@ -1,6 +1,9 @@
import asyncio import asyncio
from typing import Callable from typing import Callable
from cpl.dependency import get_provider
from cpl.dependency.hosted.startup_task import StartupTask
class Host: class Host:
_loop = asyncio.get_event_loop() _loop = asyncio.get_event_loop()
@@ -9,8 +12,20 @@ class Host:
def get_loop(cls): def get_loop(cls):
return cls._loop return cls._loop
@classmethod
def run_start_tasks(cls):
provider = get_provider()
tasks = provider.get_services(StartupTask)
for task in tasks:
if asyncio.iscoroutinefunction(task.run):
cls._loop.run_until_complete(task.run())
else:
task.run()
@classmethod @classmethod
def run(cls, func: Callable, *args, **kwargs): def run(cls, func: Callable, *args, **kwargs):
cls.run_start_tasks()
if asyncio.iscoroutinefunction(func): if asyncio.iscoroutinefunction(func):
return cls._loop.run_until_complete(func(*args, **kwargs)) return cls._loop.run_until_complete(func(*args, **kwargs))

View File

@@ -9,7 +9,6 @@ from .table_manager import TableManager
def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC: def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC:
from cpl.application.host import Host
from cpl.database.service.migration_service import MigrationService from cpl.database.service.migration_service import MigrationService
migration_service = self._services.get_service(MigrationService) migration_service = self._services.get_service(MigrationService)
@@ -21,8 +20,6 @@ def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _Applica
for path in paths: for path in paths:
migration_service.with_directory(path) migration_service.with_directory(path)
Host.run(migration_service.migrate)
return self return self

View File

@@ -11,6 +11,7 @@ from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.const import DATETIME_FORMAT from cpl.database.const import DATETIME_FORMAT
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.database.logger import DBLogger from cpl.database.logger import DBLogger
from cpl.database.model.server_type import ServerType, ServerTypes
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
from cpl.dependency import get_provider from cpl.dependency import get_provider
@@ -351,13 +352,13 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
values = f"{await self._get_editor_id(obj) if not skip_editor else ''}{f', {values}' if not skip_editor and len(values) > 0 else f'{values}'}" values = f"{await self._get_editor_id(obj) if not skip_editor else ''}{f', {values}' if not skip_editor and len(values) > 0 else f'{values}'}"
return f""" return f"""
INSERT INTO {self._table_name} ( INSERT INTO {self._table_name} (
{fields} {fields}
) VALUES ( ) VALUES (
{values} {values}
) )
RETURNING {self.__primary_key}; {"RETURNING {self.__primary_key};" if ServerType.server_type == ServerTypes.POSTGRES else ";SELECT LAST_INSERT_ID();"}
""" """
async def create(self, obj: T_DBM, skip_editor=False) -> int: async def create(self, obj: T_DBM, skip_editor=False) -> int:
self._logger.debug(f"create {type(obj).__name__} {obj.__dict__}") self._logger.debug(f"create {type(obj).__name__} {obj.__dict__}")

View File

@@ -6,7 +6,7 @@ from mysql.connector.aio import MySQLConnectionPool
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.database.logger import DBLogger from cpl.database.logger import DBLogger
from cpl.database.model import DatabaseSettings from cpl.database.model import DatabaseSettings
from cpl.dependency import ServiceProvider from cpl.dependency.context import get_provider
class MySQLPool: class MySQLPool:
@@ -18,7 +18,11 @@ class MySQLPool:
"user": database_settings.user, "user": database_settings.user,
"password": database_settings.password, "password": database_settings.password,
"database": database_settings.database, "database": database_settings.database,
"ssl_disabled": True, "charset": database_settings.charset,
"use_unicode": database_settings.use_unicode,
"buffered": database_settings.buffered,
"auth_plugin": database_settings.auth_plugin,
"ssl_disabled": False,
} }
self._pool: Optional[MySQLConnectionPool] = None self._pool: Optional[MySQLConnectionPool] = None

View File

@@ -7,14 +7,16 @@ from cpl.database.model import Migration
from cpl.database.model.server_type import ServerType, ServerTypes from cpl.database.model.server_type import ServerType, ServerTypes
from cpl.database.schema.executed_migration import ExecutedMigration from cpl.database.schema.executed_migration import ExecutedMigration
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
from cpl.dependency.hosted.startup_task import StartupTask
class MigrationService: class MigrationService(StartupTask):
def __init__(self, logger: DBLogger, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao): def __init__(self, logger: DBLogger, db: DBContextABC, executed_migration_dao: ExecutedMigrationDao):
StartupTask.__init__(self)
self._logger = logger self._logger = logger
self._db = db self._db = db
self._executedMigrationDao = executedMigrationDao self._executed_migration_dao = executed_migration_dao
self._script_directories: list[str] = [] self._script_directories: list[str] = []
@@ -23,12 +25,15 @@ class MigrationService:
elif ServerType.server_type == ServerTypes.MYSQL: elif ServerType.server_type == ServerTypes.MYSQL:
self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/mysql")) self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/mysql"))
async def run(self):
await self._execute(self._load_scripts())
def with_directory(self, directory: str) -> "MigrationService": def with_directory(self, directory: str) -> "MigrationService":
self._script_directories.append(directory) self._script_directories.append(directory)
return self return self
async def _get_migration_history(self) -> list[ExecutedMigration]: async def _get_migration_history(self) -> list[ExecutedMigration]:
results = await self._db.select(f"SELECT * FROM {self._executedMigrationDao.table_name}") results = await self._db.select(f"SELECT * FROM {self._executed_migration_dao.table_name}")
applied_migrations = [] applied_migrations = []
for result in results: for result in results:
applied_migrations.append(ExecutedMigration(result[0])) applied_migrations.append(ExecutedMigration(result[0]))
@@ -91,7 +96,7 @@ class MigrationService:
try: try:
# check if table exists # check if table exists
if len(result) > 0: if len(result) > 0:
migration_from_db = await self._executedMigrationDao.find_by_id(migration.name) migration_from_db = await self._executed_migration_dao.find_by_id(migration.name)
if migration_from_db is not None: if migration_from_db is not None:
continue continue
@@ -99,12 +104,9 @@ class MigrationService:
await self._db.execute(migration.script, multi=True) await self._db.execute(migration.script, multi=True)
await self._executedMigrationDao.create(ExecutedMigration(migration.name), skip_editor=True) await self._executed_migration_dao.create(ExecutedMigration(migration.name), skip_editor=True)
except Exception as e: except Exception as e:
self._logger.fatal( self._logger.fatal(
f"Migration failed: {migration.name}\n{active_statement}", f"Migration failed: {migration.name}\n{active_statement}",
e, e,
) )
async def migrate(self):
await self._execute(self._load_scripts())

View File

@@ -1,15 +1,16 @@
import contextvars import contextvars
from contextlib import contextmanager from contextlib import contextmanager
_current_provider = contextvars.ContextVar("current_provider", default=None)
_current_provider = contextvars.ContextVar("current_provider")
def use_root_provider(provider): def use_root_provider(provider: "ServiceProvider"):
_current_provider.set(provider) _current_provider.set(provider)
@contextmanager @contextmanager
def use_provider(provider): def use_provider(provider: "ServiceProvider"):
token = _current_provider.set(provider) token = _current_provider.set(provider)
try: try:
yield yield
@@ -17,5 +18,5 @@ def use_provider(provider):
_current_provider.reset(token) _current_provider.reset(token)
def get_provider(): def get_provider() -> "ServiceProvider":
return _current_provider.get() return _current_provider.get()

View File

@@ -0,0 +1,6 @@
from abc import ABC, abstractmethod
class StartupTask(ABC):
@abstractmethod
async def run(self): ...

View File

@@ -3,6 +3,7 @@ from typing import Union, Type, Callable, Self
from cpl.core.log.logger_abc import LoggerABC from cpl.core.log.logger_abc import LoggerABC
from cpl.core.typing import T, Service from cpl.core.typing import T, Service
from cpl.core.utils.cache import Cache from cpl.core.utils.cache import Cache
from cpl.dependency.hosted.startup_task import StartupTask
from cpl.dependency.service_descriptor import ServiceDescriptor from cpl.dependency.service_descriptor import ServiceDescriptor
from cpl.dependency.service_lifetime import ServiceLifetimeEnum from cpl.dependency.service_lifetime import ServiceLifetimeEnum
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
@@ -61,6 +62,10 @@ class ServiceCollection:
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service) self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service)
return self return self
def add_startup_task(self, task: Type[StartupTask]) -> Self:
self.add_singleton(StartupTask, task)
return self
def build(self) -> ServiceProvider: def build(self) -> ServiceProvider:
sp = ServiceProvider(self._service_descriptors) sp = ServiceProvider(self._service_descriptors)
return sp return sp