StartupTask #186
This commit is contained in:
@@ -9,7 +9,7 @@ from cpl.core.configuration import Configuration
|
|||||||
from cpl.core.console import Console
|
from cpl.core.console import Console
|
||||||
from cpl.core.environment import Environment
|
from cpl.core.environment import Environment
|
||||||
from cpl.core.utils.cache import Cache
|
from cpl.core.utils.cache import Cache
|
||||||
from custom.api.src.scoped_service import ScopedService
|
from scoped_service import ScopedService
|
||||||
from service import PingService
|
from service import PingService
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from cpl.api import APILogger
|
|||||||
from cpl.api.router import Router
|
from cpl.api.router import Router
|
||||||
from cpl.core.console import Console
|
from cpl.core.console import Console
|
||||||
from cpl.dependency import ServiceProvider
|
from cpl.dependency import ServiceProvider
|
||||||
from custom.api.src.scoped_service import ScopedService
|
from scoped_service import ScopedService
|
||||||
|
|
||||||
|
|
||||||
@Router.authenticate()
|
@Router.authenticate()
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
from cpl.dependency import get_provider
|
||||||
|
from cpl.dependency.hosted.startup_task import StartupTask
|
||||||
|
|
||||||
|
|
||||||
class Host:
|
class Host:
|
||||||
_loop = asyncio.get_event_loop()
|
_loop = asyncio.get_event_loop()
|
||||||
@@ -9,8 +12,20 @@ class Host:
|
|||||||
def get_loop(cls):
|
def get_loop(cls):
|
||||||
return cls._loop
|
return cls._loop
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def run_start_tasks(cls):
|
||||||
|
provider = get_provider()
|
||||||
|
tasks = provider.get_services(StartupTask)
|
||||||
|
for task in tasks:
|
||||||
|
if asyncio.iscoroutinefunction(task.run):
|
||||||
|
cls._loop.run_until_complete(task.run())
|
||||||
|
else:
|
||||||
|
task.run()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def run(cls, func: Callable, *args, **kwargs):
|
def run(cls, func: Callable, *args, **kwargs):
|
||||||
|
cls.run_start_tasks()
|
||||||
|
|
||||||
if asyncio.iscoroutinefunction(func):
|
if asyncio.iscoroutinefunction(func):
|
||||||
return cls._loop.run_until_complete(func(*args, **kwargs))
|
return cls._loop.run_until_complete(func(*args, **kwargs))
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from .table_manager import TableManager
|
|||||||
|
|
||||||
|
|
||||||
def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC:
|
def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC:
|
||||||
from cpl.application.host import Host
|
|
||||||
from cpl.database.service.migration_service import MigrationService
|
from cpl.database.service.migration_service import MigrationService
|
||||||
|
|
||||||
migration_service = self._services.get_service(MigrationService)
|
migration_service = self._services.get_service(MigrationService)
|
||||||
@@ -21,8 +20,6 @@ def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _Applica
|
|||||||
for path in paths:
|
for path in paths:
|
||||||
migration_service.with_directory(path)
|
migration_service.with_directory(path)
|
||||||
|
|
||||||
Host.run(migration_service.migrate)
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from cpl.database.abc.db_context_abc import DBContextABC
|
|||||||
from cpl.database.const import DATETIME_FORMAT
|
from cpl.database.const import DATETIME_FORMAT
|
||||||
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
|
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
|
||||||
from cpl.database.logger import DBLogger
|
from cpl.database.logger import DBLogger
|
||||||
|
from cpl.database.model.server_type import ServerType, ServerTypes
|
||||||
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
|
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
|
||||||
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
|
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
|
||||||
from cpl.dependency import get_provider
|
from cpl.dependency import get_provider
|
||||||
@@ -356,7 +357,7 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
|
|||||||
) VALUES (
|
) VALUES (
|
||||||
{values}
|
{values}
|
||||||
)
|
)
|
||||||
RETURNING {self.__primary_key};
|
{"RETURNING {self.__primary_key};" if ServerType.server_type == ServerTypes.POSTGRES else ";SELECT LAST_INSERT_ID();"}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def create(self, obj: T_DBM, skip_editor=False) -> int:
|
async def create(self, obj: T_DBM, skip_editor=False) -> int:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from mysql.connector.aio import MySQLConnectionPool
|
|||||||
from cpl.core.environment import Environment
|
from cpl.core.environment import Environment
|
||||||
from cpl.database.logger import DBLogger
|
from cpl.database.logger import DBLogger
|
||||||
from cpl.database.model import DatabaseSettings
|
from cpl.database.model import DatabaseSettings
|
||||||
from cpl.dependency import ServiceProvider
|
from cpl.dependency.context import get_provider
|
||||||
|
|
||||||
|
|
||||||
class MySQLPool:
|
class MySQLPool:
|
||||||
@@ -18,7 +18,11 @@ class MySQLPool:
|
|||||||
"user": database_settings.user,
|
"user": database_settings.user,
|
||||||
"password": database_settings.password,
|
"password": database_settings.password,
|
||||||
"database": database_settings.database,
|
"database": database_settings.database,
|
||||||
"ssl_disabled": True,
|
"charset": database_settings.charset,
|
||||||
|
"use_unicode": database_settings.use_unicode,
|
||||||
|
"buffered": database_settings.buffered,
|
||||||
|
"auth_plugin": database_settings.auth_plugin,
|
||||||
|
"ssl_disabled": False,
|
||||||
}
|
}
|
||||||
self._pool: Optional[MySQLConnectionPool] = None
|
self._pool: Optional[MySQLConnectionPool] = None
|
||||||
|
|
||||||
|
|||||||
@@ -7,14 +7,16 @@ from cpl.database.model import Migration
|
|||||||
from cpl.database.model.server_type import ServerType, ServerTypes
|
from cpl.database.model.server_type import ServerType, ServerTypes
|
||||||
from cpl.database.schema.executed_migration import ExecutedMigration
|
from cpl.database.schema.executed_migration import ExecutedMigration
|
||||||
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
|
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
|
||||||
|
from cpl.dependency.hosted.startup_task import StartupTask
|
||||||
|
|
||||||
|
|
||||||
class MigrationService:
|
class MigrationService(StartupTask):
|
||||||
|
|
||||||
def __init__(self, logger: DBLogger, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao):
|
def __init__(self, logger: DBLogger, db: DBContextABC, executed_migration_dao: ExecutedMigrationDao):
|
||||||
|
StartupTask.__init__(self)
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._db = db
|
self._db = db
|
||||||
self._executedMigrationDao = executedMigrationDao
|
self._executed_migration_dao = executed_migration_dao
|
||||||
|
|
||||||
self._script_directories: list[str] = []
|
self._script_directories: list[str] = []
|
||||||
|
|
||||||
@@ -23,12 +25,15 @@ class MigrationService:
|
|||||||
elif ServerType.server_type == ServerTypes.MYSQL:
|
elif ServerType.server_type == ServerTypes.MYSQL:
|
||||||
self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/mysql"))
|
self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/mysql"))
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
await self._execute(self._load_scripts())
|
||||||
|
|
||||||
def with_directory(self, directory: str) -> "MigrationService":
|
def with_directory(self, directory: str) -> "MigrationService":
|
||||||
self._script_directories.append(directory)
|
self._script_directories.append(directory)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def _get_migration_history(self) -> list[ExecutedMigration]:
|
async def _get_migration_history(self) -> list[ExecutedMigration]:
|
||||||
results = await self._db.select(f"SELECT * FROM {self._executedMigrationDao.table_name}")
|
results = await self._db.select(f"SELECT * FROM {self._executed_migration_dao.table_name}")
|
||||||
applied_migrations = []
|
applied_migrations = []
|
||||||
for result in results:
|
for result in results:
|
||||||
applied_migrations.append(ExecutedMigration(result[0]))
|
applied_migrations.append(ExecutedMigration(result[0]))
|
||||||
@@ -91,7 +96,7 @@ class MigrationService:
|
|||||||
try:
|
try:
|
||||||
# check if table exists
|
# check if table exists
|
||||||
if len(result) > 0:
|
if len(result) > 0:
|
||||||
migration_from_db = await self._executedMigrationDao.find_by_id(migration.name)
|
migration_from_db = await self._executed_migration_dao.find_by_id(migration.name)
|
||||||
if migration_from_db is not None:
|
if migration_from_db is not None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -99,12 +104,9 @@ class MigrationService:
|
|||||||
|
|
||||||
await self._db.execute(migration.script, multi=True)
|
await self._db.execute(migration.script, multi=True)
|
||||||
|
|
||||||
await self._executedMigrationDao.create(ExecutedMigration(migration.name), skip_editor=True)
|
await self._executed_migration_dao.create(ExecutedMigration(migration.name), skip_editor=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._logger.fatal(
|
self._logger.fatal(
|
||||||
f"Migration failed: {migration.name}\n{active_statement}",
|
f"Migration failed: {migration.name}\n{active_statement}",
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def migrate(self):
|
|
||||||
await self._execute(self._load_scripts())
|
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
import contextvars
|
import contextvars
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
_current_provider = contextvars.ContextVar("current_provider", default=None)
|
|
||||||
|
_current_provider = contextvars.ContextVar("current_provider")
|
||||||
|
|
||||||
|
|
||||||
def use_root_provider(provider):
|
def use_root_provider(provider: "ServiceProvider"):
|
||||||
_current_provider.set(provider)
|
_current_provider.set(provider)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def use_provider(provider):
|
def use_provider(provider: "ServiceProvider"):
|
||||||
token = _current_provider.set(provider)
|
token = _current_provider.set(provider)
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
@@ -17,5 +18,5 @@ def use_provider(provider):
|
|||||||
_current_provider.reset(token)
|
_current_provider.reset(token)
|
||||||
|
|
||||||
|
|
||||||
def get_provider():
|
def get_provider() -> "ServiceProvider":
|
||||||
return _current_provider.get()
|
return _current_provider.get()
|
||||||
|
|||||||
6
src/cpl-dependency/cpl/dependency/hosted/startup_task.py
Normal file
6
src/cpl-dependency/cpl/dependency/hosted/startup_task.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class StartupTask(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def run(self): ...
|
||||||
@@ -3,6 +3,7 @@ from typing import Union, Type, Callable, Self
|
|||||||
from cpl.core.log.logger_abc import LoggerABC
|
from cpl.core.log.logger_abc import LoggerABC
|
||||||
from cpl.core.typing import T, Service
|
from cpl.core.typing import T, Service
|
||||||
from cpl.core.utils.cache import Cache
|
from cpl.core.utils.cache import Cache
|
||||||
|
from cpl.dependency.hosted.startup_task import StartupTask
|
||||||
from cpl.dependency.service_descriptor import ServiceDescriptor
|
from cpl.dependency.service_descriptor import ServiceDescriptor
|
||||||
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
|
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
|
||||||
from cpl.dependency.service_provider import ServiceProvider
|
from cpl.dependency.service_provider import ServiceProvider
|
||||||
@@ -61,6 +62,10 @@ class ServiceCollection:
|
|||||||
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service)
|
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def add_startup_task(self, task: Type[StartupTask]) -> Self:
|
||||||
|
self.add_singleton(StartupTask, task)
|
||||||
|
return self
|
||||||
|
|
||||||
def build(self) -> ServiceProvider:
|
def build(self) -> ServiceProvider:
|
||||||
sp = ServiceProvider(self._service_descriptors)
|
sp = ServiceProvider(self._service_descriptors)
|
||||||
return sp
|
return sp
|
||||||
|
|||||||
Reference in New Issue
Block a user