diff --git a/src/gismo_data/abc/server_repository_abc.py b/src/gismo_data/abc/server_repository_abc.py index a9106f4..2c082f1 100644 --- a/src/gismo_data/abc/server_repository_abc.py +++ b/src/gismo_data/abc/server_repository_abc.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Optional from cpl_query.extension import List @@ -19,6 +20,9 @@ class ServerRepositoryABC(ABC): @abstractmethod def get_server_by_discord_id(self, discord_id: int) -> Server: pass + @abstractmethod + def find_server_by_discord_id(self, discord_id: int) -> Optional[Server]: pass + @abstractmethod def add_server(self, server: Server) -> int: pass diff --git a/src/gismo_data/abc/user_repository_abc.py b/src/gismo_data/abc/user_repository_abc.py index 7952f42..aa40998 100644 --- a/src/gismo_data/abc/user_repository_abc.py +++ b/src/gismo_data/abc/user_repository_abc.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Optional from cpl_query.extension import List @@ -19,6 +20,9 @@ class UserRepositoryABC(ABC): @abstractmethod def get_user_by_discord_id(self, discord_id: int) -> User: pass + @abstractmethod + def find_user_by_discord_id(self, discord_id: int) -> Optional[User]: pass + @abstractmethod def add_user(self, user: User) -> int: pass diff --git a/src/gismo_data/service/server_repository_service.py b/src/gismo_data/service/server_repository_service.py index 9cf26d8..09ed488 100644 --- a/src/gismo_data/service/server_repository_service.py +++ b/src/gismo_data/service/server_repository_service.py @@ -1,12 +1,16 @@ +from typing import Optional from cpl_core.database.context import DatabaseContextABC +from cpl_core.logging import LoggerABC from cpl_query.extension import List + from gismo_data.abc.server_repository_abc import ServerRepositoryABC from gismo_data.model.server import Server class ServerRepositoryService(ServerRepositoryABC): - def __init__(self, db_context: DatabaseContextABC): + def __init__(self, logger: LoggerABC, db_context: DatabaseContextABC): + self._logger = logger self._context = db_context ServerRepositoryABC.__init__(self) @@ -38,6 +42,17 @@ class ServerRepositoryService(ServerRepositoryABC): result[1], id=result[0] ) + + def find_server_by_discord_id(self, discord_id: int) -> Optional[Server]: + self._logger.trace(__name__, f'Send SQL command: {Server.get_select_by_discord_id_string(discord_id)}') + result = self._context.select(Server.get_select_by_discord_id_string(discord_id)) + if len(result) == 0: + return None + + return Server( + result[1], + id=result[0] + ) def add_server(self, server: Server) -> int: self._logger.trace(__name__, f'Send SQL command: {server.insert_string}') diff --git a/src/gismo_data/service/user_repository_service.py b/src/gismo_data/service/user_repository_service.py index 1e94084..6212c82 100644 --- a/src/gismo_data/service/user_repository_service.py +++ b/src/gismo_data/service/user_repository_service.py @@ -1,3 +1,4 @@ +from typing import Optional from cpl_core.database.context import DatabaseContextABC from cpl_core.logging import LoggerABC from cpl_query.extension import List @@ -50,6 +51,19 @@ class UserRepositoryService(UserRepositoryABC): self._servers.get_server_by_id(result[3]), id=result[0] ) + + def find_user_by_discord_id(self, discord_id: int) -> Optional[User]: + self._logger.trace(__name__, f'Send SQL command: {User.get_select_by_discord_id_string(discord_id)}') + result = self._context.select(User.get_select_by_discord_id_string(discord_id)) + if len(result) == 0: + return None + + return User( + result[1], + result[2], + self._servers.get_server_by_id(result[3]), + id=result[0] + ) def add_user(self, user: User) -> int: self._logger.trace(__name__, f'Send SQL command: {user.insert_strin}')