forked from sh-edraft.de/sh_discord_bot
		
	Improved get server logic #72
This commit is contained in:
		| @@ -1,4 +1,5 @@ | ||||
| from abc import ABC, abstractmethod | ||||
| from typing import Optional | ||||
|  | ||||
| from cpl_query.extension import List | ||||
|  | ||||
| @@ -23,6 +24,12 @@ class AuthServiceABC(ABC): | ||||
|     @abstractmethod | ||||
|     def decode_token(self, token: str) -> dict: pass | ||||
|  | ||||
|     @abstractmethod | ||||
|     def get_decoded_token_from_request(self) -> dict: pass | ||||
|  | ||||
|     @abstractmethod | ||||
|     def find_decoded_token_from_request(self) -> Optional[dict]: pass | ||||
|  | ||||
|     @abstractmethod | ||||
|     async def get_all_auth_users_async(self) -> List[AuthUserDTO]: pass | ||||
|  | ||||
|   | ||||
| @@ -37,4 +37,13 @@ class ServerController: | ||||
|     @Route.get(f'{BasePath}/servers') | ||||
|     @Route.authorize(role=AuthRoleEnum.admin) | ||||
|     async def get_all_servers(self) -> Response: | ||||
|         return jsonify(self._discord_service.get_all_servers().select(lambda x: x.to_dict())) | ||||
|         result = await self._discord_service.get_all_servers() | ||||
|         result = result.select(lambda x: x.to_dict()) | ||||
|         return jsonify(result) | ||||
|  | ||||
|     @Route.get(f'{BasePath}/servers-by-user') | ||||
|     @Route.authorize | ||||
|     async def get_all_servers_by_user(self) -> Response: | ||||
|         result = await self._discord_service.get_all_servers_by_user() | ||||
|         result = result.select(lambda x: x.to_dict()) | ||||
|         return jsonify(result) | ||||
|   | ||||
| @@ -15,6 +15,7 @@ class AuthUserDTO(DtoABC): | ||||
|             password: str, | ||||
|             confirmation_id: Optional[str], | ||||
|             auth_role: AuthRoleEnum, | ||||
|             user_id: Optional[int], | ||||
|     ): | ||||
|         DtoABC.__init__(self) | ||||
|  | ||||
| @@ -25,6 +26,7 @@ class AuthUserDTO(DtoABC): | ||||
|         self._password = password | ||||
|         self._is_confirmed = confirmation_id is None | ||||
|         self._auth_role = auth_role | ||||
|         self._user_id = user_id | ||||
|          | ||||
|     @property | ||||
|     def id(self) -> int: | ||||
| @@ -78,6 +80,14 @@ class AuthUserDTO(DtoABC): | ||||
|     def auth_role(self, value: AuthRoleEnum): | ||||
|         self._auth_role = value | ||||
|  | ||||
|     @property | ||||
|     def user_id(self) -> Optional[int]: | ||||
|         return self._user_id | ||||
|  | ||||
|     @user_id.setter | ||||
|     def user_id(self, value: Optional[int]): | ||||
|         self._user_id = value | ||||
|  | ||||
|     def from_dict(self, values: dict): | ||||
|         self._id = values['id'] | ||||
|         self._first_name = values['firstName'] | ||||
| @@ -86,6 +96,7 @@ class AuthUserDTO(DtoABC): | ||||
|         self._password = values['password'] | ||||
|         self._is_confirmed = values['isConfirmed'] | ||||
|         self._auth_role = values['authRole'] | ||||
|         self._user_id = values['userId'] | ||||
|  | ||||
|     def to_dict(self) -> dict: | ||||
|         return { | ||||
| @@ -96,4 +107,5 @@ class AuthUserDTO(DtoABC): | ||||
|             'password': self._password, | ||||
|             'isConfirmed': self._is_confirmed, | ||||
|             'authRole': self._auth_role.value, | ||||
|             'userId': self._user_id, | ||||
|         } | ||||
|   | ||||
| @@ -9,6 +9,7 @@ from cpl_core.database.context import DatabaseContextABC | ||||
| from cpl_core.mailing import EMailClientABC, EMail | ||||
| from cpl_query.extension import List | ||||
| from cpl_translation import TranslatePipe | ||||
| from flask import request | ||||
|  | ||||
| from bot_api.abc.auth_service_abc import AuthServiceABC | ||||
| from bot_api.configuration.authentication_settings import AuthenticationSettings | ||||
| @@ -96,6 +97,37 @@ class AuthService(AuthServiceABC): | ||||
|             algorithms=['HS256'] | ||||
|         ) | ||||
|  | ||||
|     def get_decoded_token_from_request(self) -> dict: | ||||
|         token = None | ||||
|         if 'Authorization' in request.headers: | ||||
|             bearer = request.headers.get('Authorization') | ||||
|             token = bearer.split()[1] | ||||
|  | ||||
|         if token is None: | ||||
|             raise ServiceException(ServiceErrorCode.Unauthorized, f'Token not set') | ||||
|  | ||||
|         return jwt.decode( | ||||
|             token, | ||||
|             key=self._auth_settings.secret_key, | ||||
|             issuer=self._auth_settings.issuer, | ||||
|             audience=self._auth_settings.audience, | ||||
|             algorithms=['HS256'] | ||||
|         ) | ||||
|  | ||||
|     def find_decoded_token_from_request(self) -> Optional[dict]: | ||||
|         token = None | ||||
|         if 'Authorization' in request.headers: | ||||
|             bearer = request.headers.get('Authorization') | ||||
|             token = bearer.split()[1] | ||||
|  | ||||
|         return jwt.decode( | ||||
|             token, | ||||
|             key=self._auth_settings.secret_key, | ||||
|             issuer=self._auth_settings.issuer, | ||||
|             audience=self._auth_settings.audience, | ||||
|             algorithms=['HS256'] | ||||
|         ) if token is not None else None | ||||
|  | ||||
|     def _create_and_save_refresh_token(self, user: AuthUser) -> str: | ||||
|         token = str(uuid.uuid4()) | ||||
|         user.refresh_token = token | ||||
|   | ||||
| @@ -1,9 +1,18 @@ | ||||
| from typing import Optional | ||||
|  | ||||
| from cpl_discord.service import DiscordBotServiceABC | ||||
| from cpl_query.extension import List | ||||
| from flask import jsonify | ||||
|  | ||||
| from bot_api.abc.auth_service_abc import AuthServiceABC | ||||
| from bot_api.exception.service_error_code_enum import ServiceErrorCode | ||||
| from bot_api.exception.service_exception import ServiceException | ||||
| from bot_api.model.discord.server_dto import ServerDTO | ||||
| from bot_api.model.error_dto import ErrorDTO | ||||
| from bot_api.transformer.server_transformer import ServerTransformer | ||||
| from bot_data.abc.server_repository_abc import ServerRepositoryABC | ||||
| from bot_data.abc.user_repository_abc import UserRepositoryABC | ||||
| from bot_data.model.auth_role_enum import AuthRoleEnum | ||||
|  | ||||
|  | ||||
| class DiscordService: | ||||
| @@ -12,12 +21,33 @@ class DiscordService: | ||||
|             self, | ||||
|             bot: DiscordBotServiceABC, | ||||
|             servers: ServerRepositoryABC, | ||||
|             auth: AuthServiceABC, | ||||
|             users: UserRepositoryABC, | ||||
|     ): | ||||
|         self._bot = bot | ||||
|         self._servers = servers | ||||
|         self._auth = auth | ||||
|         self._users = users | ||||
|  | ||||
|     async def get_all_servers(self) -> List[ServerDTO]: | ||||
|         servers = self._servers.get_servers() | ||||
|         return servers.select( | ||||
|             lambda x: ServerTransformer.to_dto(x, self._bot.get_guild(x.discord_server_id).name, self._bot.get_guild(x.discord_server_id).member_count) | ||||
|         ) | ||||
|  | ||||
|     async def get_all_servers_by_user(self) -> List[ServerDTO]: | ||||
|         token = self._auth.get_decoded_token_from_request() | ||||
|         if token is None or 'email' not in token or 'role' not in token: | ||||
|             raise ServiceException(ServiceErrorCode.InvalidData, 'Token invalid') | ||||
|  | ||||
|         role = AuthRoleEnum(token['role']) | ||||
|         if role == AuthRoleEnum.admin: | ||||
|             servers = self._servers.get_servers() | ||||
|         else: | ||||
|             user = await self._auth.find_auth_user_by_email_async(token['email']) | ||||
|             user_from_db = self._users.find_user_by_id(0 if user.user_id is None else user.user_id) | ||||
|             servers = self._servers.get_servers().where(lambda x: user_from_db is not None and x.server_id == user_from_db.server.server_id) | ||||
|  | ||||
|     def get_all_servers(self) -> List[ServerDTO]: | ||||
|         servers = self._servers.get_servers().select() | ||||
|         return servers.select( | ||||
|             lambda x: ServerTransformer.to_dto(x, self._bot.get_guild(x.discord_server_id).name, self._bot.get_guild(x.discord_server_id).member_count) | ||||
|         ) | ||||
|   | ||||
| @@ -9,7 +9,7 @@ from bot_data.model.auth_user import AuthUser | ||||
| class AuthUserTransformer(TransformerABC): | ||||
|  | ||||
|     @staticmethod | ||||
|     def to_db(dto: AuthUser) -> AuthUser: | ||||
|     def to_db(dto: AuthUserDTO) -> AuthUser: | ||||
|         return AuthUser( | ||||
|             dto.first_name, | ||||
|             dto.last_name, | ||||
| @@ -20,6 +20,7 @@ class AuthUserTransformer(TransformerABC): | ||||
|             None, | ||||
|             datetime.now(tz=timezone.utc), | ||||
|             AuthRoleEnum.normal if dto.auth_role is None else AuthRoleEnum(dto.auth_role), | ||||
|             dto.user_id, | ||||
|             id=0 if dto.id is None else dto.id | ||||
|         ) | ||||
|  | ||||
| @@ -32,5 +33,6 @@ class AuthUserTransformer(TransformerABC): | ||||
|             db.email, | ||||
|             db.password, | ||||
|             db.confirmation_id, | ||||
|             db.auth_role | ||||
|             db.auth_role, | ||||
|             db.user_id | ||||
|         ) | ||||
|   | ||||
| @@ -16,6 +16,9 @@ class UserRepositoryABC(ABC): | ||||
|      | ||||
|     @abstractmethod | ||||
|     def get_user_by_id(self, id: int) -> User: pass | ||||
|  | ||||
|     @abstractmethod | ||||
|     def find_user_by_id(self, id: int) -> Optional[User]: pass | ||||
|      | ||||
|     @abstractmethod | ||||
|     def get_users_by_discord_id(self, discord_id: int) -> List[User]: pass | ||||
|   | ||||
| @@ -28,9 +28,11 @@ class ApiMigration(MigrationABC): | ||||
|               `ForgotPasswordId` VARCHAR(255) DEFAULT NULL, | ||||
|               `RefreshTokenExpiryTime` DATETIME(6) NOT NULL, | ||||
|               `AuthRole` INT NOT NULL DEFAULT '0', | ||||
|               `UserId` BIGINT NOT NULL DEFAULT '0', | ||||
|               `CreatedOn` DATETIME(6) NOT NULL, | ||||
|               `LastModifiedOn` DATETIME(6) NOT NULL, | ||||
|               PRIMARY KEY(`Id`) | ||||
|               PRIMARY KEY(`Id`), | ||||
|               FOREIGN KEY (`UserId`) REFERENCES `Users`(`UserId`) | ||||
|             ) | ||||
|             """) | ||||
|         ) | ||||
|   | ||||
| @@ -19,6 +19,7 @@ class AuthUser(TableABC): | ||||
|             forgot_password_id: Optional[str], | ||||
|             refresh_token_expire_time: datetime, | ||||
|             auth_role: AuthRoleEnum, | ||||
|             user_id: Optional[int], | ||||
|             created_at: datetime = None, | ||||
|             modified_at: datetime = None, | ||||
|             id=0 | ||||
| @@ -34,6 +35,7 @@ class AuthUser(TableABC): | ||||
|         self._refresh_token_expire_time = refresh_token_expire_time | ||||
|  | ||||
|         self._auth_role_id = auth_role | ||||
|         self._user_id = user_id | ||||
|  | ||||
|         TableABC.__init__(self) | ||||
|         self._created_at = created_at if created_at is not None else self._created_at | ||||
| @@ -115,6 +117,14 @@ class AuthUser(TableABC): | ||||
|     def auth_role(self, value: AuthRoleEnum): | ||||
|         self._auth_role_id = value | ||||
|  | ||||
|     @property | ||||
|     def user_id(self) -> Optional[int]: | ||||
|         return self._user_id | ||||
|  | ||||
|     @user_id.setter | ||||
|     def user_id(self, value: Optional[int]): | ||||
|         self._user_id = value | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_select_all_string() -> str: | ||||
|         return str(f""" | ||||
| @@ -163,6 +173,7 @@ class AuthUser(TableABC): | ||||
|                 `ForgotPasswordId`, | ||||
|                 `RefreshTokenExpiryTime`, | ||||
|                 `AuthRole`, | ||||
|                 `UserId`, | ||||
|                 `CreatedOn`, | ||||
|                 `LastModifiedOn` | ||||
|                 ) VALUES ( | ||||
| @@ -176,6 +187,7 @@ class AuthUser(TableABC): | ||||
|                     '{"NULL" if self._forgot_password_id is None else self._forgot_password_id}', | ||||
|                     '{self._refresh_token_expire_time}', | ||||
|                     {self._auth_role_id.value}, | ||||
|                     {"NULL" if self._user_id is None else self._user_id} | ||||
|                     '{self._created_at}',  | ||||
|                     '{self._modified_at}' | ||||
|                 ) | ||||
| @@ -194,6 +206,7 @@ class AuthUser(TableABC): | ||||
|             `ForgotPasswordId` = '{"NULL" if self._forgot_password_id is None else self._forgot_password_id}', | ||||
|             `RefreshTokenExpiryTime` = '{self._refresh_token_expire_time}', | ||||
|             `AuthRole` = {self._auth_role_id.value}, | ||||
|             `UserId` = {"NULL" if self._user_id is None else self._user_id}, | ||||
|             `LastModifiedOn` = '{self._modified_at}' | ||||
|             WHERE `AuthUsers`.`Id` = {self._auth_user_id}; | ||||
|         """) | ||||
|   | ||||
| @@ -37,6 +37,7 @@ class AuthUserRepositoryService(AuthUserRepositoryABC): | ||||
|             self._get_value_from_result(result[7]), | ||||
|             self._get_value_from_result(result[8]), | ||||
|             AuthRoleEnum(self._get_value_from_result(result[9])), | ||||
|             self._get_value_from_result(result[10]), | ||||
|             id=self._get_value_from_result(result[0]) | ||||
|         ) | ||||
|  | ||||
|   | ||||
| @@ -44,6 +44,21 @@ class UserRepositoryService(UserRepositoryABC): | ||||
|             self._servers.get_server_by_id(result[3]), | ||||
|             id=result[0] | ||||
|         ) | ||||
|  | ||||
|     def find_user_by_id(self, id: int) -> Optional[User]: | ||||
|         self._logger.trace(__name__, f'Send SQL command: {User.get_select_by_id_string(id)}') | ||||
|         result = self._context.select(User.get_select_by_id_string(id)) | ||||
|         if result is None or len(result) == 0: | ||||
|             return None | ||||
|  | ||||
|         result = result[0] | ||||
|  | ||||
|         return User( | ||||
|             result[1], | ||||
|             result[2], | ||||
|             self._servers.get_server_by_id(result[3]), | ||||
|             id=result[0] | ||||
|         ) | ||||
|      | ||||
|     def get_users_by_discord_id(self, discord_id: int) -> List[User]: | ||||
|         users = List(User) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user