diff --git a/src/bot_data/abc/auth_user_repository_abc.py b/src/bot_data/abc/auth_user_repository_abc.py new file mode 100644 index 0000000000..161187461e --- /dev/null +++ b/src/bot_data/abc/auth_user_repository_abc.py @@ -0,0 +1,41 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from cpl_query.extension import List + +from bot_api.filter.auth_user_select_criteria import AuthUserSelectCriteria +from bot_data.filtered_result import FilteredResult +from bot_data.model.auth_user import AuthUser + + +class AuthUserRepositoryABC(ABC): + + @abstractmethod + def __init__(self): pass + + @abstractmethod + def get_all_auth_users(self) -> List[AuthUser]: pass + + @abstractmethod + def get_filtered_auth_users_async(self, criteria: AuthUserSelectCriteria) -> FilteredResult: pass + + @abstractmethod + def get_auth_user_by_email_async(self, email: str) -> AuthUser: pass + + @abstractmethod + def find_auth_user_by_email_async(self, email: str) -> Optional[AuthUser]: pass + + @abstractmethod + def find_auth_user_by_confirmation_id_async(self, id: str) -> Optional[AuthUser]: pass + + @abstractmethod + def find_auth_user_by_forgot_password_id_async(self, id: str) -> Optional[AuthUser]: pass + + @abstractmethod + def add_auth_user(self, user: AuthUser): pass + + @abstractmethod + def update_auth_user(self, user: AuthUser): pass + + @abstractmethod + def delete_auth_user(self, user: AuthUser): pass diff --git a/src/bot_data/data_module.py b/src/bot_data/data_module.py index 99fc12ad0d..7537b6f282 100644 --- a/src/bot_data/data_module.py +++ b/src/bot_data/data_module.py @@ -5,6 +5,7 @@ from cpl_discord.service.discord_collection_abc import DiscordCollectionABC from bot_core.abc.module_abc import ModuleABC from bot_core.configuration.feature_flags_enum import FeatureFlagsEnum +from bot_data.abc.auth_user_repository_abc import AuthUserRepositoryABC from bot_data.abc.auto_role_repository_abc import AutoRoleRepositoryABC from bot_data.abc.client_repository_abc import ClientRepositoryABC from bot_data.abc.known_user_repository_abc import KnownUserRepositoryABC @@ -12,6 +13,7 @@ from bot_data.abc.server_repository_abc import ServerRepositoryABC from bot_data.abc.user_joined_server_repository_abc import UserJoinedServerRepositoryABC from bot_data.abc.user_joined_voice_channel_abc import UserJoinedVoiceChannelRepositoryABC from bot_data.abc.user_repository_abc import UserRepositoryABC +from bot_data.service.auth_user_repository_service import AuthUserRepositoryService from bot_data.service.auto_role_repository_service import AutoRoleRepositoryService from bot_data.service.client_repository_service import ClientRepositoryService from bot_data.service.known_user_repository_service import KnownUserRepositoryService @@ -30,6 +32,7 @@ class DataModule(ModuleABC): pass def configure_services(self, services: ServiceCollectionABC, env: ApplicationEnvironmentABC): + services.add_transient(AuthUserRepositoryABC, AuthUserRepositoryService) services.add_transient(ServerRepositoryABC, ServerRepositoryService) services.add_transient(UserRepositoryABC, UserRepositoryService) services.add_transient(ClientRepositoryABC, ClientRepositoryService) diff --git a/src/bot_data/filtered_result.py b/src/bot_data/filtered_result.py new file mode 100644 index 0000000000..8bdf90a314 --- /dev/null +++ b/src/bot_data/filtered_result.py @@ -0,0 +1,24 @@ +from cpl_query.extension import List + + +class FilteredResult: + + def __init__(self, result: List = None, total_count: int = 0): + self._result = [] if result is None else result + self._total_count = total_count + + @property + def result(self) -> List: + return self._result + + @result.setter + def result(self, value: List): + self._result = value + + @property + def total_count(self) -> int: + return self._total_count + + @total_count.setter + def total_count(self, value: int): + self._total_count = value diff --git a/src/bot_data/model/auth_user.py b/src/bot_data/model/auth_user.py index 37b50d99c9..6078f2099b 100644 --- a/src/bot_data/model/auth_user.py +++ b/src/bot_data/model/auth_user.py @@ -33,12 +33,44 @@ class AuthUser(TableABC): self._forgot_password_id = forgot_password_id self._refresh_token_expire_time = refresh_token_expire_time - self._auth_role_id = auth_role.value + self._auth_role_id = auth_role TableABC.__init__(self) self._created_at = created_at if created_at is not None else self._created_at self._modified_at = modified_at if modified_at is not None else self._modified_at + @property + def id(self) -> int: + return self._auth_user_id + + @property + def first_name(self) -> str: + return self._first_name + + @property + def last_name(self) -> str: + return self._last_name + + @property + def email(self) -> str: + return self._email + + @property + def password(self) -> str: + return self._password + + @property + def refresh_token(self) -> str: + return self._refresh_token + + @property + def refresh_token_expire_time(self) -> datetime: + return self._refresh_token_expire_time + + @property + def auth_role(self) -> AuthRoleEnum: + return self._auth_role_id + @staticmethod def get_select_all_string() -> str: return str(f""" @@ -52,6 +84,27 @@ class AuthUser(TableABC): WHERE `Id` = {id}; """) + @staticmethod + def get_select_by_email_string(email: str) -> str: + return str(f""" + SELECT * FROM `AuthUsers` + WHERE `EMail` = {email}; + """) + + @staticmethod + def get_select_by_confirmation_id_string(id: str) -> str: + return str(f""" + SELECT * FROM `AuthUsers` + WHERE `ConfirmationId` = {id}; + """) + + @staticmethod + def get_select_by_forgot_password_i_string(id: str) -> str: + return str(f""" + SELECT * FROM `AuthUsers` + WHERE `ForgotPasswordId` = {id}; + """) + @property def insert_string(self) -> str: return str(f""" @@ -78,7 +131,7 @@ class AuthUser(TableABC): {self._confirmation_id}, {self._forgot_password_id}, {self._refresh_token_expire_time}, - {self._auth_role_id} + {self._auth_role_id.value} {self._created_at}, {self._modified_at} ) @@ -96,7 +149,7 @@ class AuthUser(TableABC): `ConfirmationId` = '{self._confirmation_id}', `ForgotPasswordId` = '{self._forgot_password_id}', `RefreshTokenExpiryTime` = '{self._refresh_token_expire_time}', - `AutoRole` = {self._auth_role_id}, + `AutoRole` = {self._auth_role_id.value}, `LastModifiedAt` = '{self._modified_at}' WHERE `AuthUsers`.`Id` = {self._auth_user_id}; """) diff --git a/src/bot_data/service/auth_user_repository_service.py b/src/bot_data/service/auth_user_repository_service.py new file mode 100644 index 0000000000..aca3e94af2 --- /dev/null +++ b/src/bot_data/service/auth_user_repository_service.py @@ -0,0 +1,125 @@ +from typing import Optional + +from cpl_core.database.context import DatabaseContextABC +from cpl_query.extension import List + +from bot_api.filter.auth_user_select_criteria import AuthUserSelectCriteria +from bot_core.logging.database_logger import DatabaseLogger +from bot_data.abc.auth_user_repository_abc import AuthUserRepositoryABC +from bot_data.filtered_result import FilteredResult +from bot_data.model.auth_role_enum import AuthRoleEnum +from bot_data.model.auth_user import AuthUser + + +class AuthUserRepositoryService(AuthUserRepositoryABC): + + def __init__(self, logger: DatabaseLogger, db_context: DatabaseContextABC): + self._logger = logger + self._context = db_context + + AuthUserRepositoryABC.__init__(self) + + @staticmethod + def _user_from_result(result: tuple) -> AuthUser: + return AuthUser( + result[1], + result[2], + result[3], + result[4], + result[5], + result[6], + result[7], + result[8], + AuthRoleEnum(result[9]), + id=result[0] + ) + + def get_all_auth_users(self) -> List[AuthUser]: + users = List(AuthUser) + self._logger.trace(__name__, f'Send SQL command: {AuthUser.get_select_all_string()}') + results = self._context.select(AuthUser.get_select_all_string()) + for result in results: + self._logger.trace(__name__, f'Get auth user with id {result[0]}') + users.append(self._user_from_result(result)) + + return users + + def get_filtered_auth_users_async(self, criteria: AuthUserSelectCriteria) -> FilteredResult: + users = self.get_all_auth_users() + self._logger.trace(__name__, f'Send SQL command: {AuthUser.get_select_all_string()}') + + query = users + + if criteria.first_name is not None and criteria.first_name != '': + query = query.where(lambda x: criteria.first_name in x.first_name or x.first_name == criteria.first_name) + + if criteria.last_name is not None and criteria.last_name != '': + query = query.where(lambda x: criteria.last_name in x.last_name or x.last_name == criteria.last_name) + + if criteria.email is not None and criteria.email != '': + query = query.where(lambda x: criteria.email in x.email or x.email == criteria.email) + + if criteria.auth_role is not None: + query = query.where(lambda x: x.auth_role == AuthRoleEnum(criteria.auth_role)) + + # sort + if criteria.sort_column is not None and criteria.sort_column != '' and criteria.sort_direction is not None and criteria.sort_direction: + crit_sort_direction = criteria.sort_direction.lower() + if crit_sort_direction == "desc" or crit_sort_direction == "descending": + query = query.order_by_descending(lambda x: getattr(x, criteria.sort_column)) + else: + query = query.order_by(lambda x: getattr(x, criteria.sort_column)) + + skip = criteria.page_size * criteria.page_index + result = FilteredResult() + result.total_count = query.count() + result.result = query.skip(skip).take(criteria.page_size) + + return result + + def get_auth_user_by_email_async(self, email: str) -> AuthUser: + self._logger.trace(__name__, f'Send SQL command: {AuthUser.get_select_by_email_string(email)}') + result = self._context.select(AuthUser.get_select_by_email_string(email))[0] + return self._user_from_result(result) + + def find_auth_user_by_email_async(self, email: str) -> Optional[AuthUser]: + self._logger.trace(__name__, f'Send SQL command: {AuthUser.get_select_by_email_string(email)}') + result = self._context.select(AuthUser.get_select_by_email_string(email)) + if result is None or len(result) == 0: + return None + + result = result[0] + + return self._user_from_result(result) + + def find_auth_user_by_confirmation_id_async(self, id: str) -> Optional[AuthUser]: + self._logger.trace(__name__, f'Send SQL command: {AuthUser.get_select_by_email_string(id)}') + result = self._context.select(AuthUser.get_select_by_email_string(id)) + if result is None or len(result) == 0: + return None + + result = result[0] + + return self._user_from_result(result) + + def find_auth_user_by_forgot_password_id_async(self, id: str) -> Optional[AuthUser]: + self._logger.trace(__name__, f'Send SQL command: {AuthUser.get_select_by_email_string(id)}') + result = self._context.select(AuthUser.get_select_by_email_string(id)) + if result is None or len(result) == 0: + return None + + result = result[0] + + return self._user_from_result(result) + + def add_auth_user(self, user: AuthUser): + self._logger.trace(__name__, f'Send SQL command: {user.insert_string}') + self._context.cursor.execute(user.insert_string) + + def update_auth_user(self, user: AuthUser): + self._logger.trace(__name__, f'Send SQL command: {user.udpate_string}') + self._context.cursor.execute(user.udpate_string) + + def delete_auth_user(self, user: AuthUser): + self._logger.trace(__name__, f'Send SQL command: {user.delete_string}') + self._context.cursor.execute(user.delete_string)