diff --git a/kdb-bot/src/bot_data/abc/level_repository_abc.py b/kdb-bot/src/bot_data/abc/level_repository_abc.py new file mode 100644 index 0000000000..89a2a67f55 --- /dev/null +++ b/kdb-bot/src/bot_data/abc/level_repository_abc.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from cpl_query.extension import List + +from bot_data.model.level import Level + + +class LevelRepositoryABC(ABC): + + @abstractmethod + def __init__(self): pass + + @abstractmethod + def get_levels(self) -> List[Level]: pass + + @abstractmethod + def get_level_by_id(self, id: int) -> Level: pass + + @abstractmethod + def find_level_by_id(self, id: int) -> Optional[Level]: pass + + @abstractmethod + def get_levels_by_server_id(self, server_id: int) -> List[Level]: pass + + @abstractmethod + def find_levels_by_server_id(self, server_id: int) -> Optional[List[Level]]: pass + + @abstractmethod + def add_level(self, level: Level): pass + + @abstractmethod + def update_level(self, level: Level): pass + + @abstractmethod + def delete_level(self, level: Level): pass diff --git a/kdb-bot/src/bot_data/data_module.py b/kdb-bot/src/bot_data/data_module.py index 7537b6f282..bf8786e529 100644 --- a/kdb-bot/src/bot_data/data_module.py +++ b/kdb-bot/src/bot_data/data_module.py @@ -9,6 +9,7 @@ from bot_data.abc.auth_user_repository_abc import AuthUserRepositoryABC from bot_data.abc.auto_role_repository_abc import AutoRoleRepositoryABC from bot_data.abc.client_repository_abc import ClientRepositoryABC from bot_data.abc.known_user_repository_abc import KnownUserRepositoryABC +from bot_data.abc.level_repository_abc import LevelRepositoryABC from bot_data.abc.server_repository_abc import ServerRepositoryABC from bot_data.abc.user_joined_server_repository_abc import UserJoinedServerRepositoryABC from bot_data.abc.user_joined_voice_channel_abc import UserJoinedVoiceChannelRepositoryABC @@ -17,6 +18,7 @@ from bot_data.service.auth_user_repository_service import AuthUserRepositoryServ from bot_data.service.auto_role_repository_service import AutoRoleRepositoryService from bot_data.service.client_repository_service import ClientRepositoryService from bot_data.service.known_user_repository_service import KnownUserRepositoryService +from bot_data.service.level_repository_service import LevelRepositoryService from bot_data.service.server_repository_service import ServerRepositoryService from bot_data.service.user_joined_server_repository_service import UserJoinedServerRepositoryService from bot_data.service.user_joined_voice_channel_service import UserJoinedVoiceChannelRepositoryService @@ -40,3 +42,4 @@ class DataModule(ModuleABC): services.add_transient(UserJoinedServerRepositoryABC, UserJoinedServerRepositoryService) services.add_transient(UserJoinedVoiceChannelRepositoryABC, UserJoinedVoiceChannelRepositoryService) services.add_transient(AutoRoleRepositoryABC, AutoRoleRepositoryService) + services.add_transient(LevelRepositoryABC, LevelRepositoryService) diff --git a/kdb-bot/src/bot_data/model/level.py b/kdb-bot/src/bot_data/model/level.py index abd7e68737..fcde27bfcf 100644 --- a/kdb-bot/src/bot_data/model/level.py +++ b/kdb-bot/src/bot_data/model/level.py @@ -73,7 +73,7 @@ class Level(TableABC): """) @staticmethod - def get_select_by_server_id_string(dc_id: int, s_id: int) -> str: + def get_select_by_server_id_string(s_id: int) -> str: return str(f""" SELECT * FROM `Levels` WHERE `ServerId` = {s_id}; diff --git a/kdb-bot/src/bot_data/service/level_repository_service.py b/kdb-bot/src/bot_data/service/level_repository_service.py new file mode 100644 index 0000000000..37ee595298 --- /dev/null +++ b/kdb-bot/src/bot_data/service/level_repository_service.py @@ -0,0 +1,97 @@ +from typing import Optional + +from cpl_core.database.context import DatabaseContextABC +from cpl_query.extension import List + +from bot_core.logging.database_logger import DatabaseLogger +from bot_data.abc.server_repository_abc import ServerRepositoryABC +from bot_data.abc.level_repository_abc import LevelRepositoryABC +from bot_data.model.level import Level + + +class LevelRepositoryService(LevelRepositoryABC): + + def __init__(self, logger: DatabaseLogger, db_context: DatabaseContextABC, servers: ServerRepositoryABC): + self._logger = logger + self._context = db_context + + self._servers = servers + + LevelRepositoryABC.__init__(self) + + @staticmethod + def _get_value_from_result(value: any) -> Optional[any]: + if isinstance(value, str) and 'NULL' in value: + return None + + return value + + def _level_from_result(self, sql_result: tuple) -> Level: + return Level( + self._get_value_from_result(sql_result[1]), # name + self._get_value_from_result(sql_result[2]), # color + int(self._get_value_from_result(sql_result[3])), # min xp + int(self._get_value_from_result(sql_result[4])), # permissions + self._servers.get_server_by_id(sql_result[5]), # server + id=self._get_value_from_result(sql_result[0]) # id + ) + + def get_levels(self) -> List[Level]: + levels = List(Level) + self._logger.trace(__name__, f'Send SQL command: {Level.get_select_all_string()}') + results = self._context.select(Level.get_select_all_string()) + for result in results: + self._logger.trace(__name__, f'Get level with id {result[0]}') + levels.append(self._level_from_result(result)) + + return levels + + def get_level_by_id(self, id: int) -> Level: + self._logger.trace(__name__, f'Send SQL command: {Level.get_select_by_id_string(id)}') + result = self._context.select(Level.get_select_by_id_string(id))[0] + + return self._level_from_result(result) + + def find_level_by_id(self, id: int) -> Optional[Level]: + self._logger.trace(__name__, f'Send SQL command: {Level.get_select_by_id_string(id)}') + result = self._context.select(Level.get_select_by_id_string(id)) + if result is None or len(result) == 0: + return None + + return self._level_from_result(result[0]) + + def get_levels_by_server_id(self, server_id: int) -> List[Level]: + levels = List(Level) + self._logger.trace(__name__, f'Send SQL command: {Level.get_select_by_server_id_string(server_id)}') + results = self._context.select(Level.get_select_by_server_id_string(server_id))[0] + + for result in results: + self._logger.trace(__name__, f'Get level with id {result[0]}') + levels.append(self._level_from_result(result)) + + return levels + + def find_levels_by_server_id(self, server_id: int) -> Optional[List[Level]]: + levels = List(Level) + self._logger.trace(__name__, f'Send SQL command: {Level.get_select_by_server_id_string(server_id)}') + results = self._context.select(Level.get_select_by_server_id_string(server_id)) + if results is None or len(results) == 0: + return None + + for result in results: + self._logger.trace(__name__, f'Get level with id {result[0]}') + levels.append(self._level_from_result(result)) + + return levels + + def add_level(self, level: Level): + self._logger.trace(__name__, f'Send SQL command: {level.insert_string}') + self._context.cursor.execute(level.insert_string) + + def update_level(self, level: Level): + self._logger.trace(__name__, f'Send SQL command: {level.udpate_string}') + self._context.cursor.execute(level.udpate_string) + + def delete_level(self, level: Level): + self._logger.trace(__name__, f'Send SQL command: {level.delete_string}') + self._context.cursor.execute(level.delete_string)