StartupTask #186
This commit is contained in:
@@ -9,7 +9,6 @@ from .table_manager import TableManager
|
||||
|
||||
|
||||
def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC:
|
||||
from cpl.application.host import Host
|
||||
from cpl.database.service.migration_service import MigrationService
|
||||
|
||||
migration_service = self._services.get_service(MigrationService)
|
||||
@@ -21,8 +20,6 @@ def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _Applica
|
||||
for path in paths:
|
||||
migration_service.with_directory(path)
|
||||
|
||||
Host.run(migration_service.migrate)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from cpl.database.abc.db_context_abc import DBContextABC
|
||||
from cpl.database.const import DATETIME_FORMAT
|
||||
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.model.server_type import ServerType, ServerTypes
|
||||
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
|
||||
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
|
||||
from cpl.dependency import get_provider
|
||||
@@ -351,13 +352,13 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
|
||||
values = f"{await self._get_editor_id(obj) if not skip_editor else ''}{f', {values}' if not skip_editor and len(values) > 0 else f'{values}'}"
|
||||
|
||||
return f"""
|
||||
INSERT INTO {self._table_name} (
|
||||
{fields}
|
||||
) VALUES (
|
||||
{values}
|
||||
)
|
||||
RETURNING {self.__primary_key};
|
||||
"""
|
||||
INSERT INTO {self._table_name} (
|
||||
{fields}
|
||||
) VALUES (
|
||||
{values}
|
||||
)
|
||||
{"RETURNING {self.__primary_key};" if ServerType.server_type == ServerTypes.POSTGRES else ";SELECT LAST_INSERT_ID();"}
|
||||
"""
|
||||
|
||||
async def create(self, obj: T_DBM, skip_editor=False) -> int:
|
||||
self._logger.debug(f"create {type(obj).__name__} {obj.__dict__}")
|
||||
|
||||
@@ -6,7 +6,7 @@ from mysql.connector.aio import MySQLConnectionPool
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.model import DatabaseSettings
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.dependency.context import get_provider
|
||||
|
||||
|
||||
class MySQLPool:
|
||||
@@ -18,7 +18,11 @@ class MySQLPool:
|
||||
"user": database_settings.user,
|
||||
"password": database_settings.password,
|
||||
"database": database_settings.database,
|
||||
"ssl_disabled": True,
|
||||
"charset": database_settings.charset,
|
||||
"use_unicode": database_settings.use_unicode,
|
||||
"buffered": database_settings.buffered,
|
||||
"auth_plugin": database_settings.auth_plugin,
|
||||
"ssl_disabled": False,
|
||||
}
|
||||
self._pool: Optional[MySQLConnectionPool] = None
|
||||
|
||||
|
||||
@@ -7,14 +7,16 @@ from cpl.database.model import Migration
|
||||
from cpl.database.model.server_type import ServerType, ServerTypes
|
||||
from cpl.database.schema.executed_migration import ExecutedMigration
|
||||
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
|
||||
from cpl.dependency.hosted.startup_task import StartupTask
|
||||
|
||||
|
||||
class MigrationService:
|
||||
class MigrationService(StartupTask):
|
||||
|
||||
def __init__(self, logger: DBLogger, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao):
|
||||
def __init__(self, logger: DBLogger, db: DBContextABC, executed_migration_dao: ExecutedMigrationDao):
|
||||
StartupTask.__init__(self)
|
||||
self._logger = logger
|
||||
self._db = db
|
||||
self._executedMigrationDao = executedMigrationDao
|
||||
self._executed_migration_dao = executed_migration_dao
|
||||
|
||||
self._script_directories: list[str] = []
|
||||
|
||||
@@ -23,12 +25,15 @@ class MigrationService:
|
||||
elif ServerType.server_type == ServerTypes.MYSQL:
|
||||
self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/mysql"))
|
||||
|
||||
async def run(self):
|
||||
await self._execute(self._load_scripts())
|
||||
|
||||
def with_directory(self, directory: str) -> "MigrationService":
|
||||
self._script_directories.append(directory)
|
||||
return self
|
||||
|
||||
async def _get_migration_history(self) -> list[ExecutedMigration]:
|
||||
results = await self._db.select(f"SELECT * FROM {self._executedMigrationDao.table_name}")
|
||||
results = await self._db.select(f"SELECT * FROM {self._executed_migration_dao.table_name}")
|
||||
applied_migrations = []
|
||||
for result in results:
|
||||
applied_migrations.append(ExecutedMigration(result[0]))
|
||||
@@ -91,7 +96,7 @@ class MigrationService:
|
||||
try:
|
||||
# check if table exists
|
||||
if len(result) > 0:
|
||||
migration_from_db = await self._executedMigrationDao.find_by_id(migration.name)
|
||||
migration_from_db = await self._executed_migration_dao.find_by_id(migration.name)
|
||||
if migration_from_db is not None:
|
||||
continue
|
||||
|
||||
@@ -99,12 +104,9 @@ class MigrationService:
|
||||
|
||||
await self._db.execute(migration.script, multi=True)
|
||||
|
||||
await self._executedMigrationDao.create(ExecutedMigration(migration.name), skip_editor=True)
|
||||
await self._executed_migration_dao.create(ExecutedMigration(migration.name), skip_editor=True)
|
||||
except Exception as e:
|
||||
self._logger.fatal(
|
||||
f"Migration failed: {migration.name}\n{active_statement}",
|
||||
e,
|
||||
)
|
||||
|
||||
async def migrate(self):
|
||||
await self._execute(self._load_scripts())
|
||||
|
||||
Reference in New Issue
Block a user