From 4625b626e6efc29fdc9c7ed01f147cb99418004a Mon Sep 17 00:00:00 2001 From: edraft Date: Tue, 16 Sep 2025 22:19:59 +0200 Subject: [PATCH] Added dao base --- README.md | 153 --- src/cpl-core/cpl/core/typing.py | 6 + src/cpl-core/cpl/core/utils/__init__.py | 2 +- .../cpl/core/utils/{b64.py => base64.py} | 2 +- src/cpl-database/cpl/database/__init__.py | 42 +- .../_external_data_temp_table_builder.py | 68 ++ src/cpl-database/cpl/database/abc/__init__.py | 5 + .../connection_abc.py} | 4 +- .../database/abc/data_access_object_abc.py | 875 ++++++++++++++++++ .../cpl/database/abc/db_context_abc.py | 53 ++ .../cpl/database/abc/db_join_model_abc.py | 30 + .../cpl/database/abc/db_model_abc.py | 79 ++ .../cpl/database/abc/db_model_dao_abc.py | 25 + .../cpl/database/{ => abc}/table_abc.py | 0 src/cpl-database/cpl/database/const.py | 1 + .../database/database_settings_name_enum.py | 13 - .../cpl/database/internal_tables.py | 15 + .../cpl/database/model/__init__.py | 3 + .../database/{ => model}/database_settings.py | 25 +- .../cpl/database/model/migration.py | 12 + .../cpl/database/model/server_type.py | 21 + .../database_connection.py => connection.py} | 6 +- .../cpl/database/mysql/connection/__init__.py | 2 - .../cpl/database/mysql/context/__init__.py | 2 - .../mysql/context/database_context.py | 52 -- .../mysql/context/database_context_abc.py | 40 - .../cpl/database/mysql/db_context.py | 84 ++ .../cpl/database/mysql/mysql_pool.py | 105 +++ .../cpl/database/postgres/__init__.py | 0 .../cpl/database/postgres/db_context.py | 86 ++ .../cpl/database/postgres/postgres_pool.py | 123 +++ .../database/postgres/sql_select_builder.py | 154 +++ .../cpl/database/schema/__init__.py | 0 .../cpl/database/schema/executed_migration.py | 18 + .../database/schema/executed_migration_dao.py | 14 + .../database/scripts/mysql/0-cpl-initial.sql | 6 + .../cpl/database/scripts/mysql/trigger.txt | 26 + .../scripts/postgres/0-cpl-initial.sql | 47 + .../cpl/database/service/__init__.py | 0 .../cpl/database/service/migration_service.py | 111 +++ src/cpl-database/cpl/database/typing.py | 65 ++ src/cpl-database/requirements.txt | 8 +- .../cpl/dependency/service_provider_abc.py | 4 + tests/custom/database/src/application.py | 28 +- .../src/appsettings.edrafts-lapi.json | 4 +- .../database/src/appsettings.edrafts-pc.json | 7 +- tests/custom/database/src/main.py | 12 +- tests/custom/database/src/model/city_model.py | 2 +- tests/custom/database/src/model/user.py | 16 + tests/custom/database/src/model/user_dao.py | 14 + tests/custom/database/src/model/user_model.py | 3 +- tests/custom/database/src/model/user_repo.py | 27 +- .../custom/database/src/scripts/0-initial.sql | 14 + tests/custom/database/src/startup.py | 25 +- 54 files changed, 2199 insertions(+), 340 deletions(-) rename src/cpl-core/cpl/core/utils/{b64.py => base64.py} (98%) create mode 100644 src/cpl-database/cpl/database/_external_data_temp_table_builder.py create mode 100644 src/cpl-database/cpl/database/abc/__init__.py rename src/cpl-database/cpl/database/{mysql/connection/database_connection_abc.py => abc/connection_abc.py} (89%) create mode 100644 src/cpl-database/cpl/database/abc/data_access_object_abc.py create mode 100644 src/cpl-database/cpl/database/abc/db_context_abc.py create mode 100644 src/cpl-database/cpl/database/abc/db_join_model_abc.py create mode 100644 src/cpl-database/cpl/database/abc/db_model_abc.py create mode 100644 src/cpl-database/cpl/database/abc/db_model_dao_abc.py rename src/cpl-database/cpl/database/{ => abc}/table_abc.py (100%) create mode 100644 src/cpl-database/cpl/database/const.py delete mode 100644 src/cpl-database/cpl/database/database_settings_name_enum.py create mode 100644 src/cpl-database/cpl/database/internal_tables.py create mode 100644 src/cpl-database/cpl/database/model/__init__.py rename src/cpl-database/cpl/database/{ => model}/database_settings.py (62%) create mode 100644 src/cpl-database/cpl/database/model/migration.py create mode 100644 src/cpl-database/cpl/database/model/server_type.py rename src/cpl-database/cpl/database/mysql/{connection/database_connection.py => connection.py} (90%) delete mode 100644 src/cpl-database/cpl/database/mysql/connection/__init__.py delete mode 100644 src/cpl-database/cpl/database/mysql/context/__init__.py delete mode 100644 src/cpl-database/cpl/database/mysql/context/database_context.py delete mode 100644 src/cpl-database/cpl/database/mysql/context/database_context_abc.py create mode 100644 src/cpl-database/cpl/database/mysql/db_context.py create mode 100644 src/cpl-database/cpl/database/mysql/mysql_pool.py create mode 100644 src/cpl-database/cpl/database/postgres/__init__.py create mode 100644 src/cpl-database/cpl/database/postgres/db_context.py create mode 100644 src/cpl-database/cpl/database/postgres/postgres_pool.py create mode 100644 src/cpl-database/cpl/database/postgres/sql_select_builder.py create mode 100644 src/cpl-database/cpl/database/schema/__init__.py create mode 100644 src/cpl-database/cpl/database/schema/executed_migration.py create mode 100644 src/cpl-database/cpl/database/schema/executed_migration_dao.py create mode 100644 src/cpl-database/cpl/database/scripts/mysql/0-cpl-initial.sql create mode 100644 src/cpl-database/cpl/database/scripts/mysql/trigger.txt create mode 100644 src/cpl-database/cpl/database/scripts/postgres/0-cpl-initial.sql create mode 100644 src/cpl-database/cpl/database/service/__init__.py create mode 100644 src/cpl-database/cpl/database/service/migration_service.py create mode 100644 src/cpl-database/cpl/database/typing.py create mode 100644 tests/custom/database/src/model/user.py create mode 100644 tests/custom/database/src/model/user_dao.py create mode 100644 tests/custom/database/src/scripts/0-initial.sql diff --git a/README.md b/README.md index 2b1a120d..e69de29b 100644 --- a/README.md +++ b/README.md @@ -1,153 +0,0 @@ -

CPL - Common python library

- - -

- -
- - CPL is a development platform for python server applications -
using Python.
-
-

- -## Table of Contents - -
    -
  1. Features
  2. -
  3. - Getting Started - -
  4. -
  5. Roadmap
  6. -
  7. Contributing
  8. -
  9. License
  10. -
  11. Contact
  12. -
- -## Features - -- Expandle -- Application base - - Standardized application classes - - Application object builder - - Application extension classes - - Startup classes - - Startup extension classes -- Configuration - - Configure via object mapped JSON - - Console argument handling -- Console class for in and output - - Banner - - Spinner - - Options (menu) - - Table - - Write - - Write_at - - Write_line - - Write_line_at -- Dependency injection - - Service lifetimes: singleton, scoped and transient -- Providing of application environment - - Environment (development, staging, testing, production) - - Appname - - Customer - - Hostname - - Runtime directory - - Working directory -- Logging - - Standardized logger - - Log-level (FATAL, ERROR, WARN, INFO, DEBUG & TRACE) -- Mail handling - - Send mails -- Pipe classes - - Convert input -- Utils - - Credential manager - - Encryption via BASE64 - - PIP wrapper class based on subprocess - - Run pip commands - - String converter to different variants - - to_lower_case - - to_camel_case - - ... - - -## Getting Started - -[Get started with CPL][quickstart]. - -### Prerequisites - -- Install [python] which includes [Pip installs packages][pip] - -### Installation - -Install the CPL package -```sh -pip install cpl-core --extra-index-url https://pip.sh-edraft.de -``` - -Install the CPL CLI -```sh -pip install cpl-cli --extra-index-url https://pip.sh-edraft.de -``` - -Create workspace: -```sh -cpl new -``` - -Run the application: -```sh -cd -cpl start -``` - - - -## Roadmap - -See the [open issues](https://git.sh-edraft.de/sh-edraft.de/sh_cpl/issues) for a list of proposed features (and known issues). - - - - -## Contributing - -### Contributing Guidelines - -Read through our [contributing guidelines][contributing] to learn about our submission process, coding rules and more. - -### Want to Help? - -Want to file a bug, contribute some code, or improve documentation? Excellent! Read up on our guidelines for [contributing][contributing]. - - - - -## License - -Distributed under the MIT License. See [LICENSE] for more information. - - - - -## Contact - -Sven Heidemann - sven.heidemann@sh-edraft.de - -Project link: [https://git.sh-edraft.de/sh-edraft.de/sh_common_py_lib](https://git.sh-edraft.de/sh-edraft.de/sh_cpl) - - -[pip_url]: https://pip.sh-edraft.de -[python]: https://www.python.org/ -[pip]: https://pypi.org/project/pip/ - - -[project]: https://git.sh-edraft.de/sh-edraft.de/sh_cpl -[quickstart]: https://git.sh-edraft.de/sh-edraft.de/sh_cpl/wiki/quickstart -[contributing]: https://git.sh-edraft.de/sh-edraft.de/sh_cpl/wiki/contributing -[license]: LICENSE diff --git a/src/cpl-core/cpl/core/typing.py b/src/cpl-core/cpl/core/typing.py index 63a1f8f2..b0a980a3 100644 --- a/src/cpl-core/cpl/core/typing.py +++ b/src/cpl-core/cpl/core/typing.py @@ -1,4 +1,5 @@ from typing import TypeVar, Any +from uuid import UUID T = TypeVar("T") D = TypeVar("D") @@ -8,3 +9,8 @@ Service = TypeVar("Service") Source = TypeVar("Source") Messages = list[Any] | Any + +UuidId = str | UUID +SerialId = int + +Id = UuidId | SerialId diff --git a/src/cpl-core/cpl/core/utils/__init__.py b/src/cpl-core/cpl/core/utils/__init__.py index 664d3a02..84bfba28 100644 --- a/src/cpl-core/cpl/core/utils/__init__.py +++ b/src/cpl-core/cpl/core/utils/__init__.py @@ -1,4 +1,4 @@ -from .b64 import B64 +from .base64 import Base64 from .credential_manager import CredentialManager from .json_processor import JSONProcessor from .pip import Pip diff --git a/src/cpl-core/cpl/core/utils/b64.py b/src/cpl-core/cpl/core/utils/base64.py similarity index 98% rename from src/cpl-core/cpl/core/utils/b64.py rename to src/cpl-core/cpl/core/utils/base64.py index 0292e6f6..fe47cbc7 100644 --- a/src/cpl-core/cpl/core/utils/b64.py +++ b/src/cpl-core/cpl/core/utils/base64.py @@ -2,7 +2,7 @@ import base64 from typing import Union -class B64: +class Base64: @staticmethod def encode(string: str) -> str: diff --git a/src/cpl-database/cpl/database/__init__.py b/src/cpl-database/cpl/database/__init__.py index 1d0e9173..a982d815 100644 --- a/src/cpl-database/cpl/database/__init__.py +++ b/src/cpl-database/cpl/database/__init__.py @@ -1,23 +1,41 @@ +from typing import Type + from cpl.dependency import ServiceCollection as _ServiceCollection -from . import mysql -from .database_settings import DatabaseSettings -from .database_settings_name_enum import DatabaseSettingsNameEnum -from .mysql.context import DatabaseContextABC, DatabaseContext -from .table_abc import TableABC +from . import mysql as _mysql +from . import postgres as _postgres +from .internal_tables import InternalTables -def add_mysql(collection: _ServiceCollection): +def _add(collection: _ServiceCollection,db_context: Type, default_port: int, server_type: str): from cpl.core.console import Console from cpl.core.configuration import Configuration + from cpl.database.abc.db_context_abc import DBContextABC + from cpl.database.model.server_type import ServerTypes, ServerType + from cpl.database.model.database_settings import DatabaseSettings + from cpl.database.service.migration_service import MigrationService + from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao try: - collection.add_singleton(DatabaseContextABC, DatabaseContext) - database_context = collection.build_service_provider().get_service(DatabaseContextABC) + ServerType.set_server_type(ServerTypes(server_type)) + Configuration.set("DB_DEFAULT_PORT", default_port) - db_settings: DatabaseSettings = Configuration.get(DatabaseSettings) - database_context.connect(db_settings) + collection.add_singleton(DBContextABC, db_context) + collection.add_singleton(ExecutedMigrationDao) + collection.add_singleton(MigrationService) except ImportError as e: - Console.error("cpl-translation is not installed", str(e)) + Console.error("cpl-database is not installed", str(e)) + +def add_mysql(collection: _ServiceCollection): + from cpl.database.mysql.db_context import DBContext + from cpl.database.model import ServerTypes + _add(collection, DBContext, 3306, ServerTypes.MYSQL.value) -_ServiceCollection.with_module(add_mysql, mysql.__name__) +def add_postgres(collection: _ServiceCollection): + from cpl.database.mysql.db_context import DBContext + from cpl.database.model import ServerTypes + _add(collection, DBContext, 5432, ServerTypes.POSTGRES.value) + + +_ServiceCollection.with_module(add_mysql, _mysql.__name__) +_ServiceCollection.with_module(add_postgres, _postgres.__name__) diff --git a/src/cpl-database/cpl/database/_external_data_temp_table_builder.py b/src/cpl-database/cpl/database/_external_data_temp_table_builder.py new file mode 100644 index 00000000..588630b4 --- /dev/null +++ b/src/cpl-database/cpl/database/_external_data_temp_table_builder.py @@ -0,0 +1,68 @@ +import textwrap +from typing import Callable + + +class ExternalDataTempTableBuilder: + + def __init__(self): + self._table_name = None + self._fields: dict[str, str] = {} + self._primary_key = "id" + self._join_ref_table = None + self._value_getter = None + + @property + def table_name(self) -> str: + return self._table_name + + @property + def fields(self) -> dict[str, str]: + return self._fields + + @property + def primary_key(self) -> str: + return self._primary_key + + @property + def join_ref_table(self) -> str: + return self._join_ref_table + + def with_table_name(self, table_name: str) -> "ExternalDataTempTableBuilder": + self._join_ref_table = table_name + + if "." in table_name: + table_name = table_name.split(".")[-1] + + if not table_name.endswith("_temp"): + table_name = f"{table_name}_temp" + + self._table_name = table_name + return self + + def with_field(self, name: str, sql_type: str, primary=False) -> "ExternalDataTempTableBuilder": + if primary: + sql_type += " PRIMARY KEY" + self._primary_key = name + self._fields[name] = sql_type + return self + + def with_value_getter(self, value_getter: Callable) -> "ExternalDataTempTableBuilder": + self._value_getter = value_getter + return self + + async def build(self) -> str: + assert self._table_name is not None, "Table name is required" + assert self._value_getter is not None, "Value getter is required" + + values_str = ", ".join([f"{value}" for value in await self._value_getter()]) + + return textwrap.dedent( + f""" + DROP TABLE IF EXISTS {self._table_name}; + CREATE TEMP TABLE {self._table_name} ( + {", ".join([f"{k} {v}" for k, v in self._fields.items()])} + ); + + INSERT INTO {self._table_name} VALUES {values_str}; + """ + ) diff --git a/src/cpl-database/cpl/database/abc/__init__.py b/src/cpl-database/cpl/database/abc/__init__.py new file mode 100644 index 00000000..05508900 --- /dev/null +++ b/src/cpl-database/cpl/database/abc/__init__.py @@ -0,0 +1,5 @@ +from .connection_abc import ConnectionABC +from .db_context_abc import DBContextABC +from .db_join_model_abc import DbJoinModelABC +from .db_model_abc import DbModelABC +from .db_model_dao_abc import DbModelDaoABC diff --git a/src/cpl-database/cpl/database/mysql/connection/database_connection_abc.py b/src/cpl-database/cpl/database/abc/connection_abc.py similarity index 89% rename from src/cpl-database/cpl/database/mysql/connection/database_connection_abc.py rename to src/cpl-database/cpl/database/abc/connection_abc.py index 6980eeea..a0fab905 100644 --- a/src/cpl-database/cpl/database/mysql/connection/database_connection_abc.py +++ b/src/cpl-database/cpl/database/abc/connection_abc.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod -from cpl.database.database_settings import DatabaseSettings +from cpl.database.model.database_settings import DatabaseSettings from mysql.connector.abstracts import MySQLConnectionAbstract from mysql.connector.cursor import MySQLCursorBuffered -class DatabaseConnectionABC(ABC): +class ConnectionABC(ABC): r"""ABC for the :class:`cpl.database.connection.database_connection.DatabaseConnection`""" @abstractmethod diff --git a/src/cpl-database/cpl/database/abc/data_access_object_abc.py b/src/cpl-database/cpl/database/abc/data_access_object_abc.py new file mode 100644 index 00000000..ed7ebad0 --- /dev/null +++ b/src/cpl-database/cpl/database/abc/data_access_object_abc.py @@ -0,0 +1,875 @@ +import datetime +from abc import ABC, abstractmethod +from enum import Enum +from types import NoneType +from typing import Generic, Optional, Union, Type, List, Any + +from cpl.core.typing import T, Id +from cpl.core.utils import String +from cpl.core.utils.get_value import get_value +from cpl.database._external_data_temp_table_builder import ExternalDataTempTableBuilder +from cpl.database.abc.db_context_abc import DBContextABC +from cpl.database.const import DATETIME_FORMAT +from cpl.database.db_logger import DBLogger +from cpl.database.postgres.sql_select_builder import SQLSelectBuilder +from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts + + +class DataAccessObjectABC(ABC, Generic[T_DBM]): + + @abstractmethod + def __init__(self, source: str, model_type: Type[T_DBM], table_name: str): + from cpl.dependency.service_provider_abc import ServiceProviderABC + + self._db = ServiceProviderABC.get_global_provider().get_service(DBContextABC) + + self._logger = DBLogger(source) + self._model_type = model_type + self._table_name = table_name + + self._logger = DBLogger(source) + self._model_type = model_type + self._table_name = table_name + + self._default_filter_condition = None + + self.__attributes: dict[str, type] = {} + + self.__db_names: dict[str, str] = {} + self.__foreign_tables: dict[str, tuple[str, str]] = {} + self.__foreign_table_keys: dict[str, str] = {} + self.__foreign_dao: dict[str, "DataAccessObjectABC"] = {} + + self.__date_attributes: set[str] = set() + self.__ignored_attributes: set[str] = set() + + self.__primary_key = "id" + self.__primary_key_type = int + self._external_fields: dict[str, ExternalDataTempTableBuilder] = {} + + @property + def table_name(self) -> str: + return self._table_name + + def has_attribute(self, attr_name: Attribute) -> bool: + """ + Check if the attribute exists in the DAO + :param Attribute attr_name: Name of the attribute + :return: True if the attribute exists, False otherwise + """ + return attr_name in self.__attributes + + def attribute( + self, + attr_name: Attribute, + attr_type: type, + db_name: str = None, + ignore=False, + primary_key=False, + aliases: list[str] = None, + ): + """ + Add an attribute for db and object mapping to the data access object + :param Attribute attr_name: Name of the attribute in the object + :param type attr_type: Python type of the attribute to cast db value to + :param str db_name: Name of the field in the database, if None the attribute lowered attr_name without "_" is used + :param bool ignore: Defines if field is ignored for create and update (for e.g. auto increment fields or created/updated fields) + :param bool primary_key: Defines if field is the primary key + :param list[str] aliases: List of aliases for the attribute name + :return: + """ + if isinstance(attr_name, property): + attr_name = attr_name.fget.__name__ + + self.__attributes[attr_name] = attr_type + if ignore: + self.__ignored_attributes.add(attr_name) + + if not db_name: + db_name = attr_name.lower().replace("_", "") + + self.__db_names[attr_name] = db_name + self.__db_names[db_name] = db_name + + if aliases is not None: + for alias in aliases: + if alias in self.__db_names: + raise ValueError(f"Alias {alias} already exists") + self.__db_names[alias] = db_name + + if primary_key: + self.__primary_key = db_name + self.__primary_key_type = attr_type + + if attr_type in [datetime, datetime.datetime]: + self.__date_attributes.add(attr_name) + self.__date_attributes.add(db_name) + + def reference( + self, + attr: Attribute, + primary_attr: Attribute, + foreign_attr: Attribute, + table_name: str, + reference_dao: "DataAccessObjectABC" = None, + ): + """ + Add a reference to another table for the given attribute + :param Attribute attr: Name of the attribute in the object + :param str primary_attr: Name of the primary key in the foreign object + :param str foreign_attr: Name of the foreign key in the object + :param str table_name: Name of the table to reference + :param DataAccessObjectABC reference_dao: The data access object for the referenced table + :return: + """ + if isinstance(attr, property): + attr = attr.fget.__name__ + + if isinstance(primary_attr, property): + primary_attr = primary_attr.fget.__name__ + + primary_attr = primary_attr.lower().replace("_", "") + + if isinstance(foreign_attr, property): + foreign_attr = foreign_attr.fget.__name__ + + foreign_attr = foreign_attr.lower().replace("_", "") + + self.__foreign_table_keys[attr] = foreign_attr + if reference_dao is not None: + self.__foreign_dao[attr] = reference_dao + + if table_name == self._table_name: + return + + self.__foreign_tables[attr] = ( + table_name, + f"{table_name}.{primary_attr} = {self._table_name}.{foreign_attr}", + ) + + def use_external_fields(self, builder: ExternalDataTempTableBuilder): + self._external_fields[builder.table_name] = builder + + def to_object(self, result: dict) -> T_DBM: + """ + Convert a result from the database to an object + :param dict result: Result from the database + :return: + """ + value_map: dict[str, T] = {} + + for db_name, value in result.items(): + # Find the attribute name corresponding to the db_name + attr_name = next((k for k, v in self.__db_names.items() if v == db_name), None) + if attr_name: + value_map[attr_name] = self._get_value_from_sql(self.__attributes[attr_name], value) + + return self._model_type(**value_map) + + def to_dict(self, obj: T_DBM) -> dict: + """ + Convert an object to a dictionary + :param T_DBM obj: Object to convert + :return: + """ + value_map: dict[str, Any] = {} + + for attr_name, attr_type in self.__attributes.items(): + value = getattr(obj, attr_name) + if isinstance(value, datetime.datetime): + value = value.strftime(DATETIME_FORMAT) + elif isinstance(value, Enum): + value = value.value + + value_map[attr_name] = value + + for ex_fname in self._external_fields: + ex_field = self._external_fields[ex_fname] + for ex_attr in ex_field.fields: + if ex_attr == self.__primary_key: + continue + + value_map[ex_attr] = getattr(obj, ex_attr, None) + + return value_map + + async def count(self, filters: AttributeFilters = None) -> int: + result = await self._prepare_query(filters=filters, for_count=True) + return result[0]["count"] if result else 0 + + async def get_history( + self, + entry_id: int, + by_key: str = None, + when: datetime = None, + until: datetime = None, + without_deleted: bool = False, + ) -> list[T_DBM]: + """ + Retrieve the history of an entry from the history table. + :param entry_id: The ID of the entry to retrieve history for. + :param by_key: The key to filter by (default is the primary key). + :param when: A specific timestamp to filter the history. + :param until: A timestamp to filter history entries up to a certain point. + :param without_deleted: Exclude deleted entries if True. + :return: A list of historical entries as objects. + """ + f_tables = list(self.__foreign_tables.keys()) + + history_table = f"{self._table_name}_history" + builder = SQLSelectBuilder(history_table, self.__primary_key) + + builder.with_attribute("*") + builder.with_value_condition( + f"{history_table}.{by_key or self.__primary_key}", + "=", + str(entry_id), + f_tables, + ) + + if self._default_filter_condition: + builder.with_condition(self._default_filter_condition, "", f_tables) + + if without_deleted: + builder.with_value_condition(f"{history_table}.deleted", "=", "false", f_tables) + + if when: + builder.with_value_condition( + self._attr_from_date_to_char(f"{history_table}.updated"), + "=", + f"'{when.strftime(DATETIME_FORMAT)}'", + f_tables, + ) + + if until: + builder.with_value_condition( + self._attr_from_date_to_char(f"{history_table}.updated"), + "<=", + f"'{until.strftime(DATETIME_FORMAT)}'", + f_tables, + ) + + builder.with_order_by(f"{history_table}.updated", "DESC") + + query = await builder.build() + result = await self._db.select_map(query) + return [self.to_object(x) for x in result] if result else [] + + async def get_all(self) -> List[T_DBM]: + result = await self._prepare_query(sorts=[{self.__primary_key: "asc"}]) + return [self.to_object(x) for x in result] if result else [] + + async def get_by_id(self, id: Union[int, str]) -> Optional[T_DBM]: + result = await self._prepare_query(filters=[{self.__primary_key: id}], sorts=[{self.__primary_key: "asc"}]) + return self.to_object(result[0]) if result else None + + async def find_by_id(self, id: Union[int, str]) -> Optional[T_DBM]: + result = await self._prepare_query(filters=[{self.__primary_key: id}], sorts=[{self.__primary_key: "asc"}]) + return self.to_object(result[0]) if result else None + + async def get_by( + self, + filters: AttributeFilters = None, + sorts: AttributeSorts = None, + take: int = None, + skip: int = None, + ) -> list[T_DBM]: + result = await self._prepare_query(filters, sorts, take, skip) + if not result or len(result) == 0: + raise ValueError("No result found") + return [self.to_object(x) for x in result] if result else [] + + async def get_single_by( + self, + filters: AttributeFilters = None, + sorts: AttributeSorts = None, + take: int = None, + skip: int = None, + ) -> T_DBM: + result = await self._prepare_query(filters, sorts, take, skip) + if not result: + raise ValueError("No result found") + if len(result) > 1: + raise ValueError("More than one result found") + return self.to_object(result[0]) + + async def find_by( + self, + filters: AttributeFilters = None, + sorts: AttributeSorts = None, + take: int = None, + skip: int = None, + ) -> list[T_DBM]: + result = await self._prepare_query(filters, sorts, take, skip) + return [self.to_object(x) for x in result] if result else [] + + async def find_single_by( + self, + filters: AttributeFilters = None, + sorts: AttributeSorts = None, + take: int = None, + skip: int = None, + ) -> Optional[T_DBM]: + result = await self._prepare_query(filters, sorts, take, skip) + if len(result) > 1: + raise ValueError("More than one result found") + return self.to_object(result[0]) if result else None + + async def touch(self, obj: T_DBM): + """ + Touch the entry to update the last updated date + :return: + """ + await self._db.execute( + f""" + UPDATE {self._table_name} + SET updated = NOW() + WHERE {self.__primary_key} = {self._get_primary_key_value_sql(obj)}; + """ + ) + + async def touch_many_by_id(self, ids: list[Id]): + """ + Touch the entries to update the last updated date + :return: + """ + if len(ids) == 0: + return + + await self._db.execute( + f""" + UPDATE {self._table_name} + SET updated = NOW() + WHERE {self.__primary_key} IN ({", ".join([str(x) for x in ids])}); + """ + ) + + async def _build_create_statement(self, obj: T_DBM, skip_editor=False) -> str: + allowed_fields = [x for x in self.__attributes.keys() if x not in self.__ignored_attributes] + + fields = ", ".join([self.__db_names[x] for x in allowed_fields]) + fields = f"{'EditorId' if not skip_editor else ''}{f', {fields}' if not skip_editor and len(fields) > 0 else f'{fields}'}" + + values = ", ".join([self._get_value_sql(getattr(obj, x)) for x in allowed_fields]) + values = f"{await self._get_editor_id(obj) if not skip_editor else ''}{f', {values}' if not skip_editor and len(values) > 0 else f'{values}'}" + + return f""" + INSERT INTO {self._table_name} ( + {fields} + ) VALUES ( + {values} + ) + RETURNING {self.__primary_key}; + """ + + async def create(self, obj: T_DBM, skip_editor=False) -> int: + self._logger.debug(f"create {type(obj).__name__} {obj.__dict__}") + + result = await self._db.execute(await self._build_create_statement(obj, skip_editor)) + return self._get_value_from_sql(self.__primary_key_type, result[0][0]) + + async def create_many(self, objs: list[T_DBM], skip_editor=False) -> list[int]: + if len(objs) == 0: + return [] + self._logger.debug(f"create many {type(objs[0]).__name__} {len(objs)} {[x.__dict__ for x in objs]}") + + query = "" + for obj in objs: + query += await self._build_create_statement(obj, skip_editor) + + result = await self._db.execute(query) + return [self._get_value_from_sql(self.__primary_key_type, x[0]) for x in result] + + async def _build_update_statement(self, obj: T_DBM, skip_editor=False) -> str: + allowed_fields = [x for x in self.__attributes.keys() if x not in self.__ignored_attributes] + + fields = ", ".join( + [f"{self.__db_names[x]} = {self._get_value_sql(getattr(obj, x, None))}" for x in allowed_fields] + ) + fields = f"{f'EditorId = {await self._get_editor_id(obj)}' if not skip_editor else ''}{f', {fields}' if not skip_editor and len(fields) > 0 else f'{fields}'}" + + return f""" + UPDATE {self._table_name} + SET {fields} + WHERE {self.__primary_key} = {self._get_primary_key_value_sql(obj)}; + """ + + async def update(self, obj: T_DBM, skip_editor=False): + self._logger.debug(f"update {type(obj).__name__} {obj.__dict__}") + await self._db.execute(await self._build_update_statement(obj, skip_editor)) + + async def update_many(self, objs: list[T_DBM], skip_editor=False): + if len(objs) == 0: + return + self._logger.debug(f"update many {type(objs[0]).__name__} {len(objs)} {[x.__dict__ for x in objs]}") + + query = "" + for obj in objs: + query += await self._build_update_statement(obj, skip_editor) + + await self._db.execute(query) + + async def _build_delete_statement(self, obj: T_DBM, hard_delete: bool = False) -> str: + if hard_delete: + return f""" + DELETE FROM {self._table_name} + WHERE {self.__primary_key} = {self._get_primary_key_value_sql(obj)}; + """ + + return f""" + UPDATE {self._table_name} + SET EditorId = {await self._get_editor_id(obj)}, + Deleted = true + WHERE {self.__primary_key} = {self._get_primary_key_value_sql(obj)}; + """ + + async def delete(self, obj: T_DBM, hard_delete: bool = False): + self._logger.debug(f"delete {type(obj).__name__} {obj.__dict__}") + await self._db.execute(await self._build_delete_statement(obj, hard_delete)) + + async def delete_many(self, objs: list[T_DBM], hard_delete: bool = False): + if len(objs) == 0: + return + self._logger.debug(f"delete many {type(objs[0]).__name__} {len(objs)} {[x.__dict__ for x in objs]}") + + query = "" + for obj in objs: + query += await self._build_delete_statement(obj, hard_delete) + + await self._db.execute(query) + + async def _build_restore_statement(self, obj: T_DBM) -> str: + return f""" + UPDATE {self._table_name} + SET EditorId = {await self._get_editor_id(obj)}, + Deleted = false + WHERE {self.__primary_key} = {self._get_primary_key_value_sql(obj)}; + """ + + async def restore(self, obj: T_DBM): + self._logger.debug(f"restore {type(obj).__name__} {obj.__dict__}") + await self._db.execute(await self._build_restore_statement(obj)) + + async def restore_many(self, objs: list[T_DBM]): + if len(objs) == 0: + return + self._logger.debug(f"restore many {type(objs[0]).__name__} {len(objs)} {objs[0].__dict__}") + + query = "" + for obj in objs: + query += await self._build_restore_statement(obj) + + await self._db.execute(query) + + async def _prepare_query( + self, + filters: AttributeFilters = None, + sorts: AttributeSorts = None, + take: int = None, + skip: int = None, + for_count=False, + ) -> list[dict]: + """ + Prepares and executes a query using the SQLBuilder with the given parameters. + :param filters: Conditions to filter the query. + :param sorts: Sorting attributes and directions. + :param take: Limit the number of results. + :param skip: Offset the results. + :return: Query result as a list of dictionaries. + """ + external_table_deps = [] + builder = SQLSelectBuilder(self._table_name, self.__primary_key) + + for temp in self._external_fields: + builder.with_temp_table(self._external_fields[temp]) + + if for_count: + builder.with_attribute("COUNT(*)", ignore_table_name=True) + else: + builder.with_attribute("*") + + for attr in self.__foreign_tables: + table, join_condition = self.__foreign_tables[attr] + builder.with_left_join(table, join_condition) + + if filters: + await self._build_conditions(builder, filters, external_table_deps) + + if sorts: + self._build_sorts(builder, sorts, external_table_deps) + + if take: + builder.with_limit(take) + + if skip: + builder.with_offset(skip) + + for external_table in external_table_deps: + builder.use_temp_table(external_table) + + query = await builder.build() + return await self._db.select_map(query) + + async def _build_conditions( + self, + builder: SQLSelectBuilder, + filters: AttributeFilters, + external_table_deps: list[str], + ): + """ + Builds SQL conditions from GraphQL-like filters and adds them to the SQLBuilder. + :param builder: The SQLBuilder instance to add conditions to. + :param filters: GraphQL-like filter structure. + :param external_table_deps: List to store external table dependencies. + """ + if not isinstance(filters, list): + filters = [filters] + + for filter_group in filters: + sql_conditions = self._graphql_to_sql_conditions(filter_group, external_table_deps) + for attr, operator, value in sql_conditions: + if attr in self.__foreign_table_keys: + attr = self.__foreign_table_keys[attr] + + recursive_join = self._get_recursive_reference_join(attr) + if recursive_join is not None: + builder.with_left_join(*recursive_join) + + external_table = self._get_external_field_key(attr) + if external_table is not None: + external_table_deps.append(external_table) + + if operator == "fuzzy": + builder.with_levenshtein_condition(attr) + elif operator in [ + "IS NULL", + "IS NOT NULL", + ]: # operator without value + builder.with_condition( + attr, + operator, + [ + x[0] + for fdao in self.__foreign_dao + for x in self.__foreign_dao[fdao].__foreign_tables.values() + ], + ) + else: + if attr in self.__date_attributes or String.to_snake_case(attr) in self.__date_attributes: + attr = self._attr_from_date_to_char(f"{self._table_name}.{attr}") + + builder.with_value_condition( + attr, + operator, + self._get_value_sql(value), + [ + x[0] + for fdao in self.__foreign_dao + for x in self.__foreign_dao[fdao].__foreign_tables.values() + ], + ) + + def _graphql_to_sql_conditions( + self, graphql_structure: dict, external_table_deps: list[str] + ) -> list[tuple[str, str, Any]]: + """ + Converts a GraphQL-like structure to SQL conditions. + :param graphql_structure: The GraphQL-like filter structure. + :param external_table_deps: List to track external table dependencies. + :return: A list of tuples (attribute, operator, value). + """ + + operators = { + "equal": "=", + "notEqual": "!=", + "greater": ">", + "greaterOrEqual": ">=", + "less": "<", + "lessOrEqual": "<=", + "isNull": "IS NULL", + "isNotNull": "IS NOT NULL", + "contains": "LIKE", # Special handling in _graphql_to_sql_conditions + "notContains": "NOT LIKE", # Special handling in _graphql_to_sql_conditions + "startsWith": "LIKE", # Special handling in _graphql_to_sql_conditions + "endsWith": "LIKE", # Special handling in _graphql_to_sql_conditions + "in": "IN", + "notIn": "NOT IN", + } + conditions = [] + + def parse_node(node, parent_key=None, parent_dao=None): + if not isinstance(node, dict): + return + + if isinstance(node, list): + conditions.append((parent_key, "IN", node)) + return + + for key, value in node.items(): + if isinstance(key, property): + key = key.fget.__name__ + + external_fields_table_name_by_parent = self._get_external_field_key(parent_key) + external_fields_table_name = self._get_external_field_key(key) + external_field = ( + external_fields_table_name + if external_fields_table_name_by_parent is None + else external_fields_table_name_by_parent + ) + + if key == "fuzzy": + self._handle_fuzzy_filter_conditions(conditions, external_table_deps, value) + elif parent_dao is not None and key in parent_dao.__db_names: + parse_node(value, f"{parent_dao.table_name}.{key}") + continue + + elif external_field is not None: + external_table_deps.append(external_field) + parse_node(value, f"{external_field}.{key}") + elif parent_key in self.__foreign_table_keys: + if key in operators: + parse_node({key: value}, self.__foreign_table_keys[parent_key]) + continue + + if parent_key in self.__foreign_dao: + foreign_dao = self.__foreign_dao[parent_key] + if key in foreign_dao.__foreign_tables: + parse_node( + value, + f"{self.__foreign_tables[parent_key][0]}.{foreign_dao.__foreign_table_keys[key]}", + foreign_dao.__foreign_dao[key], + ) + continue + + if parent_key in self.__foreign_tables: + parse_node(value, f"{self.__foreign_tables[parent_key][0]}.{key}") + continue + + parse_node({parent_key: value}) + elif key in operators: + operator = operators[key] + if key == "contains" or key == "notContains": + value = f"%{value}%" + elif key == "in" or key == "notIn": + value = value + elif key == "startsWith": + value = f"{value}%" + elif key == "endsWith": + value = f"%{value}" + elif key == "isNull" or key == "isNotNull": + is_null_value = value.get("equal", None) if isinstance(value, dict) else value + + if is_null_value is None: + operator = operators[key] + elif (key == "isNull" and is_null_value) or (key == "isNotNull" and not is_null_value): + operator = "IS NULL" + else: + operator = "IS NOT NULL" + + conditions.append((parent_key, operator, None)) + elif (key == "equal" or key == "notEqual") and value is None: + operator = operators["isNull"] + + conditions.append((parent_key, operator, value)) + + elif isinstance(value, dict): + if key in self.__foreign_table_keys: + parse_node(value, key) + elif key in self.__db_names and parent_key is not None: + parse_node({f"{parent_key}": value}) + elif key in self.__db_names: + parse_node(value, self.__db_names[key]) + else: + parse_node(value, key) + elif value is None: + conditions.append((self.__db_names[key], "IS NULL", value)) + else: + conditions.append((self.__db_names[key], "=", value)) + + parse_node(graphql_structure) + return conditions + + def _handle_fuzzy_filter_conditions(self, conditions, external_field_table_deps, sub_values): + # Extract fuzzy filter parameters + fuzzy_fields = get_value(sub_values, "fields", list[str]) + fuzzy_term = get_value(sub_values, "term", str) + fuzzy_threshold = get_value(sub_values, "threshold", int, 5) + + if not fuzzy_fields or not fuzzy_term: + raise ValueError("Fuzzy filter must include 'fields' and 'term'.") + + fuzzy_fields_db_names = [] + + # Map fields to their database names + for fuzzy_field in fuzzy_fields: + external_fields_table_name = self._get_external_field_key(fuzzy_field) + if external_fields_table_name is not None: + external_fields_table = self._external_fields[external_fields_table_name] + fuzzy_fields_db_names.append(f"{external_fields_table.table_name}.{fuzzy_field}") + external_field_table_deps.append(external_fields_table.table_name) + elif fuzzy_field in self.__db_names: + fuzzy_fields_db_names.append(f"{self._table_name}.{self.__db_names[fuzzy_field]}") + elif fuzzy_field in self.__foreign_tables: + fuzzy_fields_db_names.append(f"{self._table_name}.{self.__foreign_table_keys[fuzzy_field]}") + else: + fuzzy_fields_db_names.append(self.__db_names[String.to_snake_case(fuzzy_field)][0]) + + # Build fuzzy conditions for each field + fuzzy_conditions = self._build_fuzzy_conditions(fuzzy_fields_db_names, fuzzy_term, fuzzy_threshold) + + # Combine conditions with OR and append to the main conditions + conditions.append((f"({' OR '.join(fuzzy_conditions)})", "fuzzy", None)) + + @staticmethod + def _build_fuzzy_conditions(fields: list[str], term: str, threshold: int = 10) -> list[str]: + conditions = [] + for field in fields: + conditions.append(f"levenshtein({field}::TEXT, '{term}') <= {threshold}") # Adjust the threshold as needed + + return conditions + + def _get_external_field_key(self, field_name: str) -> Optional[str]: + """ + Returns the key to get the external field if found, otherwise None. + :param str field_name: The name of the field to search for. + :return: The key if found, otherwise None. + :rtype: Optional[str] + """ + if field_name is None: + return None + + for key, builder in self._external_fields.items(): + if field_name in builder.fields and field_name not in self.__db_names: + return key + + return None + + def _get_recursive_reference_join(self, attr: str) -> Optional[tuple[str, str]]: + parts = attr.split(".") + table_name = ".".join(parts[:-1]) + + if table_name == self._table_name or table_name == "": + return None + + all_foreign_tables = { + x[0]: x[1] + for x in [ + *[x for x in self.__foreign_tables.values() if x[0] != self._table_name], + *[x for fdao in self.__foreign_dao for x in self.__foreign_dao[fdao].__foreign_tables.values()], + ] + } + + if not table_name in all_foreign_tables: + return None + + return table_name, all_foreign_tables[table_name] + + def _build_sorts( + self, + builder: SQLSelectBuilder, + sorts: AttributeSorts, + external_table_deps: list[str], + ): + """ + Resolves complex sorting structures into SQL-compatible sorting conditions. + Tracks external table dependencies. + :param builder: The SQLBuilder instance to add sorting to. + :param sorts: Sorting attributes and directions in a complex structure. + :param external_table_deps: List to track external table dependencies. + """ + + def parse_sort_node(node, parent_key=None): + if isinstance(node, dict): + for key, value in node.items(): + if isinstance(value, dict): + # Recursively parse nested structures + parse_sort_node(value, key) + elif isinstance(value, str) and value.lower() in ["asc", "desc"]: + external_table = self._get_external_field_key(key) + if external_table: + external_table_deps.append(external_table) + key = f"{external_table}.{key}" + + if parent_key in self.__foreign_tables: + key = f"{self.__foreign_tables[parent_key][0]}.{key}" + builder.with_order_by(key, value.upper()) + else: + raise ValueError(f"Invalid sort direction: {value}") + elif isinstance(node, list): + for item in node: + parse_sort_node(item) + else: + raise ValueError(f"Invalid sort structure: {node}") + + parse_sort_node(sorts) + + def _get_value_sql(self, value: Any) -> str: + if isinstance(value, str): + if value.lower() == "null": + return "NULL" + return f"'{value}'" + + if isinstance(value, NoneType): + return "NULL" + + if value is None: + return "NULL" + + if isinstance(value, Enum): + return f"'{value.value}'" + + if isinstance(value, bool): + return "true" if value else "false" + + if isinstance(value, list): + if len(value) == 0: + return "()" + return f"({', '.join([self._get_value_sql(x) for x in value])})" + + if isinstance(value, datetime.datetime): + if value.tzinfo is None: + value = value.replace(tzinfo=datetime.timezone.utc) + + return f"'{value.strftime(DATETIME_FORMAT)}'" + + return str(value) + + @staticmethod + def _get_value_from_sql(cast_type: type, value: Any) -> Optional[T]: + """ + Get the value from the query result and cast it to the correct type + :param type cast_type: + :param Any value: + :return Optional[T]: Casted value, when value is str "NULL" None is returned + """ + if isinstance(value, str) and "NULL" in value: + return None + + if isinstance(value, NoneType): + return None + + if isinstance(value, cast_type): + return value + + return cast_type(value) + + def _get_primary_key_value_sql(self, obj: T_DBM) -> str: + value = getattr(obj, self.__primary_key) + if isinstance(value, str): + return f"'{value}'" + + return value + + @staticmethod + def _attr_from_date_to_char(attr: str) -> str: + return f"TO_CHAR({attr}, 'YYYY-MM-DD HH24:MI:SS.US TZ')" + + @staticmethod + async def _get_editor_id(obj: T_DBM): + editor_id = obj.editor_id + # if editor_id is None: + # user = get_user() + # if user is not None: + # editor_id = user.id + + return editor_id if editor_id is not None else "NULL" diff --git a/src/cpl-database/cpl/database/abc/db_context_abc.py b/src/cpl-database/cpl/database/abc/db_context_abc.py new file mode 100644 index 00000000..cd0cc7be --- /dev/null +++ b/src/cpl-database/cpl/database/abc/db_context_abc.py @@ -0,0 +1,53 @@ +from abc import ABC, abstractmethod +from typing import Any + +from cpl.database.model.database_settings import DatabaseSettings + + +class DBContextABC(ABC): + r"""ABC for the :class:`cpl.database.context.database_context.DatabaseContext`""" + + @abstractmethod + def connect(self, database_settings: DatabaseSettings): + r"""Connects to a database by connection settings + + Parameter: + database_settings :class:`cpl.database.database_settings.DatabaseSettings` + """ + + @abstractmethod + async def execute(self, statement: str, args=None, multi=True) -> list[list]: + r"""Runs SQL Statements + + Parameter: + statement: :class:`str` + args: :class:`list` | :class:`tuple` | :class:`dict` | :class:`None` + multi: :class:`bool` + + Returns: + list: Fetched list of executed elements + """ + + @abstractmethod + async def select_map(self, statement: str, args=None) -> list[dict]: + r"""Runs SQL Select Statements and returns a list of dictionaries + + Parameter: + statement: :class:`str` + args: :class:`list` | :class:`tuple` | :class:`dict` | :class:`None` + + Returns: + list: Fetched list of executed elements as dictionary + """ + + @abstractmethod + async def select(self, statement: str, args=None) -> list[str] | list[tuple] | list[Any]: + r"""Runs SQL Select Statements and returns a list of dictionaries + + Parameter: + statement: :class:`str` + args: :class:`list` | :class:`tuple` | :class:`dict` | :class:`None` + + Returns: + list: Fetched list of executed elements + """ diff --git a/src/cpl-database/cpl/database/abc/db_join_model_abc.py b/src/cpl-database/cpl/database/abc/db_join_model_abc.py new file mode 100644 index 00000000..c81bd50d --- /dev/null +++ b/src/cpl-database/cpl/database/abc/db_join_model_abc.py @@ -0,0 +1,30 @@ +from datetime import datetime +from typing import Optional + +from cpl.core.typing import Id, SerialId +from cpl.database.abc.db_model_abc import DbModelABC + + +class DbJoinModelABC[T](DbModelABC[T]): + def __init__( + self, + id: Id, + source_id: Id, + foreign_id: Id, + deleted: bool = False, + editor_id: Optional[SerialId] = None, + created: Optional[datetime] = None, + updated: Optional[datetime] = None, + ): + DbModelABC.__init__(self, id, deleted, editor_id, created, updated) + + self._source_id = source_id + self._foreign_id = foreign_id + + @property + def source_id(self) -> Id: + return self._source_id + + @property + def foreign_id(self) -> Id: + return self._foreign_id diff --git a/src/cpl-database/cpl/database/abc/db_model_abc.py b/src/cpl-database/cpl/database/abc/db_model_abc.py new file mode 100644 index 00000000..edbd1f3b --- /dev/null +++ b/src/cpl-database/cpl/database/abc/db_model_abc.py @@ -0,0 +1,79 @@ +from abc import ABC +from datetime import datetime, timezone +from typing import Optional, Generic + +from cpl.core.typing import Id, SerialId, T + + +class DbModelABC(ABC, Generic[T]): + def __init__( + self, + id: Id, + deleted: bool = False, + editor_id: Optional[SerialId] = None, + created: Optional[datetime] = None, + updated: Optional[datetime] = None, + ): + self._id = id + self._deleted = deleted + self._editor_id = editor_id + + self._created = created if created is not None else datetime.now(timezone.utc).isoformat() + self._updated = updated if updated is not None else datetime.now(timezone.utc).isoformat() + + @property + def id(self) -> Id: + return self._id + + @property + def deleted(self) -> bool: + return self._deleted + + @deleted.setter + def deleted(self, value: bool): + self._deleted = value + + @property + def editor_id(self) -> SerialId: + return self._editor_id + + @editor_id.setter + def editor_id(self, value: SerialId): + self._editor_id = value + + # @async_property + # async def editor(self): + # if self._editor_id is None: + # return None + # + # from data.schemas.administration.user_dao import userDao + # + # return await userDao.get_by_id(self._editor_id) + + @property + def created(self) -> datetime: + return self._created + + @property + def updated(self) -> datetime: + return self._updated + + @updated.setter + def updated(self, value: datetime): + self._updated = value + + def to_dict(self) -> dict: + result = {} + for name, value in self.__dict__.items(): + if not name.startswith("_") or name.endswith("_"): + continue + + if isinstance(value, datetime): + value = value.isoformat() + + if not isinstance(value, str): + value = str(value) + + result[name.replace("_", "")] = value + + return result diff --git a/src/cpl-database/cpl/database/abc/db_model_dao_abc.py b/src/cpl-database/cpl/database/abc/db_model_dao_abc.py new file mode 100644 index 00000000..760ece86 --- /dev/null +++ b/src/cpl-database/cpl/database/abc/db_model_dao_abc.py @@ -0,0 +1,25 @@ +from abc import abstractmethod +from datetime import datetime +from typing import Type + +from cpl.database.abc.data_access_object_abc import DataAccessObjectABC +from cpl.database.abc.db_model_abc import DbModelABC +from cpl.database.internal_tables import InternalTables + + +class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]): + + @abstractmethod + def __init__(self, source: str, model_type: Type[T_DBM], table_name: str): + DataAccessObjectABC.__init__(self, source, model_type, table_name) + + self.attribute(DbModelABC.id, int, ignore=True) + self.attribute(DbModelABC.deleted, bool) + self.attribute(DbModelABC.editor_id, int, ignore=True) # handled by db trigger + + self.reference( + "editor", "id", DbModelABC.editor_id, InternalTables.users + ) # not relevant for updates due to editor_id + + self.attribute(DbModelABC.created, datetime, ignore=True) # handled by db trigger + self.attribute(DbModelABC.updated, datetime, ignore=True) # handled by db trigger diff --git a/src/cpl-database/cpl/database/table_abc.py b/src/cpl-database/cpl/database/abc/table_abc.py similarity index 100% rename from src/cpl-database/cpl/database/table_abc.py rename to src/cpl-database/cpl/database/abc/table_abc.py diff --git a/src/cpl-database/cpl/database/const.py b/src/cpl-database/cpl/database/const.py new file mode 100644 index 00000000..355c43a5 --- /dev/null +++ b/src/cpl-database/cpl/database/const.py @@ -0,0 +1 @@ +DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f %z" diff --git a/src/cpl-database/cpl/database/database_settings_name_enum.py b/src/cpl-database/cpl/database/database_settings_name_enum.py deleted file mode 100644 index 56b59a3f..00000000 --- a/src/cpl-database/cpl/database/database_settings_name_enum.py +++ /dev/null @@ -1,13 +0,0 @@ -from enum import Enum - - -class DatabaseSettingsNameEnum(Enum): - host = "Host" - port = "Port" - user = "User" - password = "Password" - database = "Database" - charset = "Charset" - use_unicode = "UseUnicode" - buffered = "Buffered" - auth_plugin = "AuthPlugin" diff --git a/src/cpl-database/cpl/database/internal_tables.py b/src/cpl-database/cpl/database/internal_tables.py new file mode 100644 index 00000000..07d7e667 --- /dev/null +++ b/src/cpl-database/cpl/database/internal_tables.py @@ -0,0 +1,15 @@ +from cpl.database.model.server_type import ServerTypes, ServerType + + + +class InternalTables: + + @classmethod + @property + def users(cls) -> str: + return "administration.users" if ServerType.server_type is ServerTypes.POSTGRES else "users" + + @classmethod + @property + def executed_migrations(cls) -> str: + return "system._executed_migrations" if ServerType.server_type is ServerTypes.POSTGRES else "_executed_migrations" diff --git a/src/cpl-database/cpl/database/model/__init__.py b/src/cpl-database/cpl/database/model/__init__.py new file mode 100644 index 00000000..4c3c0b10 --- /dev/null +++ b/src/cpl-database/cpl/database/model/__init__.py @@ -0,0 +1,3 @@ +from .database_settings import DatabaseSettings +from .migration import Migration +from .server_type import ServerTypes diff --git a/src/cpl-database/cpl/database/database_settings.py b/src/cpl-database/cpl/database/model/database_settings.py similarity index 62% rename from src/cpl-database/cpl/database/database_settings.py rename to src/cpl-database/cpl/database/model/database_settings.py index cdae643b..260fd2a4 100644 --- a/src/cpl-database/cpl/database/database_settings.py +++ b/src/cpl-database/cpl/database/model/database_settings.py @@ -1,6 +1,9 @@ from typing import Optional +from cpl.core.configuration import Configuration from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC +from cpl.core.environment import Environment +from cpl.core.utils import Base64 class DatabaseSettings(ConfigurationModelABC): @@ -8,23 +11,23 @@ class DatabaseSettings(ConfigurationModelABC): def __init__( self, - host: str = None, - port: int = 3306, - user: str = None, - password: str = None, - database: str = None, - charset: str = "utf8mb4", - use_unicode: bool = False, - buffered: bool = False, - auth_plugin: str = "caching_sha2_password", - ssl_disabled: bool = False, + host: str = Environment.get("DB_HOST", str), + port: int = Environment.get("DB_PORT", str, Configuration.get("DB_DEFAULT_PORT", 0)), + user: str = Environment.get("DB_USER", str), + password: str = Environment.get("DB_PASSWORD", str), + database: str = Environment.get("DB_DATABASE", str), + charset: str = Environment.get("DB_CHARSET", str, "utf8mb4"), + use_unicode: bool = Environment.get("DB_USE_UNICODE", bool, False), + buffered: bool = Environment.get("DB_BUFFERED", bool, False), + auth_plugin: str = Environment.get("DB_AUTH_PLUGIN", str, "caching_sha2_password"), + ssl_disabled: bool = Environment.get("DB_SSL_DISABLED", bool, False), ): ConfigurationModelABC.__init__(self) self._host: Optional[str] = host self._port: Optional[int] = port self._user: Optional[str] = user - self._password: Optional[str] = password + self._password: Optional[str] = Base64.decode(password) if Base64.is_b64(password) else password self._database: Optional[str] = database self._charset: Optional[str] = charset self._use_unicode: Optional[bool] = use_unicode diff --git a/src/cpl-database/cpl/database/model/migration.py b/src/cpl-database/cpl/database/model/migration.py new file mode 100644 index 00000000..a32cc824 --- /dev/null +++ b/src/cpl-database/cpl/database/model/migration.py @@ -0,0 +1,12 @@ +class Migration: + def __init__(self, name: str, script: str): + self._name = name + self._script = script + + @property + def name(self) -> str: + return self._name + + @property + def script(self) -> str: + return self._script diff --git a/src/cpl-database/cpl/database/model/server_type.py b/src/cpl-database/cpl/database/model/server_type.py new file mode 100644 index 00000000..dbdd40e0 --- /dev/null +++ b/src/cpl-database/cpl/database/model/server_type.py @@ -0,0 +1,21 @@ +from enum import Enum + + +class ServerTypes(Enum): + POSTGRES = "postgres" + MYSQL = "mysql" + +class ServerType: + _server_type: ServerTypes = None + + @classmethod + def set_server_type(cls, server_type: ServerTypes): + assert server_type is not None, "server_type must not be None" + assert isinstance(server_type, ServerTypes), f"Expected ServerType but got {type(server_type)}" + cls._server_type = server_type + + @classmethod + @property + def server_type(cls) -> ServerTypes: + assert cls._server_type is not None, "Server type is not set" + return cls._server_type \ No newline at end of file diff --git a/src/cpl-database/cpl/database/mysql/connection/database_connection.py b/src/cpl-database/cpl/database/mysql/connection.py similarity index 90% rename from src/cpl-database/cpl/database/mysql/connection/database_connection.py rename to src/cpl-database/cpl/database/mysql/connection.py index 59753b6a..baa16e0f 100644 --- a/src/cpl-database/cpl/database/mysql/connection/database_connection.py +++ b/src/cpl-database/cpl/database/mysql/connection.py @@ -4,16 +4,16 @@ import mysql.connector as sql from mysql.connector.abstracts import MySQLConnectionAbstract from mysql.connector.cursor import MySQLCursorBuffered -from cpl.database.mysql.connection.database_connection_abc import DatabaseConnectionABC +from cpl.database.abc.connection_abc import ConnectionABC from cpl.database.database_settings import DatabaseSettings from cpl.core.utils.credential_manager import CredentialManager -class DatabaseConnection(DatabaseConnectionABC): +class DatabaseConnection(ConnectionABC): r"""Representation of the database connection""" def __init__(self): - DatabaseConnectionABC.__init__(self) + ConnectionABC.__init__(self) self._database: Optional[MySQLConnectionAbstract] = None self._cursor: Optional[MySQLCursorBuffered] = None diff --git a/src/cpl-database/cpl/database/mysql/connection/__init__.py b/src/cpl-database/cpl/database/mysql/connection/__init__.py deleted file mode 100644 index 06103864..00000000 --- a/src/cpl-database/cpl/database/mysql/connection/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .database_connection import DatabaseConnection -from .database_connection_abc import DatabaseConnectionABC diff --git a/src/cpl-database/cpl/database/mysql/context/__init__.py b/src/cpl-database/cpl/database/mysql/context/__init__.py deleted file mode 100644 index b061887d..00000000 --- a/src/cpl-database/cpl/database/mysql/context/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .database_context import DatabaseContext -from .database_context_abc import DatabaseContextABC diff --git a/src/cpl-database/cpl/database/mysql/context/database_context.py b/src/cpl-database/cpl/database/mysql/context/database_context.py deleted file mode 100644 index afe33970..00000000 --- a/src/cpl-database/cpl/database/mysql/context/database_context.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Optional - - -from cpl.database.mysql.connection.database_connection import DatabaseConnection -from cpl.database.mysql.connection.database_connection_abc import DatabaseConnectionABC -from cpl.database.mysql.context.database_context_abc import DatabaseContextABC -from cpl.database.database_settings import DatabaseSettings -from mysql.connector.cursor import MySQLCursorBuffered - - -class DatabaseContext(DatabaseContextABC): - r"""Representation of the database context - - Parameter: - database_settings: :class:`cpl.database.database_settings.DatabaseSettings` - """ - - def __init__(self): - DatabaseContextABC.__init__(self) - - self._db: DatabaseConnectionABC = DatabaseConnection() - self._settings: Optional[DatabaseSettings] = None - - @property - def cursor(self) -> MySQLCursorBuffered: - self._ping_and_reconnect() - return self._db.cursor - - def _ping_and_reconnect(self): - try: - self._db.server.ping(reconnect=True, attempts=3, delay=5) - except Exception: - # reconnect your cursor as you did in __init__ or wherever - if self._settings is None: - raise Exception("Call DatabaseContext.connect first") - self.connect(self._settings) - - def connect(self, database_settings: DatabaseSettings): - if self._settings is None: - self._settings = database_settings - self._db.connect(database_settings) - - self.save_changes() - - def save_changes(self): - self._ping_and_reconnect() - self._db.server.commit() - - def select(self, statement: str) -> list[tuple]: - self._ping_and_reconnect() - self._db.cursor.execute(statement) - return self._db.cursor.fetchall() diff --git a/src/cpl-database/cpl/database/mysql/context/database_context_abc.py b/src/cpl-database/cpl/database/mysql/context/database_context_abc.py deleted file mode 100644 index 481be1f0..00000000 --- a/src/cpl-database/cpl/database/mysql/context/database_context_abc.py +++ /dev/null @@ -1,40 +0,0 @@ -from abc import ABC, abstractmethod - -from cpl.database.database_settings import DatabaseSettings -from mysql.connector.cursor import MySQLCursorBuffered - - -class DatabaseContextABC(ABC): - r"""ABC for the :class:`cpl.database.context.database_context.DatabaseContext`""" - - @abstractmethod - def __init__(self, *args): - pass - - @property - @abstractmethod - def cursor(self) -> MySQLCursorBuffered: - pass - - @abstractmethod - def connect(self, database_settings: DatabaseSettings): - r"""Connects to a database by connection settings - - Parameter: - database_settings :class:`cpl.database.database_settings.DatabaseSettings` - """ - - @abstractmethod - def save_changes(self): - r"""Saves changes of the database""" - - @abstractmethod - def select(self, statement: str) -> list[tuple]: - r"""Runs SQL Statements - - Parameter: - statement: :class:`str` - - Returns: - list: Fetched list of selected elements - """ diff --git a/src/cpl-database/cpl/database/mysql/db_context.py b/src/cpl-database/cpl/database/mysql/db_context.py new file mode 100644 index 00000000..686cc4f8 --- /dev/null +++ b/src/cpl-database/cpl/database/mysql/db_context.py @@ -0,0 +1,84 @@ +import uuid +from typing import Any, List, Dict, Tuple, Union + +from mysql.connector import Error as MySQLError, PoolError + +from cpl.core.configuration import Configuration +from cpl.core.environment import Environment +from cpl.database.abc.db_context_abc import DBContextABC +from cpl.database.db_logger import DBLogger +from cpl.database.model.database_settings import DatabaseSettings +from cpl.database.mysql.mysql_pool import MySQLPool + +_logger = DBLogger(__name__) + + +class DBContext(DBContextABC): + def __init__(self): + DBContextABC.__init__(self) + self._pool: MySQLPool = None + self._fails = 0 + + self.connect(Configuration.get(DatabaseSettings)) + + def connect(self, database_settings: DatabaseSettings): + try: + _logger.debug("Connecting to database") + self._pool = MySQLPool( + database_settings, + ) + _logger.info("Connected to database") + except Exception as e: + _logger.fatal("Connecting to database failed", e) + + async def execute(self, statement: str, args=None, multi=True) -> List[List]: + _logger.trace(f"execute {statement} with args: {args}") + return await self._pool.execute(statement, args, multi) + + async def select_map(self, statement: str, args=None) -> List[Dict]: + _logger.trace(f"select {statement} with args: {args}") + try: + return await self._pool.select_map(statement, args) + except (MySQLError, PoolError) as e: + if self._fails >= 3: + _logger.error(f"Database error caused by `{statement}`", e) + uid = uuid.uuid4() + raise Exception( + f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" + ) + + _logger.error(f"Database error caused by `{statement}`", e) + self._fails += 1 + try: + _logger.debug("Retry select") + return await self.select_map(statement, args) + except Exception as e: + pass + return [] + except Exception as e: + _logger.error(f"Database error caused by `{statement}`", e) + raise e + + async def select(self, statement: str, args=None) -> Union[List[str], List[Tuple], List[Any]]: + _logger.trace(f"select {statement} with args: {args}") + try: + return await self._pool.select(statement, args) + except (MySQLError, PoolError) as e: + if self._fails >= 3: + _logger.error(f"Database error caused by `{statement}`", e) + uid = uuid.uuid4() + raise Exception( + f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" + ) + + _logger.error(f"Database error caused by `{statement}`", e) + self._fails += 1 + try: + _logger.debug("Retry select") + return await self.select(statement, args) + except Exception as e: + pass + return [] + except Exception as e: + _logger.error(f"Database error caused by `{statement}`", e) + raise e diff --git a/src/cpl-database/cpl/database/mysql/mysql_pool.py b/src/cpl-database/cpl/database/mysql/mysql_pool.py new file mode 100644 index 00000000..9faed3ce --- /dev/null +++ b/src/cpl-database/cpl/database/mysql/mysql_pool.py @@ -0,0 +1,105 @@ +from typing import Optional, Any + +import sqlparse +import aiomysql + +from cpl.core.environment import Environment +from cpl.database.db_logger import DBLogger +from cpl.database.model import DatabaseSettings + +_logger = DBLogger(__name__) + + +class MySQLPool: + """ + Create a pool when connecting to MySQL, which will decrease the time spent in + requesting connection, creating connection, and closing connection. + """ + + def __init__(self, database_settings: DatabaseSettings): + self._db_settings = database_settings + self.pool: Optional[aiomysql.Pool] = None + + async def _get_pool(self): + if self.pool is None or self.pool._closed: + try: + self.pool = await aiomysql.create_pool( + host=self._db_settings.host, + port=self._db_settings.port, + user=self._db_settings.user, + password=self._db_settings.password, + db=self._db_settings.database, + minsize=1, + maxsize=Environment.get("DB_POOL_SIZE", int, 1), + autocommit=True, + ) + except Exception as e: + _logger.fatal("Failed to connect to the database", e) + raise + return self.pool + + @staticmethod + async def _exec_sql(cursor: Any, query: str, args=None, multi=True): + if multi: + queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()] + for q in queries: + if q.strip() == "": + continue + await cursor.execute(q, args) + else: + await cursor.execute(query, args) + + async def execute(self, query: str, args=None, multi=True) -> list[list]: + """ + Execute a SQL statement, it could be with args and without args. The usage is + similar to the execute() function in aiomysql. + :param query: SQL clause + :param args: args needed by the SQL clause + :param multi: if the query is a multi-statement + :return: return result + """ + pool = await self._get_pool() + async with pool.acquire() as con: + async with con.cursor() as cursor: + await self._exec_sql(cursor, query, args, multi) + + if cursor.description is not None: # Query returns rows + res = await cursor.fetchall() + if res is None: + return [] + + return [list(row) for row in res] + else: + return [] + + async def select(self, query: str, args=None, multi=True) -> list[str]: + """ + Execute a SQL statement, it could be with args and without args. The usage is + similar to the execute() function in aiomysql. + :param query: SQL clause + :param args: args needed by the SQL clause + :param multi: if the query is a multi-statement + :return: return result + """ + pool = await self._get_pool() + async with pool.acquire() as con: + async with con.cursor() as cursor: + await self._exec_sql(cursor, query, args, multi) + res = await cursor.fetchall() + return list(res) + + async def select_map(self, query: str, args=None, multi=True) -> list[dict]: + """ + Execute a SQL statement, it could be with args and without args. The usage is + similar to the execute() function in aiomysql. + :param query: SQL clause + :param args: args needed by the SQL clause + :param multi: if the query is a multi-statement + :return: return result + """ + pool = await self._get_pool() + async with pool.acquire() as con: + async with con.cursor(aiomysql.DictCursor) as cursor: + await self._exec_sql(cursor, query, args, multi) + res = await cursor.fetchall() + return list(res) diff --git a/src/cpl-database/cpl/database/postgres/__init__.py b/src/cpl-database/cpl/database/postgres/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-database/cpl/database/postgres/db_context.py b/src/cpl-database/cpl/database/postgres/db_context.py new file mode 100644 index 00000000..40fcd7f1 --- /dev/null +++ b/src/cpl-database/cpl/database/postgres/db_context.py @@ -0,0 +1,86 @@ +import uuid +from typing import Any + +from psycopg import OperationalError +from psycopg_pool import PoolTimeout + +from cpl.core.configuration import Configuration +from cpl.core.environment import Environment +from cpl.database.abc.db_context_abc import DBContextABC +from cpl.database.database_settings import DatabaseSettings +from cpl.database.db_logger import DBLogger +from cpl.database.postgres.postgres_pool import PostgresPool + +_logger = DBLogger(__name__) + + +class DBContext(DBContextABC): + def __init__(self): + DBContextABC.__init__(self) + self._pool: PostgresPool = None + self._fails = 0 + + self.connect(Configuration.get(DatabaseSettings)) + + def connect(self, database_settings: DatabaseSettings): + try: + _logger.debug("Connecting to database") + self._pool = PostgresPool( + database_settings, + Environment.get("DB_POOL_SIZE", int, 1), + ) + _logger.info("Connected to database") + except Exception as e: + _logger.fatal("Connecting to database failed", e) + + async def execute(self, statement: str, args=None, multi=True) -> list[list]: + _logger.trace(f"execute {statement} with args: {args}") + return await self._pool.execute(statement, args, multi) + + async def select_map(self, statement: str, args=None) -> list[dict]: + _logger.trace(f"select {statement} with args: {args}") + try: + return await self._pool.select_map(statement, args) + except (OperationalError, PoolTimeout) as e: + if self._fails >= 3: + _logger.error(f"Database error caused by `{statement}`", e) + uid = uuid.uuid4() + raise Exception( + f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" + ) + + _logger.error(f"Database error caused by `{statement}`", e) + self._fails += 1 + try: + _logger.debug("Retry select") + return await self.select_map(statement, args) + except Exception as e: + pass + return [] + except Exception as e: + _logger.error(f"Database error caused by `{statement}`", e) + raise e + + async def select(self, statement: str, args=None) -> list[str] | list[tuple] | list[Any]: + _logger.trace(f"select {statement} with args: {args}") + try: + return await self._pool.select(statement, args) + except (OperationalError, PoolTimeout) as e: + if self._fails >= 3: + _logger.error(f"Database error caused by `{statement}`", e) + uid = uuid.uuid4() + raise Exception( + f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" + ) + + _logger.error(f"Database error caused by `{statement}`", e) + self._fails += 1 + try: + _logger.debug("Retry select") + return await self.select(statement, args) + except Exception as e: + pass + return [] + except Exception as e: + _logger.error(f"Database error caused by `{statement}`", e) + raise e diff --git a/src/cpl-database/cpl/database/postgres/postgres_pool.py b/src/cpl-database/cpl/database/postgres/postgres_pool.py new file mode 100644 index 00000000..8d74e35b --- /dev/null +++ b/src/cpl-database/cpl/database/postgres/postgres_pool.py @@ -0,0 +1,123 @@ +from typing import Optional, Any + +import sqlparse +from psycopg import sql +from psycopg_pool import AsyncConnectionPool, PoolTimeout + +from cpl.core.environment import Environment +from cpl.database.db_logger import DBLogger +from cpl.database.model import DatabaseSettings + +_logger = DBLogger(__name__) + + +class PostgresPool: + """ + Create a pool when connecting to PostgreSQL, which will decrease the time spent in + requesting connection, creating connection, and closing connection. + """ + + def __init__(self, database_settings: DatabaseSettings): + self._conninfo = ( + f"host={database_settings.host} " + f"port={database_settings.port} " + f"user={database_settings.user} " + f"password={database_settings.password} " + f"dbname={database_settings.database}" + ) + + self.pool: Optional[AsyncConnectionPool] = None + + async def _get_pool(self): + pool = AsyncConnectionPool( + conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1) + ) + await pool.open() + try: + async with pool.connection() as con: + await pool.check_connection(con) + except PoolTimeout as e: + await pool.close() + _logger.fatal(f"Failed to connect to the database", e) + return pool + + @staticmethod + async def _exec_sql(cursor: Any, query: str, args=None, multi=True): + if multi: + queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()] + for q in queries: + if q.strip() == "": + continue + + await cursor.execute(sql.SQL(q), args) + else: + await cursor.execute(sql.SQL(query), args) + + async def execute(self, query: str, args=None, multi=True) -> list[list]: + """ + Execute a SQL statement, it could be with args and without args. The usage is + similar to the execute() function in the psycopg module. + :param query: SQL clause + :param args: args needed by the SQL clause + :param multi: if the query is a multi-statement + :return: return result + """ + async with await self._get_pool() as pool: + async with pool.connection() as con: + async with con.cursor() as cursor: + await self._exec_sql(cursor, query, args, multi) + + if cursor.description is not None: # Check if the query returns rows + res = await cursor.fetchall() + if res is None: + return [] + + result = [] + for row in res: + result.append(list(row)) + return result + else: + return [] + + async def select(self, query: str, args=None, multi=True) -> list[str]: + """ + Execute a SQL statement, it could be with args and without args. The usage is + similar to the execute() function in the psycopg module. + :param query: SQL clause + :param args: args needed by the SQL clause + :param multi: if the query is a multi-statement + :return: return result + """ + async with await self._get_pool() as pool: + async with pool.connection() as con: + async with con.cursor() as cursor: + await self._exec_sql(cursor, query, args, multi) + + res = await cursor.fetchall() + return list(res) + + async def select_map(self, query: str, args=None, multi=True) -> list[dict]: + """ + Execute a SQL statement, it could be with args and without args. The usage is + similar to the execute() function in the psycopg module. + :param query: SQL clause + :param args: args needed by the SQL clause + :param multi: if the query is a multi-statement + :return: return result + """ + async with await self._get_pool() as pool: + async with pool.connection() as con: + async with con.cursor() as cursor: + await self._exec_sql(cursor, query, args, multi) + + res = await cursor.fetchall() + res_map: list[dict] = [] + + for i_res in range(len(res)): + cols = {} + for i_col in range(len(res[i_res])): + cols[cursor.description[i_col].name] = res[i_res][i_col] + + res_map.append(cols) + + return res_map diff --git a/src/cpl-database/cpl/database/postgres/sql_select_builder.py b/src/cpl-database/cpl/database/postgres/sql_select_builder.py new file mode 100644 index 00000000..08487628 --- /dev/null +++ b/src/cpl-database/cpl/database/postgres/sql_select_builder.py @@ -0,0 +1,154 @@ +from typing import Optional, Union + +from cpl.database._external_data_temp_table_builder import ExternalDataTempTableBuilder + + +class SQLSelectBuilder: + + def __init__(self, table_name: str, primary_key: str): + self._table_name = table_name + self._primary_key = primary_key + + self._temp_tables: dict[str, ExternalDataTempTableBuilder] = {} + self._to_use_temp_tables: list[str] = [] + self._attributes: list[str] = [] + self._tables: list[str] = [table_name] + self._joins: dict[str, (str, str)] = {} + self._conditions: list[str] = [] + self._order_by: str = "" + self._limit: Optional[int] = None + self._offset: Optional[int] = None + + def with_temp_table(self, temp_table: ExternalDataTempTableBuilder) -> "SQLSelectBuilder": + self._temp_tables[temp_table.table_name] = temp_table + return self + + def use_temp_table(self, temp_table_name: str): + if temp_table_name not in self._temp_tables: + raise ValueError(f"Temp table {temp_table_name} not found.") + + self._to_use_temp_tables.append(temp_table_name) + + def with_attribute(self, attr: str, ignore_table_name=False) -> "SQLSelectBuilder": + if not ignore_table_name and not attr.startswith(self._table_name): + attr = f"{self._table_name}.{attr}" + + self._attributes.append(attr) + return self + + def with_foreign_attribute(self, attr: str) -> "SQLSelectBuilder": + self._attributes.append(attr) + return self + + def with_table(self, table_name: str) -> "SQLSelectBuilder": + self._tables.append(table_name) + return self + + def _check_prefix(self, attr: str, foreign_tables: list[str]) -> str: + assert attr is not None + + if "TO_CHAR" in attr: + return attr + + valid_prefixes = [ + "levenshtein", + self._table_name, + *self._joins.keys(), + *self._temp_tables.keys(), + *foreign_tables, + ] + if not any(attr.startswith(f"{prefix}.") for prefix in valid_prefixes): + attr = f"{self._table_name}.{attr}" + + return attr + + def with_value_condition( + self, attr: str, operator: str, value: str, foreign_tables: list[str] + ) -> "SQLSelectBuilder": + attr = self._check_prefix(attr, foreign_tables) + self._conditions.append(f"{attr} {operator} {value}") + return self + + def with_levenshtein_condition(self, condition: str) -> "SQLSelectBuilder": + self._conditions.append(condition) + return self + + def with_condition(self, attr: str, operator: str, foreign_tables: list[str]) -> "SQLSelectBuilder": + attr = self._check_prefix(attr, foreign_tables) + self._conditions.append(f"{attr} {operator}") + return self + + def with_grouped_conditions(self, conditions: list[str]) -> "SQLSelectBuilder": + self._conditions.append(f"({' AND '.join(conditions)})") + return self + + def with_left_join(self, table: str, on: str) -> "SQLSelectBuilder": + if table in self._joins: + self._joins[table] = (f"{self._joins[table][0]} AND {on}", "LEFT") + + self._joins[table] = (on, "LEFT") + return self + + def with_inner_join(self, table: str, on: str) -> "SQLSelectBuilder": + if table in self._joins: + self._joins[table] = (f"{self._joins[table][0]} AND {on}", "INNER") + + self._joins[table] = (on, "INNER") + return self + + def with_right_join(self, table: str, on: str) -> "SQLSelectBuilder": + if table in self._joins: + self._joins[table] = (f"{self._joins[table][0]} AND {on}", "RIGHT") + + self._joins[table] = (on, "RIGHT") + return self + + def with_limit(self, limit: int) -> "SQLSelectBuilder": + self._limit = limit + return self + + def with_offset(self, offset: int) -> "SQLSelectBuilder": + self._offset = offset + return self + + def with_order_by(self, column: Union[str, property], direction: str = "ASC") -> "SQLSelectBuilder": + if isinstance(column, property): + column = column.fget.__name__ + self._order_by = f"{column} {direction}" + return self + + async def _handle_temp_table_use(self, query) -> str: + new_query = "" + + for temp_table_name in self._to_use_temp_tables: + temp_table = self._temp_tables[temp_table_name] + new_query += await self._temp_tables[temp_table_name].build() + self.with_left_join( + temp_table.table_name, + f"{temp_table.join_ref_table}.{self._primary_key} = {temp_table.table_name}.{temp_table.primary_key}", + ) + + return f"{new_query} {query}" if new_query != "" else query + + async def build(self) -> str: + query = await self._handle_temp_table_use("") + + attributes = ", ".join(self._attributes) if self._attributes else "*" + query += f"SELECT {attributes} FROM {", ".join(self._tables)}" + + for join in self._joins: + query += f" {self._joins[join][1]} JOIN {join} ON {self._joins[join][0]}" + + if self._conditions: + query += " WHERE " + " AND ".join(self._conditions) + + if self._order_by: + query += f" ORDER BY {self._order_by}" + + if self._limit is not None: + query += f" LIMIT {self._limit}" + + if self._offset is not None: + query += f" OFFSET {self._offset}" + + return query diff --git a/src/cpl-database/cpl/database/schema/__init__.py b/src/cpl-database/cpl/database/schema/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-database/cpl/database/schema/executed_migration.py b/src/cpl-database/cpl/database/schema/executed_migration.py new file mode 100644 index 00000000..3b9ed1c5 --- /dev/null +++ b/src/cpl-database/cpl/database/schema/executed_migration.py @@ -0,0 +1,18 @@ +from datetime import datetime +from typing import Optional + +from cpl.database.abc import DbModelABC + + +class ExecutedMigration(DbModelABC): + def __init__( + self, + migration_id: str, + created: Optional[datetime] = None, + modified: Optional[datetime] = None, + ): + DbModelABC.__init__(self, migration_id, False, created, modified) + + @property + def migration_id(self) -> str: + return self._id diff --git a/src/cpl-database/cpl/database/schema/executed_migration_dao.py b/src/cpl-database/cpl/database/schema/executed_migration_dao.py new file mode 100644 index 00000000..cef92ce3 --- /dev/null +++ b/src/cpl-database/cpl/database/schema/executed_migration_dao.py @@ -0,0 +1,14 @@ +from cpl.database import InternalTables +from cpl.database.abc.data_access_object_abc import DataAccessObjectABC +from cpl.database.db_logger import DBLogger +from cpl.database.schema.executed_migration import ExecutedMigration + +_logger = DBLogger(__name__) + + +class ExecutedMigrationDao(DataAccessObjectABC[ExecutedMigration]): + + def __init__(self): + DataAccessObjectABC.__init__(self, __name__, ExecutedMigration, InternalTables.executed_migrations) + + self.attribute(ExecutedMigration.migration_id, str, primary_key=True, db_name="migrationId") diff --git a/src/cpl-database/cpl/database/scripts/mysql/0-cpl-initial.sql b/src/cpl-database/cpl/database/scripts/mysql/0-cpl-initial.sql new file mode 100644 index 00000000..d2a1b292 --- /dev/null +++ b/src/cpl-database/cpl/database/scripts/mysql/0-cpl-initial.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS _executed_migrations +( + migrationId VARCHAR(255) PRIMARY KEY, + created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP +); \ No newline at end of file diff --git a/src/cpl-database/cpl/database/scripts/mysql/trigger.txt b/src/cpl-database/cpl/database/scripts/mysql/trigger.txt new file mode 100644 index 00000000..9ce1cf14 --- /dev/null +++ b/src/cpl-database/cpl/database/scripts/mysql/trigger.txt @@ -0,0 +1,26 @@ + +DELIMITER // +CREATE TRIGGER mytable_before_update + BEFORE UPDATE + ON mytable + FOR EACH ROW +BEGIN + INSERT INTO mytable_history + SELECT OLD.*; + + SET NEW.updated = NOW(); +END; +// +DELIMITER ; + +DELIMITER // +CREATE TRIGGER mytable_before_delete + BEFORE DELETE + ON mytable + FOR EACH ROW +BEGIN + INSERT INTO mytable_history + SELECT OLD.*; +END; +// +DELIMITER ; \ No newline at end of file diff --git a/src/cpl-database/cpl/database/scripts/postgres/0-cpl-initial.sql b/src/cpl-database/cpl/database/scripts/postgres/0-cpl-initial.sql new file mode 100644 index 00000000..3857a6e7 --- /dev/null +++ b/src/cpl-database/cpl/database/scripts/postgres/0-cpl-initial.sql @@ -0,0 +1,47 @@ +CREATE SCHEMA IF NOT EXISTS public; +CREATE SCHEMA IF NOT EXISTS system; + +CREATE TABLE IF NOT EXISTS system._executed_migrations +( + MigrationId VARCHAR(255) PRIMARY KEY, + Created timestamptz NOT NULL DEFAULT NOW(), + Updated timestamptz NOT NULL DEFAULT NOW() +); + +CREATE OR REPLACE FUNCTION public.history_trigger_function() + RETURNS TRIGGER AS +$$ +DECLARE + schema_name TEXT; + history_table_name TEXT; +BEGIN + -- Construct the name of the history table based on the current table + schema_name := TG_TABLE_SCHEMA; + history_table_name := TG_TABLE_NAME || '_history'; + + IF (TG_OP = 'INSERT') THEN + RETURN NEW; + END IF; + + -- Insert the old row into the history table on UPDATE or DELETE + IF (TG_OP = 'UPDATE' OR TG_OP = 'DELETE') THEN + EXECUTE format( + 'INSERT INTO %I.%I SELECT ($1).*', + schema_name, + history_table_name + ) + USING OLD; + END IF; + + -- For UPDATE, update the Updated column and return the new row + IF (TG_OP = 'UPDATE') THEN + NEW.updated := NOW(); -- Update the Updated column + RETURN NEW; + END IF; + + -- For DELETE, return OLD to allow the deletion + IF (TG_OP = 'DELETE') THEN + RETURN OLD; + END IF; +END; +$$ LANGUAGE plpgsql; diff --git a/src/cpl-database/cpl/database/service/__init__.py b/src/cpl-database/cpl/database/service/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-database/cpl/database/service/migration_service.py b/src/cpl-database/cpl/database/service/migration_service.py new file mode 100644 index 00000000..c84ee475 --- /dev/null +++ b/src/cpl-database/cpl/database/service/migration_service.py @@ -0,0 +1,111 @@ +import glob +import os + +from cpl.database.abc import DBContextABC +from cpl.database.db_logger import DBLogger +from cpl.database.model import Migration +from cpl.database.model.server_type import ServerType, ServerTypes +from cpl.database.schema.executed_migration import ExecutedMigration +from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao + +_logger = DBLogger(__name__) + + +class MigrationService: + + def __init__(self, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao): + self._db = db + self._executedMigrationDao = executedMigrationDao + + self._script_directories: list[str] = [] + + if ServerType.server_type == ServerTypes.POSTGRES: + self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/postgres")) + elif ServerType.server_type == ServerTypes.MYSQL: + self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/mysql")) + + def with_directory(self, directory: str) -> "MigrationService": + self._script_directories.append(directory) + return self + + async def _get_migration_history(self) -> list[ExecutedMigration]: + results = await self._db.select(f"SELECT * FROM {self._executedMigrationDao.table_name}") + applied_migrations = [] + for result in results: + applied_migrations.append(ExecutedMigration(result[0])) + + return applied_migrations + + @staticmethod + def _load_scripts_by_path(path: str) -> list[Migration]: + migrations = [] + + if not os.path.exists(path): + raise Exception("Migration path not found") + + files = sorted(glob.glob(f"{path}/*")) + + for file in files: + if not file.endswith(".sql"): + continue + + name = str(file.split(".sql")[0]) + if "/" in name: + name = name.split("/")[-1] + + with open(f"{file}", "r") as f: + script = f.read() + f.close() + + migrations.append(Migration(name, script)) + + return migrations + + def _load_scripts(self) -> list[Migration]: + migrations = [] + for path in self._script_directories: + migrations.extend(self._load_scripts_by_path(path)) + + return migrations + + async def _get_tables(self): + if ServerType == ServerTypes.POSTGRES: + return await self._db.select( + """ + SELECT tablename + FROM pg_tables + WHERE schemaname = 'public'; + """ + ) + else: + return await self._db.select( + """ + SHOW TABLES; + """ + ) + + async def _execute(self, migrations: list[Migration]): + result = await self._get_tables() + + for migration in migrations: + active_statement = "" + try: + # check if table exists + if len(result) > 0: + migration_from_db = await self._executedMigrationDao.find_by_id(migration.name) + if migration_from_db is not None: + continue + + _logger.debug(f"Running upgrade migration: {migration.name}") + + await self._db.execute(migration.script, multi=True) + + await self._executedMigrationDao.create(ExecutedMigration(migration.name), skip_editor=True) + except Exception as e: + _logger.fatal( + f"Migration failed: {migration.name}\n{active_statement}", + e, + ) + + async def migrate(self): + await self._execute(self._load_scripts()) diff --git a/src/cpl-database/cpl/database/typing.py b/src/cpl-database/cpl/database/typing.py new file mode 100644 index 00000000..c3b7385a --- /dev/null +++ b/src/cpl-database/cpl/database/typing.py @@ -0,0 +1,65 @@ +from datetime import datetime +from typing import TypeVar, Union, Literal, Any + +from cpl.database.abc.db_model_abc import DbModelABC + + +T_DBM = TypeVar("T_DBM", bound=DbModelABC) + +NumberFilterOperator = Literal[ + "equal", + "notEqual", + "greater", + "greaterOrEqual", + "less", + "lessOrEqual", + "isNull", + "isNotNull", +] +StringFilterOperator = Literal[ + "equal", + "notEqual", + "contains", + "notContains", + "startsWith", + "endsWith", + "isNull", + "isNotNull", +] +BoolFilterOperator = Literal[ + "equal", + "notEqual", + "isNull", + "isNotNull", +] +DateFilterOperator = Literal[ + "equal", + "notEqual", + "greater", + "greaterOrEqual", + "less", + "lessOrEqual", + "isNull", + "isNotNull", +] +FilterOperator = Union[NumberFilterOperator, StringFilterOperator, BoolFilterOperator, DateFilterOperator] + +Attribute = Union[str, property] + +AttributeCondition = Union[ + dict[NumberFilterOperator, int], + dict[StringFilterOperator, str], + dict[BoolFilterOperator, bool], + dict[DateFilterOperator, datetime], +] +AttributeFilter = dict[Attribute, Union[list[Union[AttributeCondition, Any]], AttributeCondition, Any]] +AttributeFilters = Union[ + list[AttributeFilter], + AttributeFilter, +] + +AttributeSort = dict[Attribute, Literal["asc", "desc"]] +AttributeSorts = Union[ + list[AttributeSort], + AttributeSort, +] diff --git a/src/cpl-database/requirements.txt b/src/cpl-database/requirements.txt index e8d9db7b..e613d162 100644 --- a/src/cpl-database/requirements.txt +++ b/src/cpl-database/requirements.txt @@ -1,2 +1,8 @@ cpl-core -cpl-dependency \ No newline at end of file +cpl-dependency +psycopg[binary]==3.2.3 +psycopg-pool==3.2.4 +sqlparse==0.5.3 +mysql-connector-python==9.4.0 +async-property==0.2.2 +aiomysql==0.2.0 \ No newline at end of file diff --git a/src/cpl-dependency/cpl/dependency/service_provider_abc.py b/src/cpl-dependency/cpl/dependency/service_provider_abc.py index 142e9daa..d6e06b36 100644 --- a/src/cpl-dependency/cpl/dependency/service_provider_abc.py +++ b/src/cpl-dependency/cpl/dependency/service_provider_abc.py @@ -20,6 +20,10 @@ class ServiceProviderABC(ABC): def set_global_provider(cls, provider: "ServiceProviderABC"): cls._provider = provider + @classmethod + def get_global_provider(cls) -> Optional["ServiceProviderABC"]: + return cls._provider + @abstractmethod def _build_by_signature(self, sig: Signature, origin_service_type: type) -> list[R]: pass diff --git a/tests/custom/database/src/application.py b/tests/custom/database/src/application.py index 6ba0edbe..0bbb49e9 100644 --- a/tests/custom/database/src/application.py +++ b/tests/custom/database/src/application.py @@ -5,6 +5,7 @@ from cpl.core.console import Console from cpl.core.environment import Environment from cpl.core.log import LoggerABC from cpl.dependency import ServiceProviderABC +from model.user_dao import UserDao from model.user_repo import UserRepo from model.user_repo_abc import UserRepoABC @@ -15,21 +16,28 @@ class Application(ApplicationABC): self._logger: Optional[LoggerABC] = None - def configure(self): - self._logger = self._services.get_service(LoggerABC) - - def main(self): - self._logger.debug(f"Host: {Environment.get_host_name()}") - self._logger.debug(f"Environment: {Environment.get_environment()}") - + async def test_repos(self): user_repo: UserRepo = self._services.get_service(UserRepoABC) - if len(user_repo.get_users()) == 0: + if len(await user_repo.get_users()) == 0: user_repo.add_test_user() Console.write_line("Users:") - for user in user_repo.get_users(): + for user in await user_repo.get_users(): Console.write_line(user.UserId, user.Name, user.City) Console.write_line("Cities:") - for city in user_repo.get_cities(): + for city in await user_repo.get_cities(): Console.write_line(city.CityId, city.Name, city.ZIP) + + async def test_daos(self): + userDao: UserDao = self._services.get_service(UserDao) + Console.write_line(await userDao.get_all()) + + async def configure(self): + self._logger = self._services.get_service(LoggerABC) + + async def main(self): + self._logger.debug(f"Host: {Environment.get_host_name()}") + self._logger.debug(f"Environment: {Environment.get_environment()}") + + await self.test_daos() diff --git a/tests/custom/database/src/appsettings.edrafts-lapi.json b/tests/custom/database/src/appsettings.edrafts-lapi.json index c78e3458..c5330111 100644 --- a/tests/custom/database/src/appsettings.edrafts-lapi.json +++ b/tests/custom/database/src/appsettings.edrafts-lapi.json @@ -15,8 +15,8 @@ "DatabaseSettings": { "AuthPlugin": "mysql_native_password", - "ConnectionString": "mysql+mysqlconnector://sh_cpl:$credentials@localhost/sh_cpl", - "Credentials": "MHZhc0Y2bjhKc1VUMWV0Qw==", + "ConnectionString": "mysql+mysqlconnector://cpl:$credentials@localhost/cpl", + "Credentials": "Y3Bs", "Encoding": "utf8mb4" } } \ No newline at end of file diff --git a/tests/custom/database/src/appsettings.edrafts-pc.json b/tests/custom/database/src/appsettings.edrafts-pc.json index 78bff4d4..66e6c101 100644 --- a/tests/custom/database/src/appsettings.edrafts-pc.json +++ b/tests/custom/database/src/appsettings.edrafts-pc.json @@ -15,9 +15,10 @@ "DatabaseSettings": { "Host": "localhost", - "User": "sh_cpl", - "Password": "MHZhc0Y2bjhKc1VUMWV0Qw==", - "Database": "sh_cpl", + "User": "cpl", + "Port": 3306, + "Password": "Y3Bs", + "Database": "cpl", "Charset": "utf8mb4", "UseUnicode": "true", "Buffered": "true" diff --git a/tests/custom/database/src/main.py b/tests/custom/database/src/main.py index a8909e63..86abcbc0 100644 --- a/tests/custom/database/src/main.py +++ b/tests/custom/database/src/main.py @@ -1,14 +1,16 @@ -from cpl.application import ApplicationBuilder - from application import Application +from cpl.application import ApplicationBuilder from startup import Startup -def main(): +async def main(): app_builder = ApplicationBuilder(Application) app_builder.use_startup(Startup) - app_builder.build().run() + app = await app_builder.build_async() + await app.run_async() if __name__ == "__main__": - main() + import asyncio + + asyncio.run(main()) diff --git a/tests/custom/database/src/model/city_model.py b/tests/custom/database/src/model/city_model.py index 0d22259d..f56bc8c7 100644 --- a/tests/custom/database/src/model/city_model.py +++ b/tests/custom/database/src/model/city_model.py @@ -1,4 +1,4 @@ -from cpl.database import TableABC +from cpl.database.abc.table_abc import TableABC class CityModel(TableABC): diff --git a/tests/custom/database/src/model/user.py b/tests/custom/database/src/model/user.py new file mode 100644 index 00000000..51b7ee18 --- /dev/null +++ b/tests/custom/database/src/model/user.py @@ -0,0 +1,16 @@ +from cpl.database.abc.db_model_abc import DbModelABC + + +class User(DbModelABC): + def __init__(self, id: int, name: str, city_id: int = 0): + DbModelABC.__init__(self, id) + self._name = name + self._city_id = city_id + + @property + def name(self) -> str: + return self._name + + @property + def city_id(self) -> int: + return self._city_id \ No newline at end of file diff --git a/tests/custom/database/src/model/user_dao.py b/tests/custom/database/src/model/user_dao.py new file mode 100644 index 00000000..e4a0a3ba --- /dev/null +++ b/tests/custom/database/src/model/user_dao.py @@ -0,0 +1,14 @@ +from cpl.database import InternalTables +from cpl.database.abc import DbModelDaoABC +from model.user import User + + +class UserDao(DbModelDaoABC[User]): + + def __init__(self): + DbModelDaoABC.__init__(self, __name__, User, InternalTables.users) + + self.attribute(User.name, str) + self.attribute(User.city_id, int, db_name="CityId") + + self.reference("city", "id", User.city_id, "city") diff --git a/tests/custom/database/src/model/user_model.py b/tests/custom/database/src/model/user_model.py index f6fa56d6..25b07e2d 100644 --- a/tests/custom/database/src/model/user_model.py +++ b/tests/custom/database/src/model/user_model.py @@ -1,5 +1,4 @@ -from cpl.database import TableABC - +from cpl.database.abc.table_abc import TableABC from .city_model import CityModel diff --git a/tests/custom/database/src/model/user_repo.py b/tests/custom/database/src/model/user_repo.py index c287cf5d..806c2209 100644 --- a/tests/custom/database/src/model/user_repo.py +++ b/tests/custom/database/src/model/user_repo.py @@ -1,41 +1,38 @@ -from cpl.core.console import Console -from cpl.database.mysql.context import DatabaseContextABC - +from cpl.database.abc.db_context_abc import DBContextABC from .city_model import CityModel from .user_model import UserModel from .user_repo_abc import UserRepoABC class UserRepo(UserRepoABC): - def __init__(self, db_context: DatabaseContextABC): + def __init__(self, db_context: DBContextABC): UserRepoABC.__init__(self) - self._db_context: DatabaseContextABC = db_context + self._db_context: DBContextABC = db_context def add_test_user(self): city = CityModel("Haren", "49733") city2 = CityModel("Meppen", "49716") - self._db_context.cursor.execute(city2.insert_string) + self._db_context.execute(city2.insert_string) user = UserModel("TestUser", city) - self._db_context.cursor.execute(user.insert_string) - self._db_context.save_changes() + self._db_context.execute(user.insert_string) - def get_users(self) -> list[UserModel]: + async def get_users(self) -> list[UserModel]: users = [] - results = self._db_context.select("SELECT * FROM `User`") + results = await self._db_context.select("SELECT * FROM `User`") for result in results: - users.append(UserModel(result[1], self.get_city_by_id(result[2]), id=result[0])) + users.append(UserModel(result[1], await self.get_city_by_id(result[2]), id=result[0])) return users - def get_cities(self) -> list[CityModel]: + async def get_cities(self) -> list[CityModel]: cities = [] - results = self._db_context.select("SELECT * FROM `City`") + results = await self._db_context.select("SELECT * FROM `City`") for result in results: cities.append(CityModel(result[1], result[2], id=result[0])) return cities - def get_city_by_id(self, id: int) -> CityModel: + async def get_city_by_id(self, id: int) -> CityModel: if id is None: return None - result = self._db_context.select(f"SELECT * FROM `City` WHERE `Id` = {id}") + result = await self._db_context.select(f"SELECT * FROM `City` WHERE `Id` = {id}") return CityModel(result[1], result[2], id=result[0]) diff --git a/tests/custom/database/src/scripts/0-initial.sql b/tests/custom/database/src/scripts/0-initial.sql new file mode 100644 index 00000000..7fa8584f --- /dev/null +++ b/tests/custom/database/src/scripts/0-initial.sql @@ -0,0 +1,14 @@ +CREATE TABLE IF NOT EXISTS `city` ( + `id` INT(30) NOT NULL AUTO_INCREMENT, + `name` VARCHAR(64) NOT NULL, + `zip` VARCHAR(5) NOT NULL, + PRIMARY KEY(`id`) +); + +CREATE TABLE IF NOT EXISTS `users` ( + `id` INT(30) NOT NULL AUTO_INCREMENT, + `name` VARCHAR(64) NOT NULL, + `cityId` INT(30), + FOREIGN KEY (`cityId`) REFERENCES city(`id`), + PRIMARY KEY(`id`) +); \ No newline at end of file diff --git a/tests/custom/database/src/startup.py b/tests/custom/database/src/startup.py index 194734c9..c9ccfec5 100644 --- a/tests/custom/database/src/startup.py +++ b/tests/custom/database/src/startup.py @@ -1,28 +1,35 @@ -from cpl.application import StartupABC +from cpl.application.async_startup_abc import AsyncStartupABC from cpl.core.configuration import Configuration from cpl.core.environment import Environment from cpl.core.log import Logger, LoggerABC from cpl.database import mysql +from cpl.database.abc.data_access_object_abc import DataAccessObjectABC +from cpl.database.service.migration_service import MigrationService from cpl.dependency import ServiceCollection +from model.user_dao import UserDao from model.user_repo import UserRepo from model.user_repo_abc import UserRepoABC -class Startup(StartupABC): +class Startup(AsyncStartupABC): def __init__(self): - StartupABC.__init__(self) + AsyncStartupABC.__init__(self) - self._configuration = None - - def configure_configuration(self, configuration: Configuration, environment: Environment): + async def configure_configuration(self, configuration: Configuration, environment: Environment): configuration.add_json_file(f"appsettings.json") configuration.add_json_file(f"appsettings.{environment.get_environment()}.json") configuration.add_json_file(f"appsettings.{environment.get_host_name()}.json", optional=True) - self._configuration = configuration - - def configure_services(self, services: ServiceCollection, environment: Environment): + async def configure_services(self, services: ServiceCollection, environment: Environment): services.add_module(mysql) + services.add_transient(DataAccessObjectABC, UserDao) + services.add_singleton(UserRepoABC, UserRepo) services.add_singleton(LoggerABC, Logger) + + provider = services.build_service_provider() + migration_service: MigrationService = provider.get_service(MigrationService) + + migration_service.with_directory("./scripts") + await migration_service.migrate()