diff --git a/api/src/api/route_user_extension.py b/api/src/api/route_user_extension.py index 585390d..9efee96 100644 --- a/api/src/api/route_user_extension.py +++ b/api/src/api/route_user_extension.py @@ -53,7 +53,7 @@ class RouteUserExtension: user_id = cls._get_user_id_from_token(request) if not user_id: - return None + return await cls.get_dev_user() return await userDao.find_by_keycloak_id(user_id) diff --git a/api/src/api_graphql/field/mutation_field_builder.py b/api/src/api_graphql/field/mutation_field_builder.py index b6b7576..245045f 100644 --- a/api/src/api_graphql/field/mutation_field_builder.py +++ b/api/src/api_graphql/field/mutation_field_builder.py @@ -1,5 +1,5 @@ from asyncio import iscoroutinefunction -from typing import Self, Type +from typing import Self, Type, Union from ariadne.types import Resolver @@ -58,7 +58,11 @@ class MutationFieldBuilder(FieldBuilderABC): self._resolver = resolver_wrapper return self - def with_input(self, input_type: Type[InputABC], input_key: str = "input") -> Self: + def with_input( + self, + input_type: Type[Union[InputABC, str, int, bool]], + input_key: str = "input", + ) -> Self: self._input_type = input_type self._input_key = input_key return self diff --git a/api/src/api_graphql/graphql/mutation.gql b/api/src/api_graphql/graphql/mutation.gql index ad67439..6dff6f7 100644 --- a/api/src/api_graphql/graphql/mutation.gql +++ b/api/src/api_graphql/graphql/mutation.gql @@ -11,4 +11,6 @@ type Mutation { setting: SettingMutation userSetting: UserSettingMutation featureFlag: FeatureFlagMutation + + privacy: PrivacyMutation } \ No newline at end of file diff --git a/api/src/api_graphql/graphql/privacy.gql b/api/src/api_graphql/graphql/privacy.gql new file mode 100644 index 0000000..0a1410c --- /dev/null +++ b/api/src/api_graphql/graphql/privacy.gql @@ -0,0 +1,5 @@ +type PrivacyMutation { + exportData(userId: Int!): String + anonymizeData(userId: Int!): String + deleteData(userId: Int!): String +} \ No newline at end of file diff --git a/api/src/api_graphql/mutation.py b/api/src/api_graphql/mutation.py index 2514018..f996321 100644 --- a/api/src/api_graphql/mutation.py +++ b/api/src/api_graphql/mutation.py @@ -86,3 +86,7 @@ class Mutation(MutationABC): Permissions.administrator, ], ) + self.add_mutation_type( + "privacy", + "Privacy", + ) diff --git a/api/src/api_graphql/mutations/privacy_mutation.py b/api/src/api_graphql/mutations/privacy_mutation.py new file mode 100644 index 0000000..b0ab8a6 --- /dev/null +++ b/api/src/api_graphql/mutations/privacy_mutation.py @@ -0,0 +1,63 @@ +from api.route import Route +from api_graphql.abc.mutation_abc import MutationABC +from api_graphql.field.mutation_field_builder import MutationFieldBuilder +from api_graphql.service.exceptions import UnauthorizedException, AccessDenied +from core.logger import APILogger +from service.data_privacy_service import DataPrivacyService +from service.permission.permissions_enum import Permissions + +logger = APILogger(__name__) + + +class PrivacyMutation(MutationABC): + def __init__(self): + MutationABC.__init__(self, "Privacy") + + self.field( + MutationFieldBuilder("exportData") + .with_resolver(self.resolve_export_data) + .with_input(int, "userId") + .with_public(True) + ) + self.field( + MutationFieldBuilder("anonymizeData") + .with_resolver(self.resolve_anonymize_data) + .with_input(int, "userId") + .with_public(True) + ) + self.field( + MutationFieldBuilder("deleteData") + .with_resolver(self.resolve_delete_data) + .with_input(int, "userId") + .with_public(True) + ) + + @staticmethod + async def _permission_check(user_id: int): + user = await Route.get_user() + if user is None: + raise UnauthorizedException() + + if user.id != user_id and not user.has_permission(Permissions.administrator): + raise AccessDenied() + + @staticmethod + async def resolve_export_data(user_id: int, *_): + logger.debug(f"export data for user: {user_id}") + await PrivacyMutation._permission_check(user_id) + + return await DataPrivacyService.export_user_data(user_id) + + @staticmethod + async def resolve_anonymize_data(user_id: int, *_): + logger.debug(f"anonymize data for user: {user_id}") + await PrivacyMutation._permission_check(user_id) + + return await DataPrivacyService.anonymize_user(user_id) + + @staticmethod + async def resolve_delete_data(user_id: int, *_): + logger.debug(f"delete data for user: {user_id}") + await PrivacyMutation._permission_check(user_id) + + return await DataPrivacyService.delete_user_data(user_id) diff --git a/api/src/api_graphql/mutations/setting_mutation.py b/api/src/api_graphql/mutations/setting_mutation.py index 4c499f3..7aa2404 100644 --- a/api/src/api_graphql/mutations/setting_mutation.py +++ b/api/src/api_graphql/mutations/setting_mutation.py @@ -2,7 +2,7 @@ from api_graphql.abc.mutation_abc import MutationABC from api_graphql.input.setting_input import SettingInput from core.logger import APILogger from data.schemas.system.setting import Setting -from data.schemas.system.setting_dao import settingsDao +from data.schemas.system.setting_dao import settingDao from service.permission.permissions_enum import Permissions logger = APILogger(__name__) @@ -22,11 +22,11 @@ class SettingMutation(MutationABC): async def resolve_change(obj: SettingInput, *_): logger.debug(f"create new setting: {input}") - setting = await settingsDao.find_single_by({Setting.key: obj.key}) + setting = await settingDao.find_single_by({Setting.key: obj.key}) if setting is None: raise ValueError(f"Setting with key {obj.key} not found") setting.value = obj.value - await settingsDao.update(setting) + await settingDao.update(setting) - return await settingsDao.get_by_id(setting.id) + return await settingDao.get_by_id(setting.id) diff --git a/api/src/api_graphql/mutations/user_setting_mutation.py b/api/src/api_graphql/mutations/user_setting_mutation.py index bca1a9b..ff1be20 100644 --- a/api/src/api_graphql/mutations/user_setting_mutation.py +++ b/api/src/api_graphql/mutations/user_setting_mutation.py @@ -5,8 +5,8 @@ from api_graphql.input.user_setting_input import UserSettingInput from core.logger import APILogger from core.string import first_to_lower from data.schemas.public.user_setting import UserSetting -from data.schemas.public.user_setting_dao import userSettingsDao -from data.schemas.system.setting_dao import settingsDao +from data.schemas.public.user_setting_dao import userSettingDao +from data.schemas.system.setting_dao import settingDao from service.permission.permissions_enum import Permissions logger = APILogger(__name__) @@ -37,13 +37,13 @@ class UserSettingMutation(MutationABC): logger.debug("user not authorized") return None - setting = await userSettingsDao.find_single_by( + setting = await userSettingDao.find_single_by( [{UserSetting.user_id: user.id}, {UserSetting.key: obj.key}] ) if setting is None: - await userSettingsDao.create(UserSetting(0, user.id, obj.key, obj.value)) + await userSettingDao.create(UserSetting(0, user.id, obj.key, obj.value)) else: setting.value = obj.value - await userSettingsDao.update(setting) + await userSettingDao.update(setting) - return await userSettingsDao.find_by_key(user, obj.key) + return await userSettingDao.find_by_key(user, obj.key) diff --git a/api/src/api_graphql/query.py b/api/src/api_graphql/query.py index 0f949f5..2e80df3 100644 --- a/api/src/api_graphql/query.py +++ b/api/src/api_graphql/query.py @@ -32,9 +32,9 @@ from data.schemas.public.group_dao import groupDao from data.schemas.public.short_url import ShortUrl from data.schemas.public.short_url_dao import shortUrlDao from data.schemas.public.user_setting import UserSetting -from data.schemas.public.user_setting_dao import userSettingsDao +from data.schemas.public.user_setting_dao import userSettingDao from data.schemas.system.feature_flag_dao import featureFlagDao -from data.schemas.system.setting_dao import settingsDao +from data.schemas.system.setting_dao import settingDao from service.permission.permissions_enum import Permissions @@ -206,8 +206,8 @@ class Query(QueryABC): @staticmethod async def _resolve_settings(*args, **kwargs): if "key" in kwargs: - return [await settingsDao.find_by_key(kwargs["key"])] - return await settingsDao.get_all() + return [await settingDao.find_by_key(kwargs["key"])] + return await settingDao.get_all() @staticmethod async def _resolve_user_settings(*args, **kwargs): @@ -216,10 +216,10 @@ class Query(QueryABC): return None if "key" in kwargs: - return await userSettingsDao.find_by( + return await userSettingDao.find_by( [{UserSetting.user_id: user.id}, {UserSetting.key: kwargs["key"]}] ) - return await userSettingsDao.find_by({UserSetting.user_id: user.id}) + return await userSettingDao.find_by({UserSetting.user_id: user.id}) @staticmethod async def _resolve_feature_flags(*args, **kwargs): diff --git a/api/src/core/database/abc/data_access_object_abc.py b/api/src/core/database/abc/data_access_object_abc.py index 230b77b..4fe4874 100644 --- a/api/src/core/database/abc/data_access_object_abc.py +++ b/api/src/core/database/abc/data_access_object_abc.py @@ -2,38 +2,23 @@ import datetime from abc import ABC, abstractmethod from enum import Enum from types import NoneType -from typing import Generic, Optional, Union, TypeVar, Any, Type +from typing import Generic, Optional, Union, Type, List, Any, TypeVar from core.const import DATETIME_FORMAT from core.database.abc.db_model_abc import DbModelABC from core.database.database import Database from core.database.external_data_temp_table_builder import ExternalDataTempTableBuilder +from core.database.sql_select_builder import SQLSelectBuilder from core.get_value import get_value from core.logger import DBLogger from core.string import camel_to_snake -from core.typing import T, Attribute, AttributeFilters, AttributeSorts, Id +from core.typing import AttributeFilters, AttributeSorts, Id, Attribute, T T_DBM = TypeVar("T_DBM", bound=DbModelABC) class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): _external_fields: dict[str, ExternalDataTempTableBuilder] = {} - _operators = [ - "equal", - "notEqual", - "greater", - "greaterOrEqual", - "less", - "lessOrEqual", - "isNull", - "isNotNull", - "contains", - "notContains", - "startsWith", - "endsWith", - "in", - "notIn", - ] @abstractmethod def __init__(self, source: str, model_type: Type[T_DBM], table_name: str): @@ -41,13 +26,16 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): self._model_type = model_type self._table_name = table_name + self._logger = DBLogger(source) + self._model_type = model_type + self._table_name = table_name + self._default_filter_condition = None self.__attributes: dict[str, type] = {} - self.__joins: list[str] = [] self.__db_names: dict[str, str] = {} - self.__foreign_tables: dict[str, str] = {} + self.__foreign_tables: dict[str, tuple[str, str]] = {} self.__foreign_table_keys: dict[str, str] = {} self.__date_attributes: set[str] = set() @@ -60,6 +48,14 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): def table_name(self) -> str: return self._table_name + def has_attribute(self, attr_name: Attribute) -> bool: + """ + Check if the attribute exists in the DAO + :param Attribute attr_name: Name of the attribute + :return: True if the attribute exists, False otherwise + """ + return attr_name in self.__attributes + def attribute( self, attr_name: Attribute, @@ -138,24 +134,10 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): if table_name == self._table_name: return - if any(f"{table_name} ON" in join for join in self.__joins): - index = next( - ( - i - for i, join in enumerate(self.__joins) - if f"{table_name} ON" in join - ), - None, - ) - if index is not None: - self.__joins[ - index - ] += f" AND {table_name}.{primary_attr} = {self._table_name}.{foreign_attr}" - else: - self.__joins.append( - f"LEFT JOIN {table_name} ON {table_name}.{primary_attr} = {self._table_name}.{foreign_attr}" - ) - self.__foreign_tables[attr] = table_name + self.__foreign_tables[attr] = ( + table_name, + f"{table_name}.{primary_attr} = {self._table_name}.{foreign_attr}", + ) def use_external_fields(self, builder: ExternalDataTempTableBuilder): self._external_fields[builder.table_name] = builder @@ -180,22 +162,28 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): return self._model_type(**value_map) + def to_dict(self, obj: T_DBM) -> dict: + """ + Convert an object to a dictionary + :param T_DBM obj: Object to convert + :return: + """ + value_map: dict[str, Any] = {} + + for attr_name, attr_type in self.__attributes.items(): + value = getattr(obj, attr_name) + if isinstance(value, datetime.datetime): + value = value.strftime(DATETIME_FORMAT) + elif isinstance(value, Enum): + value = value.value + + value_map[attr_name] = value + + return value_map + async def count(self, filters: AttributeFilters = None) -> int: - query = f"SELECT COUNT(*) FROM {self._table_name}" - join_str = f" ".join(self.__joins) - query += f" {join_str}" if len(join_str) > 0 else "" - - if filters is not None and (not isinstance(filters, list) or len(filters) > 0): - conditions, external_table_deps = await self._build_conditions(filters) - query = await self._handle_query_external_temp_tables( - query, external_table_deps, ignore_fields=True - ) - query += f" WHERE {conditions};" - - result = await self._db.select_map(query) - if len(result) == 0: - return 0 - return result[0]["count"] + result = await self._prepare_query(filters=filters, for_count=True) + return result[0]["count"] if result else 0 async def get_history( self, @@ -203,57 +191,66 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): by_key: str = None, when: datetime = None, until: datetime = None, - without_deleted=False, + without_deleted: bool = False, ) -> list[T_DBM]: - query = f"SELECT {self._table_name}_history.* FROM {self._table_name}_history" + """ + Retrieve the history of an entry from the history table. + :param entry_id: The ID of the entry to retrieve history for. + :param by_key: The key to filter by (default is the primary key). + :param when: A specific timestamp to filter the history. + :param until: A timestamp to filter history entries up to a certain point. + :param without_deleted: Exclude deleted entries if True. + :return: A list of historical entries as objects. + """ + history_table = f"{self._table_name}_history" + builder = SQLSelectBuilder(history_table, self.__primary_key) - for join in self.__joins: - query += f" {join.replace(self._table_name, f'{self._table_name}_history')}" + builder.with_attribute("*") + builder.with_value_condition( + f"{history_table}.{by_key or self.__primary_key}", "=", str(entry_id) + ) - query += f" WHERE {f'{self._table_name}_history.{self.__primary_key}' if by_key is None else f'{self._table_name}_history.{by_key}'} = {entry_id}" - - if self._default_filter_condition is not None: - query += f" AND {self._default_filter_condition}" + if self._default_filter_condition: + builder.with_condition(self._default_filter_condition, "") if without_deleted: - query += f" AND {self._table_name}_history.deleted = false" + builder.with_value_condition(f"{history_table}.deleted", "=", "false") - if when is not None: - query += f" AND {self._attr_from_date_to_char(f'{self._table_name}_history.updated')} = '{when.strftime(DATETIME_FORMAT)}'" + if when: + builder.with_value_condition( + self._attr_from_date_to_char(f"{history_table}.updated"), + "=", + f"'{when.strftime(DATETIME_FORMAT)}'", + ) - if until is not None: - query += f" AND {self._attr_from_date_to_char(f'{self._table_name}_history.updated')} <= '{until.strftime(DATETIME_FORMAT)}'" + if until: + builder.with_value_condition( + self._attr_from_date_to_char(f"{history_table}.updated"), + "<=", + f"'{until.strftime(DATETIME_FORMAT)}'", + ) - query += f" ORDER BY {self._table_name}_history.updated DESC;" + builder.with_order_by(f"{history_table}.updated", "DESC") + query = await builder.build() result = await self._db.select_map(query) - if result is None: - return [] - return [self.to_object(x) for x in result] + return [self.to_object(x) for x in result] if result else [] - async def get_all(self) -> list[T_DBM]: - result = await self._db.select_map( - f"SELECT * FROM {self._table_name}{f" WHERE {self._default_filter_condition}" if self._default_filter_condition is not None else ''} ORDER BY {self.__primary_key};" + async def get_all(self) -> List[T_DBM]: + result = await self._prepare_query(sorts=[{self.__primary_key: "asc"}]) + return [self.to_object(x) for x in result] if result else [] + + async def get_by_id(self, id: Union[int, str]) -> Optional[T_DBM]: + result = await self._prepare_query( + filters=[{self.__primary_key: id}], sorts=[{self.__primary_key: "asc"}] ) - if result is None: - return [] - - return [self.to_object(x) for x in result] - - async def get_by_id(self, id: Union[int, str]) -> T_DBM: - result = await self._db.select_map( - f"SELECT * FROM {self._table_name} WHERE {f"{self._default_filter_condition} AND " if self._default_filter_condition is not None else ''} {self.__primary_key} = {f"'{id}'" if isinstance(id, str) else id}" - ) - return self.to_object(result[0]) + return self.to_object(result[0]) if result else None async def find_by_id(self, id: Union[int, str]) -> Optional[T_DBM]: - result = await self._db.select_map( - f"SELECT * FROM {self._table_name} WHERE {f"{self._default_filter_condition} AND " if self._default_filter_condition is not None else ''} {self.__primary_key} = {f"'{id}'" if isinstance(id, str) else id}" + result = await self._prepare_query( + filters=[{self.__primary_key: id}], sorts=[{self.__primary_key: "asc"}] ) - if not result or len(result) == 0: - return None - - return self.to_object(result[0]) + return self.to_object(result[0]) if result else None async def get_by( self, @@ -262,23 +259,10 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): take: int = None, skip: int = None, ) -> list[T_DBM]: - """ - Get all objects by the given filters - :param AttributeFilter filters: - :param AttributeSorts sorts: - :param int skip: - :param int take: - :return: List of objects - :rtype: list[T_DBM] - :raises ValueError: When no result is found - """ - result = await self._db.select_map( - await self._build_conditional_query(filters, sorts, take, skip) - ) + result = await self._prepare_query(filters, sorts, take, skip) if not result or len(result) == 0: raise ValueError("No result found") - - return [self.to_object(x) for x in result] + return [self.to_object(x) for x in result] if result else [] async def get_single_by( self, @@ -287,23 +271,12 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): take: int = None, skip: int = None, ) -> T_DBM: - """ - Get a single object by the given filters - :param AttributeFilter filters: - :param AttributeSorts sorts: - :param int skip: - :param int take: - :return: Single object - :rtype: T_DBM - :raises ValueError: When no result is found - :raises ValueError: When more than one result is found - """ - result = await self.get_by(filters, sorts, take, skip) + result = await self._prepare_query(filters, sorts, take, skip) if not result: raise ValueError("No result found") if len(result) > 1: raise ValueError("More than one result found") - return result[0] + return self.to_object(result[0]) async def find_by( self, @@ -311,23 +284,9 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): sorts: AttributeSorts = None, take: int = None, skip: int = None, - ) -> list[Optional[T_DBM]]: - """ - Find all objects by the given filters - :param AttributeFilter filters: - :param AttributeSorts sorts: - :param int skip: - :param int take: - :return: List of objects - :rtype: list[Optional[T_DBM]] - """ - result = await self._db.select_map( - await self._build_conditional_query(filters, sorts, take, skip) - ) - if not result or len(result) == 0: - return [] - - return [self.to_object(x) for x in result] + ) -> list[T_DBM]: + result = await self._prepare_query(filters, sorts, take, skip) + return [self.to_object(x) for x in result] if result else [] async def find_single_by( self, @@ -336,22 +295,10 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): take: int = None, skip: int = None, ) -> Optional[T_DBM]: - """ - Find a single object by the given filters - :param AttributeFilter filters: - :param AttributeSorts sorts: - :param int skip: - :param int take: - :return: Single object - :rtype: Optional[T_DBM] - :raises ValueError: When more than one result is found - """ - result = await self.find_by(filters, sorts, take, skip) - if not result or len(result) == 0: - return None + result = await self._prepare_query(filters, sorts, take, skip) if len(result) > 1: raise ValueError("More than one result found") - return result[0] + return self.to_object(result[0]) if result else None async def touch(self, obj: T_DBM): """ @@ -520,15 +467,289 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): await self._db.execute(query) - def _get_primary_key_value_sql(self, obj: T_DBM) -> str: - value = getattr(obj, self.__primary_key) - if isinstance(value, str): - return f"'{value}'" + async def _prepare_query( + self, + filters: AttributeFilters = None, + sorts: AttributeSorts = None, + take: int = None, + skip: int = None, + for_count=False, + ) -> list[dict]: + """ + Prepares and executes a query using the SQLBuilder with the given parameters. + :param filters: Conditions to filter the query. + :param sorts: Sorting attributes and directions. + :param take: Limit the number of results. + :param skip: Offset the results. + :return: Query result as a list of dictionaries. + """ + external_table_deps = [] + builder = SQLSelectBuilder(self._table_name, self.__primary_key) - return value + for temp in self._external_fields: + builder.with_temp_table(self._external_fields[temp]) + + if for_count: + builder.with_attribute("COUNT(*)", ignore_table_name=True) + else: + builder.with_attribute("*") + + for attr in self.__foreign_tables: + table, join_condition = self.__foreign_tables[attr] + builder.with_left_join(table, join_condition) + + if filters: + await self._build_conditions(builder, filters, external_table_deps) + + if sorts: + self._build_sorts(builder, sorts, external_table_deps) + + if take: + builder.with_limit(take) + + if skip: + builder.with_offset(skip) + + for external_table in external_table_deps: + builder.use_temp_table(external_table) + + query = await builder.build() + return await self._db.select_map(query) + + async def _build_conditions( + self, + builder: SQLSelectBuilder, + filters: AttributeFilters, + external_table_deps: list[str], + ): + """ + Builds SQL conditions from GraphQL-like filters and adds them to the SQLBuilder. + :param builder: The SQLBuilder instance to add conditions to. + :param filters: GraphQL-like filter structure. + :param external_table_deps: List to store external table dependencies. + """ + if not isinstance(filters, list): + filters = [filters] + + for filter_group in filters: + sql_conditions = self._graphql_to_sql_conditions( + filter_group, external_table_deps + ) + for attr, operator, value in sql_conditions: + if attr in self.__foreign_table_keys: + attr = self.__foreign_table_keys[attr] + + external_table = self._get_external_field_key(attr) + if external_table: + external_table_deps.append(external_table) + + if operator == "fuzzy": + builder.with_levenshtein_condition(attr) + elif operator in [ + "IS NULL", + "IS NOT NULL", + ]: # operator without value + builder.with_condition(attr, operator) + else: + builder.with_value_condition( + attr, operator, self._get_value_sql(value) + ) + + def _graphql_to_sql_conditions( + self, graphql_structure: dict, external_table_deps: list[str] + ) -> list[tuple[str, str, Any]]: + """ + Converts a GraphQL-like structure to SQL conditions. + :param graphql_structure: The GraphQL-like filter structure. + :param external_table_deps: List to track external table dependencies. + :return: A list of tuples (attribute, operator, value). + """ + + operators = { + "equal": "=", + "notEqual": "!=", + "greater": ">", + "greaterOrEqual": ">=", + "less": "<", + "lessOrEqual": "<=", + "isNull": "IS NULL", + "isNotNull": "IS NOT NULL", + "contains": "LIKE", # Special handling in _graphql_to_sql_conditions + "notContains": "NOT LIKE", # Special handling in _graphql_to_sql_conditions + "startsWith": "LIKE", # Special handling in _graphql_to_sql_conditions + "endsWith": "LIKE", # Special handling in _graphql_to_sql_conditions + "in": "IN", + "notIn": "NOT IN", + } + conditions = [] + + def parse_node(node, parent_key=None): + if not isinstance(node, dict): + return + + if isinstance(node, list): + conditions.append((parent_key, "IN", node)) + return + + for key, value in node.items(): + if isinstance(key, property): + key = key.fget.__name__ + + external_fields_table_name_by_parent = self._get_external_field_key( + parent_key + ) + external_fields_table_name = self._get_external_field_key(key) + external_field = ( + external_fields_table_name + if external_fields_table_name_by_parent is None + else external_fields_table_name_by_parent + ) + + if key == "fuzzy": + self._handle_fuzzy_filter_conditions( + conditions, external_table_deps, value + ) + + elif external_field is not None: + external_table_deps.append(external_field) + parse_node(value, f"{external_field}.{key}") + elif parent_key in self.__foreign_table_keys: + parse_node({key: value}, self.__foreign_table_keys[parent_key]) + elif key in operators: + operator = operators[key] + if key == "contains" or key == "notContains": + value = f"%{value}%" + elif key == "in" or key == "notIn": + value = value + elif key == "startsWith": + value = f"{value}%" + elif key == "endsWith": + value = f"%{value}" + elif (key == "equal" or key == "notEqual") and value is None: + operator = operators["isNull"] + + conditions.append((parent_key, operator, value)) + + elif isinstance(value, dict): + if key in self.__foreign_table_keys: + parse_node(value, key) + elif key in self.__db_names: + parse_node(value, self.__db_names[key]) + else: + parse_node(value, key) + elif value is None: + conditions.append((self.__db_names[key], "IS NULL", value)) + else: + conditions.append((self.__db_names[key], "=", value)) + + parse_node(graphql_structure) + return conditions + + def _handle_fuzzy_filter_conditions( + self, conditions, external_field_table_deps, sub_values + ): + # Extract fuzzy filter parameters + fuzzy_fields = get_value(sub_values, "fields", list[str]) + fuzzy_term = get_value(sub_values, "term", str) + fuzzy_threshold = get_value(sub_values, "threshold", int, 5) + + if not fuzzy_fields or not fuzzy_term: + raise ValueError("Fuzzy filter must include 'fields' and 'term'.") + + fuzzy_fields_db_names = [] + + # Map fields to their database names + for fuzzy_field in fuzzy_fields: + external_fields_table_name = self._get_external_field_key(fuzzy_field) + if external_fields_table_name is not None: + external_fields_table = self._external_fields[ + external_fields_table_name + ] + fuzzy_fields_db_names.append( + f"{external_fields_table.table_name}.{fuzzy_field}" + ) + external_field_table_deps.append(external_fields_table.table_name) + elif fuzzy_field in self.__db_names: + fuzzy_fields_db_names.append( + f"{self._table_name}.{self.__db_names[fuzzy_field]}" + ) + elif fuzzy_field in self.__foreign_tables: + fuzzy_fields_db_names.append( + f"{self._table_name}.{self.__foreign_table_keys[fuzzy_field]}" + ) + else: + fuzzy_fields_db_names.append( + self.__db_names[camel_to_snake(fuzzy_field)][0] + ) + + # Build fuzzy conditions for each field + fuzzy_conditions = self._build_fuzzy_conditions( + fuzzy_fields_db_names, fuzzy_term, fuzzy_threshold + ) + + # Combine conditions with OR and append to the main conditions + conditions.append((f"({' OR '.join(fuzzy_conditions)})", "fuzzy", None)) @staticmethod - def _get_value_sql(value: Any) -> str: + def _build_fuzzy_conditions( + fields: list[str], term: str, threshold: int = 10 + ) -> list[str]: + conditions = [] + for field in fields: + conditions.append( + f"levenshtein({field}::TEXT, '{term}') <= {threshold}" + ) # Adjust the threshold as needed + + return conditions + + def _get_external_field_key(self, field_name: str) -> Optional[str]: + """ + Returns the key to get the external field if found, otherwise None. + :param str field_name: The name of the field to search for. + :return: The key if found, otherwise None. + :rtype: Optional[str] + """ + for key, builder in self._external_fields.items(): + if field_name in builder.fields and field_name not in self.__db_names: + return key + return None + + def _build_sorts( + self, + builder: SQLSelectBuilder, + sorts: AttributeSorts, + external_table_deps: list[str], + ): + """ + Resolves complex sorting structures into SQL-compatible sorting conditions. + Tracks external table dependencies. + :param builder: The SQLBuilder instance to add sorting to. + :param sorts: Sorting attributes and directions in a complex structure. + :param external_table_deps: List to track external table dependencies. + """ + + def parse_sort_node(node): + if isinstance(node, dict): + for key, value in node.items(): + if isinstance(value, dict): + # Recursively parse nested structures + parse_sort_node(value) + elif isinstance(value, str) and value.lower() in ["asc", "desc"]: + external_table = self._get_external_field_key(key) + if external_table: + external_table_deps.append(external_table) + builder.with_order_by(key, value.upper()) + else: + raise ValueError(f"Invalid sort direction: {value}") + elif isinstance(node, list): + for item in node: + parse_sort_node(item) + else: + raise ValueError(f"Invalid sort structure: {node}") + + parse_sort_node(sorts) + + def _get_value_sql(self, value: Any) -> str: if isinstance(value, str): if value.lower() == "null": return "NULL" @@ -548,8 +769,8 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): if isinstance(value, list): if len(value) == 0: - return "ARRAY[]::text[]" - return f"ARRAY[{", ".join([DataAccessObjectABC._get_value_sql(x) for x in value])}]" + return "()" + return f"({', '.join([self._get_value_sql(x) for x in value])})" if isinstance(value, datetime.datetime): if value.tzinfo is None: @@ -578,425 +799,17 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): return cast_type(value) - async def _handle_query_external_temp_tables( - self, query: str, external_table_deps: list[str], ignore_fields=False - ) -> str: - for dep in external_table_deps: - temp_table = self._external_fields[dep] - temp_table_sql = await temp_table.build() + def _get_primary_key_value_sql(self, obj: T_DBM) -> str: + value = getattr(obj, self.__primary_key) + if isinstance(value, str): + return f"'{value}'" - if not ignore_fields: - query = query.replace( - " FROM", - f", {','.join([f'{temp_table.table_name}.{x}' for x in temp_table.fields.keys() if x not in self.__db_names])} FROM", - ) - - query = f"{temp_table_sql}\n{query}" - query += f" LEFT JOIN {temp_table.table_name} ON {temp_table.join_ref_table}.{self.__primary_key} = {temp_table.table_name}.{temp_table.primary_key}" - - return query - - async def _build_conditional_query( - self, - filters: AttributeFilters = None, - sorts: AttributeSorts = None, - take: int = None, - skip: int = None, - ) -> str: - filter_conditions = [] - sort_conditions = [] - - external_table_deps = [] - query = f"SELECT {self._table_name}.* FROM {self._table_name}" - join_str = f" ".join(self.__joins) - query += f" {join_str}" if len(join_str) > 0 else "" - - # Collect dependencies from filters - if filters is not None and (not isinstance(filters, list) or len(filters) > 0): - filter_conditions, filter_deps = await self._build_conditions(filters) - external_table_deps.extend(filter_deps) - - # Collect dependencies from sorts - if sorts is not None and (not isinstance(sorts, list) or len(sorts) > 0): - sort_conditions, sort_deps = self._build_order_by(sorts) - external_table_deps.extend(sort_deps) - - # Handle external table dependencies before WHERE and ORDER BY - if external_table_deps: - query = await self._handle_query_external_temp_tables( - query, external_table_deps - ) - - # Add WHERE clause - if filters is not None and (not isinstance(filters, list) or len(filters) > 0): - query += f" WHERE {filter_conditions}" - - # Add ORDER BY clause - if sorts is not None and (not isinstance(sorts, list) or len(sorts) > 0): - query += f" ORDER BY {sort_conditions}" - - if take is not None: - query += f" LIMIT {take}" - - if skip is not None: - query += f" OFFSET {skip}" - - if not query.endswith(";"): - query += ";" - return query - - def _get_external_field_key(self, field_name: str) -> Optional[str]: - """ - Returns the key to get the external field if found, otherwise None. - :param str field_name: The name of the field to search for. - :return: The key if found, otherwise None. - :rtype: Optional[str] - """ - for key, builder in self._external_fields.items(): - if field_name in builder.fields and field_name not in self.__db_names: - return key - return None - - async def _build_conditions(self, filters: AttributeFilters) -> (str, list[str]): - """ - Build SQL conditions from the given filters - :param filters: - :return: SQL conditions & External field table dependencies - """ - external_field_table_deps = [] - if not isinstance(filters, list): - filters = [filters] - - conditions = [] - for f in filters: - f_conditions = [] - - for attr, values in f.items(): - if isinstance(attr, property): - attr = attr.fget.__name__ - - if attr in self.__foreign_tables: - foreign_table = self.__foreign_tables[attr] - cons, eftd = self._build_foreign_conditions( - attr, foreign_table, values - ) - if eftd: - external_field_table_deps.extend(eftd) - - f_conditions.extend(cons) - continue - - if attr == "fuzzy": - self._handle_fuzzy_filter_conditions( - f_conditions, external_field_table_deps, values - ) - continue - - external_fields_table_name = self._get_external_field_key(attr) - if external_fields_table_name is not None: - external_fields_table = self._external_fields[ - external_fields_table_name - ] - db_name = f"{external_fields_table.table_name}.{attr}" - external_field_table_deps.append(external_fields_table.table_name) - elif ( - isinstance(values, dict) or isinstance(values, list) - ) and not attr in self.__foreign_tables: - db_name = f"{self._table_name}.{self.__db_names[attr]}" - elif attr in self._operators: - db_name = f"{self._table_name}.{self.__db_names[attr]}" - else: - db_name = self.__db_names[attr] - - if isinstance(values, dict): - for operator, value in values.items(): - f_conditions.append( - self._build_condition(f"{db_name}", operator, value) - ) - elif isinstance(values, list): - sub_conditions = [] - for value in values: - if isinstance(value, dict): - for operator, val in value.items(): - sub_conditions.append( - self._build_condition(f"{db_name}", operator, val) - ) - else: - sub_conditions.append( - self._get_value_validation_sql(db_name, value) - ) - f_conditions.append(f"({' OR '.join(sub_conditions)})") - elif attr in self._operators: - conditions.append(f"{self._build_condition(db_name, attr, values)}") - else: - f_conditions.append(self._get_value_validation_sql(db_name, values)) - - conditions.append(f"({' OR '.join(f_conditions)})") - - return " AND ".join(conditions), external_field_table_deps - - @staticmethod - def _build_fuzzy_conditions( - fields: list[str], term: str, threshold: int = 10 - ) -> list[str]: - conditions = [] - for field in fields: - conditions.append( - f"levenshtein({field}::TEXT, '{term}') <= {threshold}" - ) # Adjust the threshold as needed - - return conditions - - def _build_foreign_conditions( - self, base_attr: str, table: str, values: dict - ) -> (list[str], list[str]): - """ - Build SQL conditions for foreign key references - :param base_attr: Base attribute name - :param table: Foreign table name - :param values: Filter values - :return: List of conditions, List of external field tables - """ - external_field_table_deps = [] - conditions = [] - for attr, sub_values in values.items(): - if isinstance(attr, property): - attr = attr.fget.__name__ - - if attr in self.__foreign_tables: - foreign_table = self.__foreign_tables[attr] - sub_conditions, eftd = self._build_foreign_conditions( - attr, foreign_table, sub_values - ) - if len(eftd) > 0: - external_field_table_deps.extend(eftd) - - conditions.extend(sub_conditions) - continue - - if attr == "fuzzy": - self._handle_fuzzy_filter_conditions( - conditions, external_field_table_deps, sub_values - ) - continue - - external_fields_table_name = self._get_external_field_key(attr) - if external_fields_table_name is not None: - external_fields_table = self._external_fields[ - external_fields_table_name - ] - db_name = f"{external_fields_table.table_name}.{attr}" - external_field_table_deps.append(external_fields_table.table_name) - elif attr in self._operators: - db_name = f"{self._table_name}.{self.__foreign_table_keys[base_attr]}" - else: - db_name = f"{table}.{attr.lower().replace('_', '')}" - - if isinstance(sub_values, dict): - for operator, value in sub_values.items(): - conditions.append( - f"{self._build_condition(db_name, operator, value)}" - ) - elif isinstance(sub_values, list): - sub_conditions = [] - for value in sub_values: - if isinstance(value, dict): - for operator, val in value.items(): - sub_conditions.append( - f"{self._build_condition(db_name, operator, val)}" - ) - else: - sub_conditions.append( - self._get_value_validation_sql(db_name, value) - ) - conditions.append(f"({' OR '.join(sub_conditions)})") - elif attr in self._operators: - conditions.append(f"{self._build_condition(db_name, attr, sub_values)}") - else: - conditions.append(self._get_value_validation_sql(db_name, sub_values)) - - return conditions, external_field_table_deps - - def _handle_fuzzy_filter_conditions( - self, conditions, external_field_table_deps, sub_values - ): - fuzzy_fields = get_value(sub_values, "fields", list[str]) - fuzzy_fields_db_names = [] - for fuzzy_field in fuzzy_fields: - external_fields_table_name = self._get_external_field_key(fuzzy_field) - if external_fields_table_name is not None: - external_fields_table = self._external_fields[ - external_fields_table_name - ] - fuzzy_fields_db_names.append( - f"{external_fields_table.table_name}.{fuzzy_field}" - ) - external_field_table_deps.append(external_fields_table.table_name) - elif fuzzy_field in self.__db_names: - fuzzy_fields_db_names.append( - f"{self._table_name}.{self.__db_names[fuzzy_field]}" - ) - elif fuzzy_field in self.__foreign_tables: - fuzzy_fields_db_names.append( - f"{self._table_name}.{self.__foreign_table_keys[fuzzy_field]}" - ) - else: - fuzzy_fields_db_names.append( - self.__db_names[camel_to_snake(fuzzy_field)] - ) - conditions.append( - f"({' OR '.join( - self._build_fuzzy_conditions( - [x for x in fuzzy_fields_db_names], - get_value(sub_values, "term", str), - get_value(sub_values, "threshold", int, 5), - ) - ) - })" - ) - - def _get_value_validation_sql(self, field: str, value: Any): - value = self._get_value_sql(value) - field_selector = ( - f"{self._table_name}.{field}" - if not field.startswith(self._table_name) - else field - ) - if field in self.__foreign_tables: - field_selector = self.__db_names[field] - - if value == "NULL": - return f"{field_selector} IS NULL" - return f"{field_selector} = {value}" - - def _build_condition(self, db_name: str, operator: str, value: Any) -> str: - """ - Build individual SQL condition based on the operator - :param db_name: - :param operator: - :param value: - :return: - """ - attr = db_name.split(".")[-1] - - if attr in self.__date_attributes: - db_name = self._attr_from_date_to_char(db_name) - - sql_value = self._get_value_sql(value) - if operator == "equal": - return f"{db_name} = {sql_value}" - elif operator == "notEqual": - return f"{db_name} != {sql_value}" - elif operator == "greater": - return f"{db_name} > {sql_value}" - elif operator == "greaterOrEqual": - return f"{db_name} >= {sql_value}" - elif operator == "less": - return f"{db_name} < {sql_value}" - elif operator == "lessOrEqual": - return f"{db_name} <= {sql_value}" - elif operator == "isNull": - return f"{db_name} IS NULL" if sql_value else f"{db_name} IS NOT NULL" - elif operator == "isNotNull": - return f"{db_name} IS NOT NULL" if sql_value else f"{db_name} IS NULL" - elif operator == "contains": - return f"{db_name} LIKE '%{value}%'" - elif operator == "notContains": - return f"{db_name} NOT LIKE '%{value}%'" - elif operator == "startsWith": - return f"{db_name} LIKE '{value}%'" - elif operator == "endsWith": - return f"{db_name} LIKE '%{value}'" - elif operator == "in": - return ( - f"{db_name} IN ({', '.join([self._get_value_sql(x) for x in value])})" - ) - elif operator == "notIn": - return f"{db_name} NOT IN ({', '.join([self._get_value_sql(x) for x in value])})" - else: - raise ValueError(f"Unsupported operator: {operator}") + return value @staticmethod def _attr_from_date_to_char(attr: str) -> str: return f"TO_CHAR({attr}, 'YYYY-MM-DD HH24:MI:SS.US TZ')" - def _build_order_by(self, sorts: AttributeSorts) -> (str, list[str]): - """ - Build SQL order by clause from the given sorts - :param sorts: - :return: - """ - external_field_table_deps = [] - if not isinstance(sorts, list): - sorts = [sorts] - - sort_clauses = [] - for sort in sorts: - for attr, direction in sort.items(): - if isinstance(attr, property): - attr = attr.fget.__name__ - - if attr in self.__foreign_tables: - foreign_table = self.__foreign_tables[attr] - f_sorts, eftd = self._build_foreign_order_by( - foreign_table, direction - ) - if eftd: - external_field_table_deps.extend(eftd) - - sort_clauses.extend(f_sorts) - continue - - external_fields_table_name = self._get_external_field_key(attr) - if external_fields_table_name is not None: - external_fields_table = self._external_fields[ - external_fields_table_name - ] - db_name = f"{external_fields_table.table_name}.{attr}" - external_field_table_deps.append(external_fields_table.table_name) - else: - db_name = self.__db_names[attr] - sort_clauses.append(f"{db_name} {direction.upper()}") - - return ", ".join(sort_clauses), external_field_table_deps - - def _build_foreign_order_by( - self, table: str, direction: dict - ) -> (list[str], list[str]): - """ - Build SQL order by clause for foreign key references - :param table: Foreign table name - :param direction: Sort direction - :return: List of order by clauses - """ - external_field_table_deps = [] - sort_clauses = [] - for attr, sub_direction in direction.items(): - if isinstance(attr, property): - attr = attr.fget.__name__ - - if attr in self.__foreign_tables: - foreign_table = self.__foreign_tables[attr] - f_sorts, eftd = self._build_foreign_order_by(foreign_table, direction) - if eftd: - external_field_table_deps.extend(eftd) - - sort_clauses.extend(f_sorts) - continue - - external_fields_table_name = self._get_external_field_key(attr) - if external_fields_table_name is not None: - external_fields_table = self._external_fields[ - external_fields_table_name - ] - db_name = f"{external_fields_table.table_name}.{attr}" - external_field_table_deps.append(external_fields_table.table_name) - else: - db_name = f"{table}.{attr.lower().replace('_', '')}" - sort_clauses.append(f"{db_name} {sub_direction.upper()}") - - return sort_clauses, external_field_table_deps - @staticmethod async def _get_editor_id(obj: T_DBM): editor_id = obj.editor_id diff --git a/api/src/core/database/db_context.py b/api/src/core/database/db_context.py index 7e363c9..3c169c9 100644 --- a/api/src/core/database/db_context.py +++ b/api/src/core/database/db_context.py @@ -39,13 +39,13 @@ class DBContext: return await self._pool.select_map(statement, args) except (OperationalError, PoolTimeout) as e: if self._fails >= 3: - logger.error(f"Database error caused by {statement}", e) + logger.error(f"Database error caused by `{statement}`", e) uid = uuid.uuid4() raise Exception( f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" ) - logger.error(f"Database error caused by {statement}", e) + logger.error(f"Database error caused by `{statement}`", e) self._fails += 1 try: logger.debug("Retry select") @@ -54,7 +54,7 @@ class DBContext: pass return [] except Exception as e: - logger.error(f"Database error caused by {statement}", e) + logger.error(f"Database error caused by `{statement}`", e) raise e async def select( @@ -65,13 +65,13 @@ class DBContext: return await self._pool.select(statement, args) except (OperationalError, PoolTimeout) as e: if self._fails >= 3: - logger.error(f"Database error caused by {statement}", e) + logger.error(f"Database error caused by `{statement}`", e) uid = uuid.uuid4() raise Exception( f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" ) - logger.error(f"Database error caused by {statement}", e) + logger.error(f"Database error caused by `{statement}`", e) self._fails += 1 try: logger.debug("Retry select") @@ -80,5 +80,5 @@ class DBContext: pass return [] except Exception as e: - logger.error(f"Database error caused by {statement}", e) + logger.error(f"Database error caused by `{statement}`", e) raise e diff --git a/api/src/core/database/sql_select_builder.py b/api/src/core/database/sql_select_builder.py new file mode 100644 index 0000000..f2644ed --- /dev/null +++ b/api/src/core/database/sql_select_builder.py @@ -0,0 +1,150 @@ +from typing import Optional + +from core.database.external_data_temp_table_builder import ExternalDataTempTableBuilder + + +class SQLSelectBuilder: + + def __init__(self, table_name: str, primary_key: str): + self._table_name = table_name + self._primary_key = primary_key + + self._temp_tables: dict[str, ExternalDataTempTableBuilder] = {} + self._to_use_temp_tables: list[str] = [] + self._attributes: list[str] = [] + self._tables: list[str] = [table_name] + self._joins: dict[str, (str, str)] = {} + self._conditions: list[str] = [] + self._order_by: str = "" + self._limit: Optional[int] = None + self._offset: Optional[int] = None + + def with_temp_table( + self, temp_table: ExternalDataTempTableBuilder + ) -> "SQLSelectBuilder": + self._temp_tables[temp_table.table_name] = temp_table + return self + + def use_temp_table(self, temp_table_name: str): + if temp_table_name not in self._temp_tables: + raise ValueError(f"Temp table {temp_table_name} not found.") + + self._to_use_temp_tables.append(temp_table_name) + + def with_attribute(self, attr: str, ignore_table_name=False) -> "SQLSelectBuilder": + if not ignore_table_name and not attr.startswith(self._table_name): + attr = f"{self._table_name}.{attr}" + + self._attributes.append(attr) + return self + + def with_foreign_attribute(self, attr: str) -> "SQLSelectBuilder": + self._attributes.append(attr) + return self + + def with_table(self, table_name: str) -> "SQLSelectBuilder": + self._tables.append(table_name) + return self + + def _check_prefix(self, attr: str) -> str: + assert attr is not None + + valid_prefixes = [ + "levenshtein", + self._table_name, + *self._joins.keys(), + *self._temp_tables.keys(), + ] + if not any(attr.startswith(f"{prefix}.") for prefix in valid_prefixes): + attr = f"{self._table_name}.{attr}" + + return attr + + def with_value_condition( + self, attr: str, operator: str, value: str + ) -> "SQLSelectBuilder": + attr = self._check_prefix(attr) + self._conditions.append(f"{attr} {operator} {value}") + return self + + def with_levenshtein_condition(self, condition: str) -> "SQLSelectBuilder": + self._conditions.append(condition) + return self + + def with_condition(self, attr: str, operator: str) -> "SQLSelectBuilder": + attr = self._check_prefix(attr) + self._conditions.append(f"{attr} {operator}") + return self + + def with_grouped_conditions(self, conditions: list[str]) -> "SQLSelectBuilder": + self._conditions.append(f"({' AND '.join(conditions)})") + return self + + def with_left_join(self, table: str, on: str) -> "SQLSelectBuilder": + if table in self._joins: + self._joins[table] = (f"{self._joins[table][0]} AND {on}", "LEFT") + + self._joins[table] = (on, "LEFT") + return self + + def with_inner_join(self, table: str, on: str) -> "SQLSelectBuilder": + if table in self._joins: + self._joins[table] = (f"{self._joins[table][0]} AND {on}", "INNER") + + self._joins[table] = (on, "INNER") + return self + + def with_right_join(self, table: str, on: str) -> "SQLSelectBuilder": + if table in self._joins: + self._joins[table] = (f"{self._joins[table][0]} AND {on}", "RIGHT") + + self._joins[table] = (on, "RIGHT") + return self + + def with_limit(self, limit: int) -> "SQLSelectBuilder": + self._limit = limit + return self + + def with_offset(self, offset: int) -> "SQLSelectBuilder": + self._offset = offset + return self + + def with_order_by(self, column: str, direction: str = "ASC") -> "SQLSelectBuilder": + self._order_by = f"{column} {direction}" + return self + + async def _handle_temp_table_use(self, query) -> str: + new_query = "" + + for temp_table_name in self._to_use_temp_tables: + temp_table = self._temp_tables[temp_table_name] + new_query += await self._temp_tables[temp_table_name].build() + self.with_left_join( + temp_table.table_name, + f"{temp_table.join_ref_table}.{self._primary_key} = {temp_table.table_name}.{temp_table.primary_key}", + ) + + return f"{new_query} {query}" if new_query != "" else query + + async def build(self) -> str: + query = await self._handle_temp_table_use("") + + attributes = ", ".join(self._attributes) if self._attributes else "*" + query += f"SELECT {attributes} FROM {", ".join(self._tables)}" + + for join in self._joins: + query += f" {self._joins[join][1]} JOIN {join} ON {self._joins[join][0]}" + + if self._conditions: + query += " WHERE " + " AND ".join(self._conditions) + + if self._order_by: + query += f" ORDER BY {self._order_by}" + + if self._limit is not None: + query += f" LIMIT {self._limit}" + + if self._offset is not None: + query += f" OFFSET {self._offset}" + + return query diff --git a/api/src/data/schemas/public/user_setting_dao.py b/api/src/data/schemas/public/user_setting_dao.py index 78f3411..c6d8af0 100644 --- a/api/src/data/schemas/public/user_setting_dao.py +++ b/api/src/data/schemas/public/user_setting_dao.py @@ -21,4 +21,4 @@ class UserSettingDao(DbModelDaoABC[UserSetting]): ) -userSettingsDao = UserSettingDao() +userSettingDao = UserSettingDao() diff --git a/api/src/data/schemas/system/setting_dao.py b/api/src/data/schemas/system/setting_dao.py index d7262eb..c2c836a 100644 --- a/api/src/data/schemas/system/setting_dao.py +++ b/api/src/data/schemas/system/setting_dao.py @@ -17,4 +17,4 @@ class SettingDao(DbModelDaoABC[Setting]): return await self.find_single_by({Setting.key: key}) -settingsDao = SettingDao() +settingDao = SettingDao() diff --git a/api/src/data/seeder/settings_seeder.py b/api/src/data/seeder/settings_seeder.py index dcdff29..989d637 100644 --- a/api/src/data/seeder/settings_seeder.py +++ b/api/src/data/seeder/settings_seeder.py @@ -3,7 +3,7 @@ from typing import Any from core.logger import DBLogger from data.abc.data_seeder_abc import DataSeederABC from data.schemas.system.setting import Setting -from data.schemas.system.setting_dao import settingsDao +from data.schemas.system.setting_dao import settingDao logger = DBLogger(__name__) @@ -18,8 +18,8 @@ class SettingsSeeder(DataSeederABC): @staticmethod async def _seed_if_not_exists(key: str, value: Any): - existing = await settingsDao.find_by_key(key) + existing = await settingDao.find_by_key(key) if existing is not None: return - await settingsDao.create(Setting(0, key, str(value))) + await settingDao.create(Setting(0, key, str(value))) diff --git a/api/src/service/data_privacy_service.py b/api/src/service/data_privacy_service.py new file mode 100644 index 0000000..974f0d6 --- /dev/null +++ b/api/src/service/data_privacy_service.py @@ -0,0 +1,119 @@ +import importlib +import json +from typing import Type + +from api.auth.keycloak_client import Keycloak +from core.database.abc.data_access_object_abc import DataAccessObjectABC +from core.database.abc.db_model_dao_abc import DbModelDaoABC +from core.logger import Logger +from core.string import first_to_lower +from data.schemas.administration.user_dao import userDao + +logger = Logger("DataPrivacy") + + +class DataPrivacyService: + + @staticmethod + def _dynamic_import_dao(dao_class: Type[DataAccessObjectABC]): + """ + Dynamically import a DAO class and its instance. + :param dao_class: The DAO class to be imported. + :return: The DAO instance. + """ + module = importlib.import_module(dao_class.__module__) + dao_instance = getattr( + module, first_to_lower(first_to_lower(dao_class.__name__)) + ) + return dao_instance + + @classmethod + async def _collect_user_relevant_dao(cls): + """ + Collect all DAO classes that are relevant for data privacy. + :return: List of relevant DAO classes. + """ + # This method should return a list of DAOs that are relevant for data privacy + # For example, it could return a list of DAOs that contain user data + classes: list[DataAccessObjectABC] = [ + cls._dynamic_import_dao(dao) for dao in DbModelDaoABC.__subclasses__() + ] + return [x for x in classes if x.has_attribute("user_id")] + + @classmethod + async def export_user_data(cls, user_id: int): + """ + Export user data from the database. + :param user_id: ID of the user whose data is to be exported. + :return: User data in a structured format. + """ + # Logic to export user data + user = await userDao.find_by_id(user_id) + if user is None: + raise ValueError("User not found") + + collected_data = [userDao.to_dict(await userDao.find_by_id(user_id))] + + daos = await cls._collect_user_relevant_dao() + for dao in daos: + data = await dao.find_by([{"userid": user_id}]) + collected_data.append([dao.to_dict(x) for x in data]) + + return json.dumps(collected_data, default=str) + + @staticmethod + async def anonymize_user(user_id: int): + """ + Anonymize user data in the database. + :param user_id: ID of the user to be anonymized. + """ + user = await userDao.find_by_id(user_id) + if user is None: + raise ValueError("User not found") + + keycloak_id = user.keycloak_id + + # Anonymize internal data + user.keycloak_id = "ANONYMIZED" + userDao.update(user) + + # Anonymize external data + try: + Keycloak.admin.delete_user(keycloak_id) + except Exception as e: + logger.error(f"Failed to anonymize external data for user {user_id}", e) + raise ValueError("Failed to anonymize external data") from e + + @classmethod + async def delete_user_data(cls, user_id: int): + """ + Delete user data from the database. + :param user_id: ID of the user whose data is to be deleted. + """ + user = await userDao.find_by_id(user_id) + if user is None: + raise ValueError("User not found") + + keycloak_id = user.keycloak_id + + daos = await cls._collect_user_relevant_dao() + for dao in daos: + data = await dao.find_by([{"userid": user_id}]) + try: + await dao.delete_many(data, hard_delete=True) + except Exception as e: + logger.error(f"Failed to delete data for user {user_id}", e) + raise ValueError("Failed to delete data") from e + + try: + await userDao.delete(user) + except Exception as e: + logger.error(f"Failed to delete user {user_id}", e) + raise ValueError("Failed to delete user") from e + + # Delete external data + try: + Keycloak.admin.delete_user(keycloak_id) + except Exception as e: + logger.error(f"Failed to delete external data for user {user_id}", e) + raise ValueError("Failed to delete external data") from e diff --git a/web/src/app/components/header/header.component.ts b/web/src/app/components/header/header.component.ts index d4d1d72..6a0a3b0 100644 --- a/web/src/app/components/header/header.component.ts +++ b/web/src/app/components/header/header.component.ts @@ -1,5 +1,5 @@ import { Component, OnDestroy, OnInit } from '@angular/core'; -import { MenuItem, PrimeNGConfig } from 'primeng/api'; +import { MenuItem, MenuItemCommandEvent, PrimeNGConfig } from 'primeng/api'; import { Subject } from 'rxjs'; import { takeUntil } from 'rxjs/operators'; import { TranslateService } from '@ngx-translate/core'; @@ -11,6 +11,7 @@ import { SidebarService } from 'src/app/service/sidebar.service'; import { ConfigService } from 'src/app/service/config.service'; import { UserSettingsService } from 'src/app/service/user_settings.service'; import { SettingsService } from 'src/app/service/settings.service'; +import { environment } from 'src/environments/environment'; @Component({ selector: 'app-header', @@ -48,11 +49,12 @@ export class HeaderComponent implements OnInit, OnDestroy { }); this.auth.user$.pipe(takeUntil(this.unsubscribe$)).subscribe(async user => { + this.user = user; + await this.initMenuLists(); await this.loadTheme(); await this.loadLang(); - this.user = user; this.guiService.loadedGuiSettings$.next(true); }); @@ -117,14 +119,41 @@ export class HeaderComponent implements OnInit, OnDestroy { visible: !!this.user, }, { - separator: true, + label: this.translateService.instant('header.privacy'), + items: [ + { + label: this.translateService.instant('privacy.export_data'), + command: () => {}, + icon: 'pi pi-download', + }, + { + label: this.translateService.instant('privacy.delete_data'), + command: () => {}, + icon: 'pi pi-trash', + }, + ], }, { - label: this.translateService.instant('header.logout'), - command: async () => { - await this.auth.logout(); - }, - icon: 'pi pi-sign-out', + label: this.translateService.instant('header.profile'), + items: [ + { + label: this.translateService.instant('header.edit_profile'), + command: () => { + window.open( + `${this.config.settings.keycloak.url}/realms/${this.config.settings.keycloak.realm}/account`, + '_blank' + ); + }, + icon: 'pi pi-user-edit', + }, + { + label: this.translateService.instant('header.logout'), + command: async () => { + await this.auth.logout(); + }, + icon: 'pi pi-sign-out', + }, + ], }, ]; }