Added dao base
All checks were successful
Build on push / prepare (push) Successful in 8s
Build on push / query (push) Successful in 16s
Build on push / core (push) Successful in 23s
Build on push / translation (push) Successful in 14s
Build on push / mail (push) Successful in 14s

This commit is contained in:
2025-09-16 22:19:59 +02:00
parent 58dbd3ed1e
commit 4625b626e6
54 changed files with 2199 additions and 340 deletions

View File

@@ -1,23 +1,41 @@
from typing import Type
from cpl.dependency import ServiceCollection as _ServiceCollection
from . import mysql
from .database_settings import DatabaseSettings
from .database_settings_name_enum import DatabaseSettingsNameEnum
from .mysql.context import DatabaseContextABC, DatabaseContext
from .table_abc import TableABC
from . import mysql as _mysql
from . import postgres as _postgres
from .internal_tables import InternalTables
def add_mysql(collection: _ServiceCollection):
def _add(collection: _ServiceCollection,db_context: Type, default_port: int, server_type: str):
from cpl.core.console import Console
from cpl.core.configuration import Configuration
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.model.server_type import ServerTypes, ServerType
from cpl.database.model.database_settings import DatabaseSettings
from cpl.database.service.migration_service import MigrationService
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
try:
collection.add_singleton(DatabaseContextABC, DatabaseContext)
database_context = collection.build_service_provider().get_service(DatabaseContextABC)
ServerType.set_server_type(ServerTypes(server_type))
Configuration.set("DB_DEFAULT_PORT", default_port)
db_settings: DatabaseSettings = Configuration.get(DatabaseSettings)
database_context.connect(db_settings)
collection.add_singleton(DBContextABC, db_context)
collection.add_singleton(ExecutedMigrationDao)
collection.add_singleton(MigrationService)
except ImportError as e:
Console.error("cpl-translation is not installed", str(e))
Console.error("cpl-database is not installed", str(e))
def add_mysql(collection: _ServiceCollection):
from cpl.database.mysql.db_context import DBContext
from cpl.database.model import ServerTypes
_add(collection, DBContext, 3306, ServerTypes.MYSQL.value)
_ServiceCollection.with_module(add_mysql, mysql.__name__)
def add_postgres(collection: _ServiceCollection):
from cpl.database.mysql.db_context import DBContext
from cpl.database.model import ServerTypes
_add(collection, DBContext, 5432, ServerTypes.POSTGRES.value)
_ServiceCollection.with_module(add_mysql, _mysql.__name__)
_ServiceCollection.with_module(add_postgres, _postgres.__name__)

View File

@@ -0,0 +1,68 @@
import textwrap
from typing import Callable
class ExternalDataTempTableBuilder:
def __init__(self):
self._table_name = None
self._fields: dict[str, str] = {}
self._primary_key = "id"
self._join_ref_table = None
self._value_getter = None
@property
def table_name(self) -> str:
return self._table_name
@property
def fields(self) -> dict[str, str]:
return self._fields
@property
def primary_key(self) -> str:
return self._primary_key
@property
def join_ref_table(self) -> str:
return self._join_ref_table
def with_table_name(self, table_name: str) -> "ExternalDataTempTableBuilder":
self._join_ref_table = table_name
if "." in table_name:
table_name = table_name.split(".")[-1]
if not table_name.endswith("_temp"):
table_name = f"{table_name}_temp"
self._table_name = table_name
return self
def with_field(self, name: str, sql_type: str, primary=False) -> "ExternalDataTempTableBuilder":
if primary:
sql_type += " PRIMARY KEY"
self._primary_key = name
self._fields[name] = sql_type
return self
def with_value_getter(self, value_getter: Callable) -> "ExternalDataTempTableBuilder":
self._value_getter = value_getter
return self
async def build(self) -> str:
assert self._table_name is not None, "Table name is required"
assert self._value_getter is not None, "Value getter is required"
values_str = ", ".join([f"{value}" for value in await self._value_getter()])
return textwrap.dedent(
f"""
DROP TABLE IF EXISTS {self._table_name};
CREATE TEMP TABLE {self._table_name} (
{", ".join([f"{k} {v}" for k, v in self._fields.items()])}
);
INSERT INTO {self._table_name} VALUES {values_str};
"""
)

View File

@@ -0,0 +1,5 @@
from .connection_abc import ConnectionABC
from .db_context_abc import DBContextABC
from .db_join_model_abc import DbJoinModelABC
from .db_model_abc import DbModelABC
from .db_model_dao_abc import DbModelDaoABC

View File

@@ -1,11 +1,11 @@
from abc import ABC, abstractmethod
from cpl.database.database_settings import DatabaseSettings
from cpl.database.model.database_settings import DatabaseSettings
from mysql.connector.abstracts import MySQLConnectionAbstract
from mysql.connector.cursor import MySQLCursorBuffered
class DatabaseConnectionABC(ABC):
class ConnectionABC(ABC):
r"""ABC for the :class:`cpl.database.connection.database_connection.DatabaseConnection`"""
@abstractmethod

View File

@@ -0,0 +1,875 @@
import datetime
from abc import ABC, abstractmethod
from enum import Enum
from types import NoneType
from typing import Generic, Optional, Union, Type, List, Any
from cpl.core.typing import T, Id
from cpl.core.utils import String
from cpl.core.utils.get_value import get_value
from cpl.database._external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.const import DATETIME_FORMAT
from cpl.database.db_logger import DBLogger
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
class DataAccessObjectABC(ABC, Generic[T_DBM]):
@abstractmethod
def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
from cpl.dependency.service_provider_abc import ServiceProviderABC
self._db = ServiceProviderABC.get_global_provider().get_service(DBContextABC)
self._logger = DBLogger(source)
self._model_type = model_type
self._table_name = table_name
self._logger = DBLogger(source)
self._model_type = model_type
self._table_name = table_name
self._default_filter_condition = None
self.__attributes: dict[str, type] = {}
self.__db_names: dict[str, str] = {}
self.__foreign_tables: dict[str, tuple[str, str]] = {}
self.__foreign_table_keys: dict[str, str] = {}
self.__foreign_dao: dict[str, "DataAccessObjectABC"] = {}
self.__date_attributes: set[str] = set()
self.__ignored_attributes: set[str] = set()
self.__primary_key = "id"
self.__primary_key_type = int
self._external_fields: dict[str, ExternalDataTempTableBuilder] = {}
@property
def table_name(self) -> str:
return self._table_name
def has_attribute(self, attr_name: Attribute) -> bool:
"""
Check if the attribute exists in the DAO
:param Attribute attr_name: Name of the attribute
:return: True if the attribute exists, False otherwise
"""
return attr_name in self.__attributes
def attribute(
self,
attr_name: Attribute,
attr_type: type,
db_name: str = None,
ignore=False,
primary_key=False,
aliases: list[str] = None,
):
"""
Add an attribute for db and object mapping to the data access object
:param Attribute attr_name: Name of the attribute in the object
:param type attr_type: Python type of the attribute to cast db value to
:param str db_name: Name of the field in the database, if None the attribute lowered attr_name without "_" is used
:param bool ignore: Defines if field is ignored for create and update (for e.g. auto increment fields or created/updated fields)
:param bool primary_key: Defines if field is the primary key
:param list[str] aliases: List of aliases for the attribute name
:return:
"""
if isinstance(attr_name, property):
attr_name = attr_name.fget.__name__
self.__attributes[attr_name] = attr_type
if ignore:
self.__ignored_attributes.add(attr_name)
if not db_name:
db_name = attr_name.lower().replace("_", "")
self.__db_names[attr_name] = db_name
self.__db_names[db_name] = db_name
if aliases is not None:
for alias in aliases:
if alias in self.__db_names:
raise ValueError(f"Alias {alias} already exists")
self.__db_names[alias] = db_name
if primary_key:
self.__primary_key = db_name
self.__primary_key_type = attr_type
if attr_type in [datetime, datetime.datetime]:
self.__date_attributes.add(attr_name)
self.__date_attributes.add(db_name)
def reference(
self,
attr: Attribute,
primary_attr: Attribute,
foreign_attr: Attribute,
table_name: str,
reference_dao: "DataAccessObjectABC" = None,
):
"""
Add a reference to another table for the given attribute
:param Attribute attr: Name of the attribute in the object
:param str primary_attr: Name of the primary key in the foreign object
:param str foreign_attr: Name of the foreign key in the object
:param str table_name: Name of the table to reference
:param DataAccessObjectABC reference_dao: The data access object for the referenced table
:return:
"""
if isinstance(attr, property):
attr = attr.fget.__name__
if isinstance(primary_attr, property):
primary_attr = primary_attr.fget.__name__
primary_attr = primary_attr.lower().replace("_", "")
if isinstance(foreign_attr, property):
foreign_attr = foreign_attr.fget.__name__
foreign_attr = foreign_attr.lower().replace("_", "")
self.__foreign_table_keys[attr] = foreign_attr
if reference_dao is not None:
self.__foreign_dao[attr] = reference_dao
if table_name == self._table_name:
return
self.__foreign_tables[attr] = (
table_name,
f"{table_name}.{primary_attr} = {self._table_name}.{foreign_attr}",
)
def use_external_fields(self, builder: ExternalDataTempTableBuilder):
self._external_fields[builder.table_name] = builder
def to_object(self, result: dict) -> T_DBM:
"""
Convert a result from the database to an object
:param dict result: Result from the database
:return:
"""
value_map: dict[str, T] = {}
for db_name, value in result.items():
# Find the attribute name corresponding to the db_name
attr_name = next((k for k, v in self.__db_names.items() if v == db_name), None)
if attr_name:
value_map[attr_name] = self._get_value_from_sql(self.__attributes[attr_name], value)
return self._model_type(**value_map)
def to_dict(self, obj: T_DBM) -> dict:
"""
Convert an object to a dictionary
:param T_DBM obj: Object to convert
:return:
"""
value_map: dict[str, Any] = {}
for attr_name, attr_type in self.__attributes.items():
value = getattr(obj, attr_name)
if isinstance(value, datetime.datetime):
value = value.strftime(DATETIME_FORMAT)
elif isinstance(value, Enum):
value = value.value
value_map[attr_name] = value
for ex_fname in self._external_fields:
ex_field = self._external_fields[ex_fname]
for ex_attr in ex_field.fields:
if ex_attr == self.__primary_key:
continue
value_map[ex_attr] = getattr(obj, ex_attr, None)
return value_map
async def count(self, filters: AttributeFilters = None) -> int:
result = await self._prepare_query(filters=filters, for_count=True)
return result[0]["count"] if result else 0
async def get_history(
self,
entry_id: int,
by_key: str = None,
when: datetime = None,
until: datetime = None,
without_deleted: bool = False,
) -> list[T_DBM]:
"""
Retrieve the history of an entry from the history table.
:param entry_id: The ID of the entry to retrieve history for.
:param by_key: The key to filter by (default is the primary key).
:param when: A specific timestamp to filter the history.
:param until: A timestamp to filter history entries up to a certain point.
:param without_deleted: Exclude deleted entries if True.
:return: A list of historical entries as objects.
"""
f_tables = list(self.__foreign_tables.keys())
history_table = f"{self._table_name}_history"
builder = SQLSelectBuilder(history_table, self.__primary_key)
builder.with_attribute("*")
builder.with_value_condition(
f"{history_table}.{by_key or self.__primary_key}",
"=",
str(entry_id),
f_tables,
)
if self._default_filter_condition:
builder.with_condition(self._default_filter_condition, "", f_tables)
if without_deleted:
builder.with_value_condition(f"{history_table}.deleted", "=", "false", f_tables)
if when:
builder.with_value_condition(
self._attr_from_date_to_char(f"{history_table}.updated"),
"=",
f"'{when.strftime(DATETIME_FORMAT)}'",
f_tables,
)
if until:
builder.with_value_condition(
self._attr_from_date_to_char(f"{history_table}.updated"),
"<=",
f"'{until.strftime(DATETIME_FORMAT)}'",
f_tables,
)
builder.with_order_by(f"{history_table}.updated", "DESC")
query = await builder.build()
result = await self._db.select_map(query)
return [self.to_object(x) for x in result] if result else []
async def get_all(self) -> List[T_DBM]:
result = await self._prepare_query(sorts=[{self.__primary_key: "asc"}])
return [self.to_object(x) for x in result] if result else []
async def get_by_id(self, id: Union[int, str]) -> Optional[T_DBM]:
result = await self._prepare_query(filters=[{self.__primary_key: id}], sorts=[{self.__primary_key: "asc"}])
return self.to_object(result[0]) if result else None
async def find_by_id(self, id: Union[int, str]) -> Optional[T_DBM]:
result = await self._prepare_query(filters=[{self.__primary_key: id}], sorts=[{self.__primary_key: "asc"}])
return self.to_object(result[0]) if result else None
async def get_by(
self,
filters: AttributeFilters = None,
sorts: AttributeSorts = None,
take: int = None,
skip: int = None,
) -> list[T_DBM]:
result = await self._prepare_query(filters, sorts, take, skip)
if not result or len(result) == 0:
raise ValueError("No result found")
return [self.to_object(x) for x in result] if result else []
async def get_single_by(
self,
filters: AttributeFilters = None,
sorts: AttributeSorts = None,
take: int = None,
skip: int = None,
) -> T_DBM:
result = await self._prepare_query(filters, sorts, take, skip)
if not result:
raise ValueError("No result found")
if len(result) > 1:
raise ValueError("More than one result found")
return self.to_object(result[0])
async def find_by(
self,
filters: AttributeFilters = None,
sorts: AttributeSorts = None,
take: int = None,
skip: int = None,
) -> list[T_DBM]:
result = await self._prepare_query(filters, sorts, take, skip)
return [self.to_object(x) for x in result] if result else []
async def find_single_by(
self,
filters: AttributeFilters = None,
sorts: AttributeSorts = None,
take: int = None,
skip: int = None,
) -> Optional[T_DBM]:
result = await self._prepare_query(filters, sorts, take, skip)
if len(result) > 1:
raise ValueError("More than one result found")
return self.to_object(result[0]) if result else None
async def touch(self, obj: T_DBM):
"""
Touch the entry to update the last updated date
:return:
"""
await self._db.execute(
f"""
UPDATE {self._table_name}
SET updated = NOW()
WHERE {self.__primary_key} = {self._get_primary_key_value_sql(obj)};
"""
)
async def touch_many_by_id(self, ids: list[Id]):
"""
Touch the entries to update the last updated date
:return:
"""
if len(ids) == 0:
return
await self._db.execute(
f"""
UPDATE {self._table_name}
SET updated = NOW()
WHERE {self.__primary_key} IN ({", ".join([str(x) for x in ids])});
"""
)
async def _build_create_statement(self, obj: T_DBM, skip_editor=False) -> str:
allowed_fields = [x for x in self.__attributes.keys() if x not in self.__ignored_attributes]
fields = ", ".join([self.__db_names[x] for x in allowed_fields])
fields = f"{'EditorId' if not skip_editor else ''}{f', {fields}' if not skip_editor and len(fields) > 0 else f'{fields}'}"
values = ", ".join([self._get_value_sql(getattr(obj, x)) for x in allowed_fields])
values = f"{await self._get_editor_id(obj) if not skip_editor else ''}{f', {values}' if not skip_editor and len(values) > 0 else f'{values}'}"
return f"""
INSERT INTO {self._table_name} (
{fields}
) VALUES (
{values}
)
RETURNING {self.__primary_key};
"""
async def create(self, obj: T_DBM, skip_editor=False) -> int:
self._logger.debug(f"create {type(obj).__name__} {obj.__dict__}")
result = await self._db.execute(await self._build_create_statement(obj, skip_editor))
return self._get_value_from_sql(self.__primary_key_type, result[0][0])
async def create_many(self, objs: list[T_DBM], skip_editor=False) -> list[int]:
if len(objs) == 0:
return []
self._logger.debug(f"create many {type(objs[0]).__name__} {len(objs)} {[x.__dict__ for x in objs]}")
query = ""
for obj in objs:
query += await self._build_create_statement(obj, skip_editor)
result = await self._db.execute(query)
return [self._get_value_from_sql(self.__primary_key_type, x[0]) for x in result]
async def _build_update_statement(self, obj: T_DBM, skip_editor=False) -> str:
allowed_fields = [x for x in self.__attributes.keys() if x not in self.__ignored_attributes]
fields = ", ".join(
[f"{self.__db_names[x]} = {self._get_value_sql(getattr(obj, x, None))}" for x in allowed_fields]
)
fields = f"{f'EditorId = {await self._get_editor_id(obj)}' if not skip_editor else ''}{f', {fields}' if not skip_editor and len(fields) > 0 else f'{fields}'}"
return f"""
UPDATE {self._table_name}
SET {fields}
WHERE {self.__primary_key} = {self._get_primary_key_value_sql(obj)};
"""
async def update(self, obj: T_DBM, skip_editor=False):
self._logger.debug(f"update {type(obj).__name__} {obj.__dict__}")
await self._db.execute(await self._build_update_statement(obj, skip_editor))
async def update_many(self, objs: list[T_DBM], skip_editor=False):
if len(objs) == 0:
return
self._logger.debug(f"update many {type(objs[0]).__name__} {len(objs)} {[x.__dict__ for x in objs]}")
query = ""
for obj in objs:
query += await self._build_update_statement(obj, skip_editor)
await self._db.execute(query)
async def _build_delete_statement(self, obj: T_DBM, hard_delete: bool = False) -> str:
if hard_delete:
return f"""
DELETE FROM {self._table_name}
WHERE {self.__primary_key} = {self._get_primary_key_value_sql(obj)};
"""
return f"""
UPDATE {self._table_name}
SET EditorId = {await self._get_editor_id(obj)},
Deleted = true
WHERE {self.__primary_key} = {self._get_primary_key_value_sql(obj)};
"""
async def delete(self, obj: T_DBM, hard_delete: bool = False):
self._logger.debug(f"delete {type(obj).__name__} {obj.__dict__}")
await self._db.execute(await self._build_delete_statement(obj, hard_delete))
async def delete_many(self, objs: list[T_DBM], hard_delete: bool = False):
if len(objs) == 0:
return
self._logger.debug(f"delete many {type(objs[0]).__name__} {len(objs)} {[x.__dict__ for x in objs]}")
query = ""
for obj in objs:
query += await self._build_delete_statement(obj, hard_delete)
await self._db.execute(query)
async def _build_restore_statement(self, obj: T_DBM) -> str:
return f"""
UPDATE {self._table_name}
SET EditorId = {await self._get_editor_id(obj)},
Deleted = false
WHERE {self.__primary_key} = {self._get_primary_key_value_sql(obj)};
"""
async def restore(self, obj: T_DBM):
self._logger.debug(f"restore {type(obj).__name__} {obj.__dict__}")
await self._db.execute(await self._build_restore_statement(obj))
async def restore_many(self, objs: list[T_DBM]):
if len(objs) == 0:
return
self._logger.debug(f"restore many {type(objs[0]).__name__} {len(objs)} {objs[0].__dict__}")
query = ""
for obj in objs:
query += await self._build_restore_statement(obj)
await self._db.execute(query)
async def _prepare_query(
self,
filters: AttributeFilters = None,
sorts: AttributeSorts = None,
take: int = None,
skip: int = None,
for_count=False,
) -> list[dict]:
"""
Prepares and executes a query using the SQLBuilder with the given parameters.
:param filters: Conditions to filter the query.
:param sorts: Sorting attributes and directions.
:param take: Limit the number of results.
:param skip: Offset the results.
:return: Query result as a list of dictionaries.
"""
external_table_deps = []
builder = SQLSelectBuilder(self._table_name, self.__primary_key)
for temp in self._external_fields:
builder.with_temp_table(self._external_fields[temp])
if for_count:
builder.with_attribute("COUNT(*)", ignore_table_name=True)
else:
builder.with_attribute("*")
for attr in self.__foreign_tables:
table, join_condition = self.__foreign_tables[attr]
builder.with_left_join(table, join_condition)
if filters:
await self._build_conditions(builder, filters, external_table_deps)
if sorts:
self._build_sorts(builder, sorts, external_table_deps)
if take:
builder.with_limit(take)
if skip:
builder.with_offset(skip)
for external_table in external_table_deps:
builder.use_temp_table(external_table)
query = await builder.build()
return await self._db.select_map(query)
async def _build_conditions(
self,
builder: SQLSelectBuilder,
filters: AttributeFilters,
external_table_deps: list[str],
):
"""
Builds SQL conditions from GraphQL-like filters and adds them to the SQLBuilder.
:param builder: The SQLBuilder instance to add conditions to.
:param filters: GraphQL-like filter structure.
:param external_table_deps: List to store external table dependencies.
"""
if not isinstance(filters, list):
filters = [filters]
for filter_group in filters:
sql_conditions = self._graphql_to_sql_conditions(filter_group, external_table_deps)
for attr, operator, value in sql_conditions:
if attr in self.__foreign_table_keys:
attr = self.__foreign_table_keys[attr]
recursive_join = self._get_recursive_reference_join(attr)
if recursive_join is not None:
builder.with_left_join(*recursive_join)
external_table = self._get_external_field_key(attr)
if external_table is not None:
external_table_deps.append(external_table)
if operator == "fuzzy":
builder.with_levenshtein_condition(attr)
elif operator in [
"IS NULL",
"IS NOT NULL",
]: # operator without value
builder.with_condition(
attr,
operator,
[
x[0]
for fdao in self.__foreign_dao
for x in self.__foreign_dao[fdao].__foreign_tables.values()
],
)
else:
if attr in self.__date_attributes or String.to_snake_case(attr) in self.__date_attributes:
attr = self._attr_from_date_to_char(f"{self._table_name}.{attr}")
builder.with_value_condition(
attr,
operator,
self._get_value_sql(value),
[
x[0]
for fdao in self.__foreign_dao
for x in self.__foreign_dao[fdao].__foreign_tables.values()
],
)
def _graphql_to_sql_conditions(
self, graphql_structure: dict, external_table_deps: list[str]
) -> list[tuple[str, str, Any]]:
"""
Converts a GraphQL-like structure to SQL conditions.
:param graphql_structure: The GraphQL-like filter structure.
:param external_table_deps: List to track external table dependencies.
:return: A list of tuples (attribute, operator, value).
"""
operators = {
"equal": "=",
"notEqual": "!=",
"greater": ">",
"greaterOrEqual": ">=",
"less": "<",
"lessOrEqual": "<=",
"isNull": "IS NULL",
"isNotNull": "IS NOT NULL",
"contains": "LIKE", # Special handling in _graphql_to_sql_conditions
"notContains": "NOT LIKE", # Special handling in _graphql_to_sql_conditions
"startsWith": "LIKE", # Special handling in _graphql_to_sql_conditions
"endsWith": "LIKE", # Special handling in _graphql_to_sql_conditions
"in": "IN",
"notIn": "NOT IN",
}
conditions = []
def parse_node(node, parent_key=None, parent_dao=None):
if not isinstance(node, dict):
return
if isinstance(node, list):
conditions.append((parent_key, "IN", node))
return
for key, value in node.items():
if isinstance(key, property):
key = key.fget.__name__
external_fields_table_name_by_parent = self._get_external_field_key(parent_key)
external_fields_table_name = self._get_external_field_key(key)
external_field = (
external_fields_table_name
if external_fields_table_name_by_parent is None
else external_fields_table_name_by_parent
)
if key == "fuzzy":
self._handle_fuzzy_filter_conditions(conditions, external_table_deps, value)
elif parent_dao is not None and key in parent_dao.__db_names:
parse_node(value, f"{parent_dao.table_name}.{key}")
continue
elif external_field is not None:
external_table_deps.append(external_field)
parse_node(value, f"{external_field}.{key}")
elif parent_key in self.__foreign_table_keys:
if key in operators:
parse_node({key: value}, self.__foreign_table_keys[parent_key])
continue
if parent_key in self.__foreign_dao:
foreign_dao = self.__foreign_dao[parent_key]
if key in foreign_dao.__foreign_tables:
parse_node(
value,
f"{self.__foreign_tables[parent_key][0]}.{foreign_dao.__foreign_table_keys[key]}",
foreign_dao.__foreign_dao[key],
)
continue
if parent_key in self.__foreign_tables:
parse_node(value, f"{self.__foreign_tables[parent_key][0]}.{key}")
continue
parse_node({parent_key: value})
elif key in operators:
operator = operators[key]
if key == "contains" or key == "notContains":
value = f"%{value}%"
elif key == "in" or key == "notIn":
value = value
elif key == "startsWith":
value = f"{value}%"
elif key == "endsWith":
value = f"%{value}"
elif key == "isNull" or key == "isNotNull":
is_null_value = value.get("equal", None) if isinstance(value, dict) else value
if is_null_value is None:
operator = operators[key]
elif (key == "isNull" and is_null_value) or (key == "isNotNull" and not is_null_value):
operator = "IS NULL"
else:
operator = "IS NOT NULL"
conditions.append((parent_key, operator, None))
elif (key == "equal" or key == "notEqual") and value is None:
operator = operators["isNull"]
conditions.append((parent_key, operator, value))
elif isinstance(value, dict):
if key in self.__foreign_table_keys:
parse_node(value, key)
elif key in self.__db_names and parent_key is not None:
parse_node({f"{parent_key}": value})
elif key in self.__db_names:
parse_node(value, self.__db_names[key])
else:
parse_node(value, key)
elif value is None:
conditions.append((self.__db_names[key], "IS NULL", value))
else:
conditions.append((self.__db_names[key], "=", value))
parse_node(graphql_structure)
return conditions
def _handle_fuzzy_filter_conditions(self, conditions, external_field_table_deps, sub_values):
# Extract fuzzy filter parameters
fuzzy_fields = get_value(sub_values, "fields", list[str])
fuzzy_term = get_value(sub_values, "term", str)
fuzzy_threshold = get_value(sub_values, "threshold", int, 5)
if not fuzzy_fields or not fuzzy_term:
raise ValueError("Fuzzy filter must include 'fields' and 'term'.")
fuzzy_fields_db_names = []
# Map fields to their database names
for fuzzy_field in fuzzy_fields:
external_fields_table_name = self._get_external_field_key(fuzzy_field)
if external_fields_table_name is not None:
external_fields_table = self._external_fields[external_fields_table_name]
fuzzy_fields_db_names.append(f"{external_fields_table.table_name}.{fuzzy_field}")
external_field_table_deps.append(external_fields_table.table_name)
elif fuzzy_field in self.__db_names:
fuzzy_fields_db_names.append(f"{self._table_name}.{self.__db_names[fuzzy_field]}")
elif fuzzy_field in self.__foreign_tables:
fuzzy_fields_db_names.append(f"{self._table_name}.{self.__foreign_table_keys[fuzzy_field]}")
else:
fuzzy_fields_db_names.append(self.__db_names[String.to_snake_case(fuzzy_field)][0])
# Build fuzzy conditions for each field
fuzzy_conditions = self._build_fuzzy_conditions(fuzzy_fields_db_names, fuzzy_term, fuzzy_threshold)
# Combine conditions with OR and append to the main conditions
conditions.append((f"({' OR '.join(fuzzy_conditions)})", "fuzzy", None))
@staticmethod
def _build_fuzzy_conditions(fields: list[str], term: str, threshold: int = 10) -> list[str]:
conditions = []
for field in fields:
conditions.append(f"levenshtein({field}::TEXT, '{term}') <= {threshold}") # Adjust the threshold as needed
return conditions
def _get_external_field_key(self, field_name: str) -> Optional[str]:
"""
Returns the key to get the external field if found, otherwise None.
:param str field_name: The name of the field to search for.
:return: The key if found, otherwise None.
:rtype: Optional[str]
"""
if field_name is None:
return None
for key, builder in self._external_fields.items():
if field_name in builder.fields and field_name not in self.__db_names:
return key
return None
def _get_recursive_reference_join(self, attr: str) -> Optional[tuple[str, str]]:
parts = attr.split(".")
table_name = ".".join(parts[:-1])
if table_name == self._table_name or table_name == "":
return None
all_foreign_tables = {
x[0]: x[1]
for x in [
*[x for x in self.__foreign_tables.values() if x[0] != self._table_name],
*[x for fdao in self.__foreign_dao for x in self.__foreign_dao[fdao].__foreign_tables.values()],
]
}
if not table_name in all_foreign_tables:
return None
return table_name, all_foreign_tables[table_name]
def _build_sorts(
self,
builder: SQLSelectBuilder,
sorts: AttributeSorts,
external_table_deps: list[str],
):
"""
Resolves complex sorting structures into SQL-compatible sorting conditions.
Tracks external table dependencies.
:param builder: The SQLBuilder instance to add sorting to.
:param sorts: Sorting attributes and directions in a complex structure.
:param external_table_deps: List to track external table dependencies.
"""
def parse_sort_node(node, parent_key=None):
if isinstance(node, dict):
for key, value in node.items():
if isinstance(value, dict):
# Recursively parse nested structures
parse_sort_node(value, key)
elif isinstance(value, str) and value.lower() in ["asc", "desc"]:
external_table = self._get_external_field_key(key)
if external_table:
external_table_deps.append(external_table)
key = f"{external_table}.{key}"
if parent_key in self.__foreign_tables:
key = f"{self.__foreign_tables[parent_key][0]}.{key}"
builder.with_order_by(key, value.upper())
else:
raise ValueError(f"Invalid sort direction: {value}")
elif isinstance(node, list):
for item in node:
parse_sort_node(item)
else:
raise ValueError(f"Invalid sort structure: {node}")
parse_sort_node(sorts)
def _get_value_sql(self, value: Any) -> str:
if isinstance(value, str):
if value.lower() == "null":
return "NULL"
return f"'{value}'"
if isinstance(value, NoneType):
return "NULL"
if value is None:
return "NULL"
if isinstance(value, Enum):
return f"'{value.value}'"
if isinstance(value, bool):
return "true" if value else "false"
if isinstance(value, list):
if len(value) == 0:
return "()"
return f"({', '.join([self._get_value_sql(x) for x in value])})"
if isinstance(value, datetime.datetime):
if value.tzinfo is None:
value = value.replace(tzinfo=datetime.timezone.utc)
return f"'{value.strftime(DATETIME_FORMAT)}'"
return str(value)
@staticmethod
def _get_value_from_sql(cast_type: type, value: Any) -> Optional[T]:
"""
Get the value from the query result and cast it to the correct type
:param type cast_type:
:param Any value:
:return Optional[T]: Casted value, when value is str "NULL" None is returned
"""
if isinstance(value, str) and "NULL" in value:
return None
if isinstance(value, NoneType):
return None
if isinstance(value, cast_type):
return value
return cast_type(value)
def _get_primary_key_value_sql(self, obj: T_DBM) -> str:
value = getattr(obj, self.__primary_key)
if isinstance(value, str):
return f"'{value}'"
return value
@staticmethod
def _attr_from_date_to_char(attr: str) -> str:
return f"TO_CHAR({attr}, 'YYYY-MM-DD HH24:MI:SS.US TZ')"
@staticmethod
async def _get_editor_id(obj: T_DBM):
editor_id = obj.editor_id
# if editor_id is None:
# user = get_user()
# if user is not None:
# editor_id = user.id
return editor_id if editor_id is not None else "NULL"

View File

@@ -0,0 +1,53 @@
from abc import ABC, abstractmethod
from typing import Any
from cpl.database.model.database_settings import DatabaseSettings
class DBContextABC(ABC):
r"""ABC for the :class:`cpl.database.context.database_context.DatabaseContext`"""
@abstractmethod
def connect(self, database_settings: DatabaseSettings):
r"""Connects to a database by connection settings
Parameter:
database_settings :class:`cpl.database.database_settings.DatabaseSettings`
"""
@abstractmethod
async def execute(self, statement: str, args=None, multi=True) -> list[list]:
r"""Runs SQL Statements
Parameter:
statement: :class:`str`
args: :class:`list` | :class:`tuple` | :class:`dict` | :class:`None`
multi: :class:`bool`
Returns:
list: Fetched list of executed elements
"""
@abstractmethod
async def select_map(self, statement: str, args=None) -> list[dict]:
r"""Runs SQL Select Statements and returns a list of dictionaries
Parameter:
statement: :class:`str`
args: :class:`list` | :class:`tuple` | :class:`dict` | :class:`None`
Returns:
list: Fetched list of executed elements as dictionary
"""
@abstractmethod
async def select(self, statement: str, args=None) -> list[str] | list[tuple] | list[Any]:
r"""Runs SQL Select Statements and returns a list of dictionaries
Parameter:
statement: :class:`str`
args: :class:`list` | :class:`tuple` | :class:`dict` | :class:`None`
Returns:
list: Fetched list of executed elements
"""

View File

@@ -0,0 +1,30 @@
from datetime import datetime
from typing import Optional
from cpl.core.typing import Id, SerialId
from cpl.database.abc.db_model_abc import DbModelABC
class DbJoinModelABC[T](DbModelABC[T]):
def __init__(
self,
id: Id,
source_id: Id,
foreign_id: Id,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._source_id = source_id
self._foreign_id = foreign_id
@property
def source_id(self) -> Id:
return self._source_id
@property
def foreign_id(self) -> Id:
return self._foreign_id

View File

@@ -0,0 +1,79 @@
from abc import ABC
from datetime import datetime, timezone
from typing import Optional, Generic
from cpl.core.typing import Id, SerialId, T
class DbModelABC(ABC, Generic[T]):
def __init__(
self,
id: Id,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
):
self._id = id
self._deleted = deleted
self._editor_id = editor_id
self._created = created if created is not None else datetime.now(timezone.utc).isoformat()
self._updated = updated if updated is not None else datetime.now(timezone.utc).isoformat()
@property
def id(self) -> Id:
return self._id
@property
def deleted(self) -> bool:
return self._deleted
@deleted.setter
def deleted(self, value: bool):
self._deleted = value
@property
def editor_id(self) -> SerialId:
return self._editor_id
@editor_id.setter
def editor_id(self, value: SerialId):
self._editor_id = value
# @async_property
# async def editor(self):
# if self._editor_id is None:
# return None
#
# from data.schemas.administration.user_dao import userDao
#
# return await userDao.get_by_id(self._editor_id)
@property
def created(self) -> datetime:
return self._created
@property
def updated(self) -> datetime:
return self._updated
@updated.setter
def updated(self, value: datetime):
self._updated = value
def to_dict(self) -> dict:
result = {}
for name, value in self.__dict__.items():
if not name.startswith("_") or name.endswith("_"):
continue
if isinstance(value, datetime):
value = value.isoformat()
if not isinstance(value, str):
value = str(value)
result[name.replace("_", "")] = value
return result

View File

@@ -0,0 +1,25 @@
from abc import abstractmethod
from datetime import datetime
from typing import Type
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.database.abc.db_model_abc import DbModelABC
from cpl.database.internal_tables import InternalTables
class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]):
@abstractmethod
def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
DataAccessObjectABC.__init__(self, source, model_type, table_name)
self.attribute(DbModelABC.id, int, ignore=True)
self.attribute(DbModelABC.deleted, bool)
self.attribute(DbModelABC.editor_id, int, ignore=True) # handled by db trigger
self.reference(
"editor", "id", DbModelABC.editor_id, InternalTables.users
) # not relevant for updates due to editor_id
self.attribute(DbModelABC.created, datetime, ignore=True) # handled by db trigger
self.attribute(DbModelABC.updated, datetime, ignore=True) # handled by db trigger

View File

@@ -0,0 +1 @@
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f %z"

View File

@@ -1,13 +0,0 @@
from enum import Enum
class DatabaseSettingsNameEnum(Enum):
host = "Host"
port = "Port"
user = "User"
password = "Password"
database = "Database"
charset = "Charset"
use_unicode = "UseUnicode"
buffered = "Buffered"
auth_plugin = "AuthPlugin"

View File

@@ -0,0 +1,15 @@
from cpl.database.model.server_type import ServerTypes, ServerType
class InternalTables:
@classmethod
@property
def users(cls) -> str:
return "administration.users" if ServerType.server_type is ServerTypes.POSTGRES else "users"
@classmethod
@property
def executed_migrations(cls) -> str:
return "system._executed_migrations" if ServerType.server_type is ServerTypes.POSTGRES else "_executed_migrations"

View File

@@ -0,0 +1,3 @@
from .database_settings import DatabaseSettings
from .migration import Migration
from .server_type import ServerTypes

View File

@@ -1,6 +1,9 @@
from typing import Optional
from cpl.core.configuration import Configuration
from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC
from cpl.core.environment import Environment
from cpl.core.utils import Base64
class DatabaseSettings(ConfigurationModelABC):
@@ -8,23 +11,23 @@ class DatabaseSettings(ConfigurationModelABC):
def __init__(
self,
host: str = None,
port: int = 3306,
user: str = None,
password: str = None,
database: str = None,
charset: str = "utf8mb4",
use_unicode: bool = False,
buffered: bool = False,
auth_plugin: str = "caching_sha2_password",
ssl_disabled: bool = False,
host: str = Environment.get("DB_HOST", str),
port: int = Environment.get("DB_PORT", str, Configuration.get("DB_DEFAULT_PORT", 0)),
user: str = Environment.get("DB_USER", str),
password: str = Environment.get("DB_PASSWORD", str),
database: str = Environment.get("DB_DATABASE", str),
charset: str = Environment.get("DB_CHARSET", str, "utf8mb4"),
use_unicode: bool = Environment.get("DB_USE_UNICODE", bool, False),
buffered: bool = Environment.get("DB_BUFFERED", bool, False),
auth_plugin: str = Environment.get("DB_AUTH_PLUGIN", str, "caching_sha2_password"),
ssl_disabled: bool = Environment.get("DB_SSL_DISABLED", bool, False),
):
ConfigurationModelABC.__init__(self)
self._host: Optional[str] = host
self._port: Optional[int] = port
self._user: Optional[str] = user
self._password: Optional[str] = password
self._password: Optional[str] = Base64.decode(password) if Base64.is_b64(password) else password
self._database: Optional[str] = database
self._charset: Optional[str] = charset
self._use_unicode: Optional[bool] = use_unicode

View File

@@ -0,0 +1,12 @@
class Migration:
def __init__(self, name: str, script: str):
self._name = name
self._script = script
@property
def name(self) -> str:
return self._name
@property
def script(self) -> str:
return self._script

View File

@@ -0,0 +1,21 @@
from enum import Enum
class ServerTypes(Enum):
POSTGRES = "postgres"
MYSQL = "mysql"
class ServerType:
_server_type: ServerTypes = None
@classmethod
def set_server_type(cls, server_type: ServerTypes):
assert server_type is not None, "server_type must not be None"
assert isinstance(server_type, ServerTypes), f"Expected ServerType but got {type(server_type)}"
cls._server_type = server_type
@classmethod
@property
def server_type(cls) -> ServerTypes:
assert cls._server_type is not None, "Server type is not set"
return cls._server_type

View File

@@ -4,16 +4,16 @@ import mysql.connector as sql
from mysql.connector.abstracts import MySQLConnectionAbstract
from mysql.connector.cursor import MySQLCursorBuffered
from cpl.database.mysql.connection.database_connection_abc import DatabaseConnectionABC
from cpl.database.abc.connection_abc import ConnectionABC
from cpl.database.database_settings import DatabaseSettings
from cpl.core.utils.credential_manager import CredentialManager
class DatabaseConnection(DatabaseConnectionABC):
class DatabaseConnection(ConnectionABC):
r"""Representation of the database connection"""
def __init__(self):
DatabaseConnectionABC.__init__(self)
ConnectionABC.__init__(self)
self._database: Optional[MySQLConnectionAbstract] = None
self._cursor: Optional[MySQLCursorBuffered] = None

View File

@@ -1,2 +0,0 @@
from .database_connection import DatabaseConnection
from .database_connection_abc import DatabaseConnectionABC

View File

@@ -1,2 +0,0 @@
from .database_context import DatabaseContext
from .database_context_abc import DatabaseContextABC

View File

@@ -1,52 +0,0 @@
from typing import Optional
from cpl.database.mysql.connection.database_connection import DatabaseConnection
from cpl.database.mysql.connection.database_connection_abc import DatabaseConnectionABC
from cpl.database.mysql.context.database_context_abc import DatabaseContextABC
from cpl.database.database_settings import DatabaseSettings
from mysql.connector.cursor import MySQLCursorBuffered
class DatabaseContext(DatabaseContextABC):
r"""Representation of the database context
Parameter:
database_settings: :class:`cpl.database.database_settings.DatabaseSettings`
"""
def __init__(self):
DatabaseContextABC.__init__(self)
self._db: DatabaseConnectionABC = DatabaseConnection()
self._settings: Optional[DatabaseSettings] = None
@property
def cursor(self) -> MySQLCursorBuffered:
self._ping_and_reconnect()
return self._db.cursor
def _ping_and_reconnect(self):
try:
self._db.server.ping(reconnect=True, attempts=3, delay=5)
except Exception:
# reconnect your cursor as you did in __init__ or wherever
if self._settings is None:
raise Exception("Call DatabaseContext.connect first")
self.connect(self._settings)
def connect(self, database_settings: DatabaseSettings):
if self._settings is None:
self._settings = database_settings
self._db.connect(database_settings)
self.save_changes()
def save_changes(self):
self._ping_and_reconnect()
self._db.server.commit()
def select(self, statement: str) -> list[tuple]:
self._ping_and_reconnect()
self._db.cursor.execute(statement)
return self._db.cursor.fetchall()

View File

@@ -1,40 +0,0 @@
from abc import ABC, abstractmethod
from cpl.database.database_settings import DatabaseSettings
from mysql.connector.cursor import MySQLCursorBuffered
class DatabaseContextABC(ABC):
r"""ABC for the :class:`cpl.database.context.database_context.DatabaseContext`"""
@abstractmethod
def __init__(self, *args):
pass
@property
@abstractmethod
def cursor(self) -> MySQLCursorBuffered:
pass
@abstractmethod
def connect(self, database_settings: DatabaseSettings):
r"""Connects to a database by connection settings
Parameter:
database_settings :class:`cpl.database.database_settings.DatabaseSettings`
"""
@abstractmethod
def save_changes(self):
r"""Saves changes of the database"""
@abstractmethod
def select(self, statement: str) -> list[tuple]:
r"""Runs SQL Statements
Parameter:
statement: :class:`str`
Returns:
list: Fetched list of selected elements
"""

View File

@@ -0,0 +1,84 @@
import uuid
from typing import Any, List, Dict, Tuple, Union
from mysql.connector import Error as MySQLError, PoolError
from cpl.core.configuration import Configuration
from cpl.core.environment import Environment
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.db_logger import DBLogger
from cpl.database.model.database_settings import DatabaseSettings
from cpl.database.mysql.mysql_pool import MySQLPool
_logger = DBLogger(__name__)
class DBContext(DBContextABC):
def __init__(self):
DBContextABC.__init__(self)
self._pool: MySQLPool = None
self._fails = 0
self.connect(Configuration.get(DatabaseSettings))
def connect(self, database_settings: DatabaseSettings):
try:
_logger.debug("Connecting to database")
self._pool = MySQLPool(
database_settings,
)
_logger.info("Connected to database")
except Exception as e:
_logger.fatal("Connecting to database failed", e)
async def execute(self, statement: str, args=None, multi=True) -> List[List]:
_logger.trace(f"execute {statement} with args: {args}")
return await self._pool.execute(statement, args, multi)
async def select_map(self, statement: str, args=None) -> List[Dict]:
_logger.trace(f"select {statement} with args: {args}")
try:
return await self._pool.select_map(statement, args)
except (MySQLError, PoolError) as e:
if self._fails >= 3:
_logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4()
raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
)
_logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1
try:
_logger.debug("Retry select")
return await self.select_map(statement, args)
except Exception as e:
pass
return []
except Exception as e:
_logger.error(f"Database error caused by `{statement}`", e)
raise e
async def select(self, statement: str, args=None) -> Union[List[str], List[Tuple], List[Any]]:
_logger.trace(f"select {statement} with args: {args}")
try:
return await self._pool.select(statement, args)
except (MySQLError, PoolError) as e:
if self._fails >= 3:
_logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4()
raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
)
_logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1
try:
_logger.debug("Retry select")
return await self.select(statement, args)
except Exception as e:
pass
return []
except Exception as e:
_logger.error(f"Database error caused by `{statement}`", e)
raise e

View File

@@ -0,0 +1,105 @@
from typing import Optional, Any
import sqlparse
import aiomysql
from cpl.core.environment import Environment
from cpl.database.db_logger import DBLogger
from cpl.database.model import DatabaseSettings
_logger = DBLogger(__name__)
class MySQLPool:
"""
Create a pool when connecting to MySQL, which will decrease the time spent in
requesting connection, creating connection, and closing connection.
"""
def __init__(self, database_settings: DatabaseSettings):
self._db_settings = database_settings
self.pool: Optional[aiomysql.Pool] = None
async def _get_pool(self):
if self.pool is None or self.pool._closed:
try:
self.pool = await aiomysql.create_pool(
host=self._db_settings.host,
port=self._db_settings.port,
user=self._db_settings.user,
password=self._db_settings.password,
db=self._db_settings.database,
minsize=1,
maxsize=Environment.get("DB_POOL_SIZE", int, 1),
autocommit=True,
)
except Exception as e:
_logger.fatal("Failed to connect to the database", e)
raise
return self.pool
@staticmethod
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
if multi:
queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()]
for q in queries:
if q.strip() == "":
continue
await cursor.execute(q, args)
else:
await cursor.execute(query, args)
async def execute(self, query: str, args=None, multi=True) -> list[list]:
"""
Execute a SQL statement, it could be with args and without args. The usage is
similar to the execute() function in aiomysql.
:param query: SQL clause
:param args: args needed by the SQL clause
:param multi: if the query is a multi-statement
:return: return result
"""
pool = await self._get_pool()
async with pool.acquire() as con:
async with con.cursor() as cursor:
await self._exec_sql(cursor, query, args, multi)
if cursor.description is not None: # Query returns rows
res = await cursor.fetchall()
if res is None:
return []
return [list(row) for row in res]
else:
return []
async def select(self, query: str, args=None, multi=True) -> list[str]:
"""
Execute a SQL statement, it could be with args and without args. The usage is
similar to the execute() function in aiomysql.
:param query: SQL clause
:param args: args needed by the SQL clause
:param multi: if the query is a multi-statement
:return: return result
"""
pool = await self._get_pool()
async with pool.acquire() as con:
async with con.cursor() as cursor:
await self._exec_sql(cursor, query, args, multi)
res = await cursor.fetchall()
return list(res)
async def select_map(self, query: str, args=None, multi=True) -> list[dict]:
"""
Execute a SQL statement, it could be with args and without args. The usage is
similar to the execute() function in aiomysql.
:param query: SQL clause
:param args: args needed by the SQL clause
:param multi: if the query is a multi-statement
:return: return result
"""
pool = await self._get_pool()
async with pool.acquire() as con:
async with con.cursor(aiomysql.DictCursor) as cursor:
await self._exec_sql(cursor, query, args, multi)
res = await cursor.fetchall()
return list(res)

View File

@@ -0,0 +1,86 @@
import uuid
from typing import Any
from psycopg import OperationalError
from psycopg_pool import PoolTimeout
from cpl.core.configuration import Configuration
from cpl.core.environment import Environment
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.database_settings import DatabaseSettings
from cpl.database.db_logger import DBLogger
from cpl.database.postgres.postgres_pool import PostgresPool
_logger = DBLogger(__name__)
class DBContext(DBContextABC):
def __init__(self):
DBContextABC.__init__(self)
self._pool: PostgresPool = None
self._fails = 0
self.connect(Configuration.get(DatabaseSettings))
def connect(self, database_settings: DatabaseSettings):
try:
_logger.debug("Connecting to database")
self._pool = PostgresPool(
database_settings,
Environment.get("DB_POOL_SIZE", int, 1),
)
_logger.info("Connected to database")
except Exception as e:
_logger.fatal("Connecting to database failed", e)
async def execute(self, statement: str, args=None, multi=True) -> list[list]:
_logger.trace(f"execute {statement} with args: {args}")
return await self._pool.execute(statement, args, multi)
async def select_map(self, statement: str, args=None) -> list[dict]:
_logger.trace(f"select {statement} with args: {args}")
try:
return await self._pool.select_map(statement, args)
except (OperationalError, PoolTimeout) as e:
if self._fails >= 3:
_logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4()
raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
)
_logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1
try:
_logger.debug("Retry select")
return await self.select_map(statement, args)
except Exception as e:
pass
return []
except Exception as e:
_logger.error(f"Database error caused by `{statement}`", e)
raise e
async def select(self, statement: str, args=None) -> list[str] | list[tuple] | list[Any]:
_logger.trace(f"select {statement} with args: {args}")
try:
return await self._pool.select(statement, args)
except (OperationalError, PoolTimeout) as e:
if self._fails >= 3:
_logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4()
raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
)
_logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1
try:
_logger.debug("Retry select")
return await self.select(statement, args)
except Exception as e:
pass
return []
except Exception as e:
_logger.error(f"Database error caused by `{statement}`", e)
raise e

View File

@@ -0,0 +1,123 @@
from typing import Optional, Any
import sqlparse
from psycopg import sql
from psycopg_pool import AsyncConnectionPool, PoolTimeout
from cpl.core.environment import Environment
from cpl.database.db_logger import DBLogger
from cpl.database.model import DatabaseSettings
_logger = DBLogger(__name__)
class PostgresPool:
"""
Create a pool when connecting to PostgreSQL, which will decrease the time spent in
requesting connection, creating connection, and closing connection.
"""
def __init__(self, database_settings: DatabaseSettings):
self._conninfo = (
f"host={database_settings.host} "
f"port={database_settings.port} "
f"user={database_settings.user} "
f"password={database_settings.password} "
f"dbname={database_settings.database}"
)
self.pool: Optional[AsyncConnectionPool] = None
async def _get_pool(self):
pool = AsyncConnectionPool(
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
)
await pool.open()
try:
async with pool.connection() as con:
await pool.check_connection(con)
except PoolTimeout as e:
await pool.close()
_logger.fatal(f"Failed to connect to the database", e)
return pool
@staticmethod
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
if multi:
queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()]
for q in queries:
if q.strip() == "":
continue
await cursor.execute(sql.SQL(q), args)
else:
await cursor.execute(sql.SQL(query), args)
async def execute(self, query: str, args=None, multi=True) -> list[list]:
"""
Execute a SQL statement, it could be with args and without args. The usage is
similar to the execute() function in the psycopg module.
:param query: SQL clause
:param args: args needed by the SQL clause
:param multi: if the query is a multi-statement
:return: return result
"""
async with await self._get_pool() as pool:
async with pool.connection() as con:
async with con.cursor() as cursor:
await self._exec_sql(cursor, query, args, multi)
if cursor.description is not None: # Check if the query returns rows
res = await cursor.fetchall()
if res is None:
return []
result = []
for row in res:
result.append(list(row))
return result
else:
return []
async def select(self, query: str, args=None, multi=True) -> list[str]:
"""
Execute a SQL statement, it could be with args and without args. The usage is
similar to the execute() function in the psycopg module.
:param query: SQL clause
:param args: args needed by the SQL clause
:param multi: if the query is a multi-statement
:return: return result
"""
async with await self._get_pool() as pool:
async with pool.connection() as con:
async with con.cursor() as cursor:
await self._exec_sql(cursor, query, args, multi)
res = await cursor.fetchall()
return list(res)
async def select_map(self, query: str, args=None, multi=True) -> list[dict]:
"""
Execute a SQL statement, it could be with args and without args. The usage is
similar to the execute() function in the psycopg module.
:param query: SQL clause
:param args: args needed by the SQL clause
:param multi: if the query is a multi-statement
:return: return result
"""
async with await self._get_pool() as pool:
async with pool.connection() as con:
async with con.cursor() as cursor:
await self._exec_sql(cursor, query, args, multi)
res = await cursor.fetchall()
res_map: list[dict] = []
for i_res in range(len(res)):
cols = {}
for i_col in range(len(res[i_res])):
cols[cursor.description[i_col].name] = res[i_res][i_col]
res_map.append(cols)
return res_map

View File

@@ -0,0 +1,154 @@
from typing import Optional, Union
from cpl.database._external_data_temp_table_builder import ExternalDataTempTableBuilder
class SQLSelectBuilder:
def __init__(self, table_name: str, primary_key: str):
self._table_name = table_name
self._primary_key = primary_key
self._temp_tables: dict[str, ExternalDataTempTableBuilder] = {}
self._to_use_temp_tables: list[str] = []
self._attributes: list[str] = []
self._tables: list[str] = [table_name]
self._joins: dict[str, (str, str)] = {}
self._conditions: list[str] = []
self._order_by: str = ""
self._limit: Optional[int] = None
self._offset: Optional[int] = None
def with_temp_table(self, temp_table: ExternalDataTempTableBuilder) -> "SQLSelectBuilder":
self._temp_tables[temp_table.table_name] = temp_table
return self
def use_temp_table(self, temp_table_name: str):
if temp_table_name not in self._temp_tables:
raise ValueError(f"Temp table {temp_table_name} not found.")
self._to_use_temp_tables.append(temp_table_name)
def with_attribute(self, attr: str, ignore_table_name=False) -> "SQLSelectBuilder":
if not ignore_table_name and not attr.startswith(self._table_name):
attr = f"{self._table_name}.{attr}"
self._attributes.append(attr)
return self
def with_foreign_attribute(self, attr: str) -> "SQLSelectBuilder":
self._attributes.append(attr)
return self
def with_table(self, table_name: str) -> "SQLSelectBuilder":
self._tables.append(table_name)
return self
def _check_prefix(self, attr: str, foreign_tables: list[str]) -> str:
assert attr is not None
if "TO_CHAR" in attr:
return attr
valid_prefixes = [
"levenshtein",
self._table_name,
*self._joins.keys(),
*self._temp_tables.keys(),
*foreign_tables,
]
if not any(attr.startswith(f"{prefix}.") for prefix in valid_prefixes):
attr = f"{self._table_name}.{attr}"
return attr
def with_value_condition(
self, attr: str, operator: str, value: str, foreign_tables: list[str]
) -> "SQLSelectBuilder":
attr = self._check_prefix(attr, foreign_tables)
self._conditions.append(f"{attr} {operator} {value}")
return self
def with_levenshtein_condition(self, condition: str) -> "SQLSelectBuilder":
self._conditions.append(condition)
return self
def with_condition(self, attr: str, operator: str, foreign_tables: list[str]) -> "SQLSelectBuilder":
attr = self._check_prefix(attr, foreign_tables)
self._conditions.append(f"{attr} {operator}")
return self
def with_grouped_conditions(self, conditions: list[str]) -> "SQLSelectBuilder":
self._conditions.append(f"({' AND '.join(conditions)})")
return self
def with_left_join(self, table: str, on: str) -> "SQLSelectBuilder":
if table in self._joins:
self._joins[table] = (f"{self._joins[table][0]} AND {on}", "LEFT")
self._joins[table] = (on, "LEFT")
return self
def with_inner_join(self, table: str, on: str) -> "SQLSelectBuilder":
if table in self._joins:
self._joins[table] = (f"{self._joins[table][0]} AND {on}", "INNER")
self._joins[table] = (on, "INNER")
return self
def with_right_join(self, table: str, on: str) -> "SQLSelectBuilder":
if table in self._joins:
self._joins[table] = (f"{self._joins[table][0]} AND {on}", "RIGHT")
self._joins[table] = (on, "RIGHT")
return self
def with_limit(self, limit: int) -> "SQLSelectBuilder":
self._limit = limit
return self
def with_offset(self, offset: int) -> "SQLSelectBuilder":
self._offset = offset
return self
def with_order_by(self, column: Union[str, property], direction: str = "ASC") -> "SQLSelectBuilder":
if isinstance(column, property):
column = column.fget.__name__
self._order_by = f"{column} {direction}"
return self
async def _handle_temp_table_use(self, query) -> str:
new_query = ""
for temp_table_name in self._to_use_temp_tables:
temp_table = self._temp_tables[temp_table_name]
new_query += await self._temp_tables[temp_table_name].build()
self.with_left_join(
temp_table.table_name,
f"{temp_table.join_ref_table}.{self._primary_key} = {temp_table.table_name}.{temp_table.primary_key}",
)
return f"{new_query} {query}" if new_query != "" else query
async def build(self) -> str:
query = await self._handle_temp_table_use("")
attributes = ", ".join(self._attributes) if self._attributes else "*"
query += f"SELECT {attributes} FROM {", ".join(self._tables)}"
for join in self._joins:
query += f" {self._joins[join][1]} JOIN {join} ON {self._joins[join][0]}"
if self._conditions:
query += " WHERE " + " AND ".join(self._conditions)
if self._order_by:
query += f" ORDER BY {self._order_by}"
if self._limit is not None:
query += f" LIMIT {self._limit}"
if self._offset is not None:
query += f" OFFSET {self._offset}"
return query

View File

@@ -0,0 +1,18 @@
from datetime import datetime
from typing import Optional
from cpl.database.abc import DbModelABC
class ExecutedMigration(DbModelABC):
def __init__(
self,
migration_id: str,
created: Optional[datetime] = None,
modified: Optional[datetime] = None,
):
DbModelABC.__init__(self, migration_id, False, created, modified)
@property
def migration_id(self) -> str:
return self._id

View File

@@ -0,0 +1,14 @@
from cpl.database import InternalTables
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.database.db_logger import DBLogger
from cpl.database.schema.executed_migration import ExecutedMigration
_logger = DBLogger(__name__)
class ExecutedMigrationDao(DataAccessObjectABC[ExecutedMigration]):
def __init__(self):
DataAccessObjectABC.__init__(self, __name__, ExecutedMigration, InternalTables.executed_migrations)
self.attribute(ExecutedMigration.migration_id, str, primary_key=True, db_name="migrationId")

View File

@@ -0,0 +1,6 @@
CREATE TABLE IF NOT EXISTS _executed_migrations
(
migrationId VARCHAR(255) PRIMARY KEY,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
);

View File

@@ -0,0 +1,26 @@
DELIMITER //
CREATE TRIGGER mytable_before_update
BEFORE UPDATE
ON mytable
FOR EACH ROW
BEGIN
INSERT INTO mytable_history
SELECT OLD.*;
SET NEW.updated = NOW();
END;
//
DELIMITER ;
DELIMITER //
CREATE TRIGGER mytable_before_delete
BEFORE DELETE
ON mytable
FOR EACH ROW
BEGIN
INSERT INTO mytable_history
SELECT OLD.*;
END;
//
DELIMITER ;

View File

@@ -0,0 +1,47 @@
CREATE SCHEMA IF NOT EXISTS public;
CREATE SCHEMA IF NOT EXISTS system;
CREATE TABLE IF NOT EXISTS system._executed_migrations
(
MigrationId VARCHAR(255) PRIMARY KEY,
Created timestamptz NOT NULL DEFAULT NOW(),
Updated timestamptz NOT NULL DEFAULT NOW()
);
CREATE OR REPLACE FUNCTION public.history_trigger_function()
RETURNS TRIGGER AS
$$
DECLARE
schema_name TEXT;
history_table_name TEXT;
BEGIN
-- Construct the name of the history table based on the current table
schema_name := TG_TABLE_SCHEMA;
history_table_name := TG_TABLE_NAME || '_history';
IF (TG_OP = 'INSERT') THEN
RETURN NEW;
END IF;
-- Insert the old row into the history table on UPDATE or DELETE
IF (TG_OP = 'UPDATE' OR TG_OP = 'DELETE') THEN
EXECUTE format(
'INSERT INTO %I.%I SELECT ($1).*',
schema_name,
history_table_name
)
USING OLD;
END IF;
-- For UPDATE, update the Updated column and return the new row
IF (TG_OP = 'UPDATE') THEN
NEW.updated := NOW(); -- Update the Updated column
RETURN NEW;
END IF;
-- For DELETE, return OLD to allow the deletion
IF (TG_OP = 'DELETE') THEN
RETURN OLD;
END IF;
END;
$$ LANGUAGE plpgsql;

View File

@@ -0,0 +1,111 @@
import glob
import os
from cpl.database.abc import DBContextABC
from cpl.database.db_logger import DBLogger
from cpl.database.model import Migration
from cpl.database.model.server_type import ServerType, ServerTypes
from cpl.database.schema.executed_migration import ExecutedMigration
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
_logger = DBLogger(__name__)
class MigrationService:
def __init__(self, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao):
self._db = db
self._executedMigrationDao = executedMigrationDao
self._script_directories: list[str] = []
if ServerType.server_type == ServerTypes.POSTGRES:
self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/postgres"))
elif ServerType.server_type == ServerTypes.MYSQL:
self.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../scripts/mysql"))
def with_directory(self, directory: str) -> "MigrationService":
self._script_directories.append(directory)
return self
async def _get_migration_history(self) -> list[ExecutedMigration]:
results = await self._db.select(f"SELECT * FROM {self._executedMigrationDao.table_name}")
applied_migrations = []
for result in results:
applied_migrations.append(ExecutedMigration(result[0]))
return applied_migrations
@staticmethod
def _load_scripts_by_path(path: str) -> list[Migration]:
migrations = []
if not os.path.exists(path):
raise Exception("Migration path not found")
files = sorted(glob.glob(f"{path}/*"))
for file in files:
if not file.endswith(".sql"):
continue
name = str(file.split(".sql")[0])
if "/" in name:
name = name.split("/")[-1]
with open(f"{file}", "r") as f:
script = f.read()
f.close()
migrations.append(Migration(name, script))
return migrations
def _load_scripts(self) -> list[Migration]:
migrations = []
for path in self._script_directories:
migrations.extend(self._load_scripts_by_path(path))
return migrations
async def _get_tables(self):
if ServerType == ServerTypes.POSTGRES:
return await self._db.select(
"""
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public';
"""
)
else:
return await self._db.select(
"""
SHOW TABLES;
"""
)
async def _execute(self, migrations: list[Migration]):
result = await self._get_tables()
for migration in migrations:
active_statement = ""
try:
# check if table exists
if len(result) > 0:
migration_from_db = await self._executedMigrationDao.find_by_id(migration.name)
if migration_from_db is not None:
continue
_logger.debug(f"Running upgrade migration: {migration.name}")
await self._db.execute(migration.script, multi=True)
await self._executedMigrationDao.create(ExecutedMigration(migration.name), skip_editor=True)
except Exception as e:
_logger.fatal(
f"Migration failed: {migration.name}\n{active_statement}",
e,
)
async def migrate(self):
await self._execute(self._load_scripts())

View File

@@ -0,0 +1,65 @@
from datetime import datetime
from typing import TypeVar, Union, Literal, Any
from cpl.database.abc.db_model_abc import DbModelABC
T_DBM = TypeVar("T_DBM", bound=DbModelABC)
NumberFilterOperator = Literal[
"equal",
"notEqual",
"greater",
"greaterOrEqual",
"less",
"lessOrEqual",
"isNull",
"isNotNull",
]
StringFilterOperator = Literal[
"equal",
"notEqual",
"contains",
"notContains",
"startsWith",
"endsWith",
"isNull",
"isNotNull",
]
BoolFilterOperator = Literal[
"equal",
"notEqual",
"isNull",
"isNotNull",
]
DateFilterOperator = Literal[
"equal",
"notEqual",
"greater",
"greaterOrEqual",
"less",
"lessOrEqual",
"isNull",
"isNotNull",
]
FilterOperator = Union[NumberFilterOperator, StringFilterOperator, BoolFilterOperator, DateFilterOperator]
Attribute = Union[str, property]
AttributeCondition = Union[
dict[NumberFilterOperator, int],
dict[StringFilterOperator, str],
dict[BoolFilterOperator, bool],
dict[DateFilterOperator, datetime],
]
AttributeFilter = dict[Attribute, Union[list[Union[AttributeCondition, Any]], AttributeCondition, Any]]
AttributeFilters = Union[
list[AttributeFilter],
AttributeFilter,
]
AttributeSort = dict[Attribute, Literal["asc", "desc"]]
AttributeSorts = Union[
list[AttributeSort],
AttributeSort,
]

View File

@@ -1,2 +1,8 @@
cpl-core
cpl-dependency
cpl-dependency
psycopg[binary]==3.2.3
psycopg-pool==3.2.4
sqlparse==0.5.3
mysql-connector-python==9.4.0
async-property==0.2.2
aiomysql==0.2.0