Improved get server logic #72

This commit is contained in:
Sven Heidemann 2022-10-17 16:07:33 +02:00
parent dfcc516389
commit d7a0706e0c
11 changed files with 132 additions and 6 deletions

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional
from cpl_query.extension import List from cpl_query.extension import List
@ -23,6 +24,12 @@ class AuthServiceABC(ABC):
@abstractmethod @abstractmethod
def decode_token(self, token: str) -> dict: pass def decode_token(self, token: str) -> dict: pass
@abstractmethod
def get_decoded_token_from_request(self) -> dict: pass
@abstractmethod
def find_decoded_token_from_request(self) -> Optional[dict]: pass
@abstractmethod @abstractmethod
async def get_all_auth_users_async(self) -> List[AuthUserDTO]: pass async def get_all_auth_users_async(self) -> List[AuthUserDTO]: pass

View File

@ -37,4 +37,13 @@ class ServerController:
@Route.get(f'{BasePath}/servers') @Route.get(f'{BasePath}/servers')
@Route.authorize(role=AuthRoleEnum.admin) @Route.authorize(role=AuthRoleEnum.admin)
async def get_all_servers(self) -> Response: async def get_all_servers(self) -> Response:
return jsonify(self._discord_service.get_all_servers().select(lambda x: x.to_dict())) result = await self._discord_service.get_all_servers()
result = result.select(lambda x: x.to_dict())
return jsonify(result)
@Route.get(f'{BasePath}/servers-by-user')
@Route.authorize
async def get_all_servers_by_user(self) -> Response:
result = await self._discord_service.get_all_servers_by_user()
result = result.select(lambda x: x.to_dict())
return jsonify(result)

View File

@ -15,6 +15,7 @@ class AuthUserDTO(DtoABC):
password: str, password: str,
confirmation_id: Optional[str], confirmation_id: Optional[str],
auth_role: AuthRoleEnum, auth_role: AuthRoleEnum,
user_id: Optional[int],
): ):
DtoABC.__init__(self) DtoABC.__init__(self)
@ -25,6 +26,7 @@ class AuthUserDTO(DtoABC):
self._password = password self._password = password
self._is_confirmed = confirmation_id is None self._is_confirmed = confirmation_id is None
self._auth_role = auth_role self._auth_role = auth_role
self._user_id = user_id
@property @property
def id(self) -> int: def id(self) -> int:
@ -78,6 +80,14 @@ class AuthUserDTO(DtoABC):
def auth_role(self, value: AuthRoleEnum): def auth_role(self, value: AuthRoleEnum):
self._auth_role = value self._auth_role = value
@property
def user_id(self) -> Optional[int]:
return self._user_id
@user_id.setter
def user_id(self, value: Optional[int]):
self._user_id = value
def from_dict(self, values: dict): def from_dict(self, values: dict):
self._id = values['id'] self._id = values['id']
self._first_name = values['firstName'] self._first_name = values['firstName']
@ -86,6 +96,7 @@ class AuthUserDTO(DtoABC):
self._password = values['password'] self._password = values['password']
self._is_confirmed = values['isConfirmed'] self._is_confirmed = values['isConfirmed']
self._auth_role = values['authRole'] self._auth_role = values['authRole']
self._user_id = values['userId']
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
@ -96,4 +107,5 @@ class AuthUserDTO(DtoABC):
'password': self._password, 'password': self._password,
'isConfirmed': self._is_confirmed, 'isConfirmed': self._is_confirmed,
'authRole': self._auth_role.value, 'authRole': self._auth_role.value,
'userId': self._user_id,
} }

View File

@ -9,6 +9,7 @@ from cpl_core.database.context import DatabaseContextABC
from cpl_core.mailing import EMailClientABC, EMail from cpl_core.mailing import EMailClientABC, EMail
from cpl_query.extension import List from cpl_query.extension import List
from cpl_translation import TranslatePipe from cpl_translation import TranslatePipe
from flask import request
from bot_api.abc.auth_service_abc import AuthServiceABC from bot_api.abc.auth_service_abc import AuthServiceABC
from bot_api.configuration.authentication_settings import AuthenticationSettings from bot_api.configuration.authentication_settings import AuthenticationSettings
@ -96,6 +97,37 @@ class AuthService(AuthServiceABC):
algorithms=['HS256'] algorithms=['HS256']
) )
def get_decoded_token_from_request(self) -> dict:
token = None
if 'Authorization' in request.headers:
bearer = request.headers.get('Authorization')
token = bearer.split()[1]
if token is None:
raise ServiceException(ServiceErrorCode.Unauthorized, f'Token not set')
return jwt.decode(
token,
key=self._auth_settings.secret_key,
issuer=self._auth_settings.issuer,
audience=self._auth_settings.audience,
algorithms=['HS256']
)
def find_decoded_token_from_request(self) -> Optional[dict]:
token = None
if 'Authorization' in request.headers:
bearer = request.headers.get('Authorization')
token = bearer.split()[1]
return jwt.decode(
token,
key=self._auth_settings.secret_key,
issuer=self._auth_settings.issuer,
audience=self._auth_settings.audience,
algorithms=['HS256']
) if token is not None else None
def _create_and_save_refresh_token(self, user: AuthUser) -> str: def _create_and_save_refresh_token(self, user: AuthUser) -> str:
token = str(uuid.uuid4()) token = str(uuid.uuid4())
user.refresh_token = token user.refresh_token = token

View File

@ -1,9 +1,18 @@
from typing import Optional
from cpl_discord.service import DiscordBotServiceABC from cpl_discord.service import DiscordBotServiceABC
from cpl_query.extension import List from cpl_query.extension import List
from flask import jsonify
from bot_api.abc.auth_service_abc import AuthServiceABC
from bot_api.exception.service_error_code_enum import ServiceErrorCode
from bot_api.exception.service_exception import ServiceException
from bot_api.model.discord.server_dto import ServerDTO from bot_api.model.discord.server_dto import ServerDTO
from bot_api.model.error_dto import ErrorDTO
from bot_api.transformer.server_transformer import ServerTransformer from bot_api.transformer.server_transformer import ServerTransformer
from bot_data.abc.server_repository_abc import ServerRepositoryABC from bot_data.abc.server_repository_abc import ServerRepositoryABC
from bot_data.abc.user_repository_abc import UserRepositoryABC
from bot_data.model.auth_role_enum import AuthRoleEnum
class DiscordService: class DiscordService:
@ -12,12 +21,33 @@ class DiscordService:
self, self,
bot: DiscordBotServiceABC, bot: DiscordBotServiceABC,
servers: ServerRepositoryABC, servers: ServerRepositoryABC,
auth: AuthServiceABC,
users: UserRepositoryABC,
): ):
self._bot = bot self._bot = bot
self._servers = servers self._servers = servers
self._auth = auth
self._users = users
async def get_all_servers(self) -> List[ServerDTO]:
servers = self._servers.get_servers()
return servers.select(
lambda x: ServerTransformer.to_dto(x, self._bot.get_guild(x.discord_server_id).name, self._bot.get_guild(x.discord_server_id).member_count)
)
async def get_all_servers_by_user(self) -> List[ServerDTO]:
token = self._auth.get_decoded_token_from_request()
if token is None or 'email' not in token or 'role' not in token:
raise ServiceException(ServiceErrorCode.InvalidData, 'Token invalid')
role = AuthRoleEnum(token['role'])
if role == AuthRoleEnum.admin:
servers = self._servers.get_servers()
else:
user = await self._auth.find_auth_user_by_email_async(token['email'])
user_from_db = self._users.find_user_by_id(0 if user.user_id is None else user.user_id)
servers = self._servers.get_servers().where(lambda x: user_from_db is not None and x.server_id == user_from_db.server.server_id)
def get_all_servers(self) -> List[ServerDTO]:
servers = self._servers.get_servers().select()
return servers.select( return servers.select(
lambda x: ServerTransformer.to_dto(x, self._bot.get_guild(x.discord_server_id).name, self._bot.get_guild(x.discord_server_id).member_count) lambda x: ServerTransformer.to_dto(x, self._bot.get_guild(x.discord_server_id).name, self._bot.get_guild(x.discord_server_id).member_count)
) )

View File

@ -9,7 +9,7 @@ from bot_data.model.auth_user import AuthUser
class AuthUserTransformer(TransformerABC): class AuthUserTransformer(TransformerABC):
@staticmethod @staticmethod
def to_db(dto: AuthUser) -> AuthUser: def to_db(dto: AuthUserDTO) -> AuthUser:
return AuthUser( return AuthUser(
dto.first_name, dto.first_name,
dto.last_name, dto.last_name,
@ -20,6 +20,7 @@ class AuthUserTransformer(TransformerABC):
None, None,
datetime.now(tz=timezone.utc), datetime.now(tz=timezone.utc),
AuthRoleEnum.normal if dto.auth_role is None else AuthRoleEnum(dto.auth_role), AuthRoleEnum.normal if dto.auth_role is None else AuthRoleEnum(dto.auth_role),
dto.user_id,
id=0 if dto.id is None else dto.id id=0 if dto.id is None else dto.id
) )
@ -32,5 +33,6 @@ class AuthUserTransformer(TransformerABC):
db.email, db.email,
db.password, db.password,
db.confirmation_id, db.confirmation_id,
db.auth_role db.auth_role,
db.user_id
) )

View File

@ -17,6 +17,9 @@ class UserRepositoryABC(ABC):
@abstractmethod @abstractmethod
def get_user_by_id(self, id: int) -> User: pass def get_user_by_id(self, id: int) -> User: pass
@abstractmethod
def find_user_by_id(self, id: int) -> Optional[User]: pass
@abstractmethod @abstractmethod
def get_users_by_discord_id(self, discord_id: int) -> List[User]: pass def get_users_by_discord_id(self, discord_id: int) -> List[User]: pass

View File

@ -28,9 +28,11 @@ class ApiMigration(MigrationABC):
`ForgotPasswordId` VARCHAR(255) DEFAULT NULL, `ForgotPasswordId` VARCHAR(255) DEFAULT NULL,
`RefreshTokenExpiryTime` DATETIME(6) NOT NULL, `RefreshTokenExpiryTime` DATETIME(6) NOT NULL,
`AuthRole` INT NOT NULL DEFAULT '0', `AuthRole` INT NOT NULL DEFAULT '0',
`UserId` BIGINT NOT NULL DEFAULT '0',
`CreatedOn` DATETIME(6) NOT NULL, `CreatedOn` DATETIME(6) NOT NULL,
`LastModifiedOn` DATETIME(6) NOT NULL, `LastModifiedOn` DATETIME(6) NOT NULL,
PRIMARY KEY(`Id`) PRIMARY KEY(`Id`),
FOREIGN KEY (`UserId`) REFERENCES `Users`(`UserId`)
) )
""") """)
) )

View File

@ -19,6 +19,7 @@ class AuthUser(TableABC):
forgot_password_id: Optional[str], forgot_password_id: Optional[str],
refresh_token_expire_time: datetime, refresh_token_expire_time: datetime,
auth_role: AuthRoleEnum, auth_role: AuthRoleEnum,
user_id: Optional[int],
created_at: datetime = None, created_at: datetime = None,
modified_at: datetime = None, modified_at: datetime = None,
id=0 id=0
@ -34,6 +35,7 @@ class AuthUser(TableABC):
self._refresh_token_expire_time = refresh_token_expire_time self._refresh_token_expire_time = refresh_token_expire_time
self._auth_role_id = auth_role self._auth_role_id = auth_role
self._user_id = user_id
TableABC.__init__(self) TableABC.__init__(self)
self._created_at = created_at if created_at is not None else self._created_at self._created_at = created_at if created_at is not None else self._created_at
@ -115,6 +117,14 @@ class AuthUser(TableABC):
def auth_role(self, value: AuthRoleEnum): def auth_role(self, value: AuthRoleEnum):
self._auth_role_id = value self._auth_role_id = value
@property
def user_id(self) -> Optional[int]:
return self._user_id
@user_id.setter
def user_id(self, value: Optional[int]):
self._user_id = value
@staticmethod @staticmethod
def get_select_all_string() -> str: def get_select_all_string() -> str:
return str(f""" return str(f"""
@ -163,6 +173,7 @@ class AuthUser(TableABC):
`ForgotPasswordId`, `ForgotPasswordId`,
`RefreshTokenExpiryTime`, `RefreshTokenExpiryTime`,
`AuthRole`, `AuthRole`,
`UserId`,
`CreatedOn`, `CreatedOn`,
`LastModifiedOn` `LastModifiedOn`
) VALUES ( ) VALUES (
@ -176,6 +187,7 @@ class AuthUser(TableABC):
'{"NULL" if self._forgot_password_id is None else self._forgot_password_id}', '{"NULL" if self._forgot_password_id is None else self._forgot_password_id}',
'{self._refresh_token_expire_time}', '{self._refresh_token_expire_time}',
{self._auth_role_id.value}, {self._auth_role_id.value},
{"NULL" if self._user_id is None else self._user_id}
'{self._created_at}', '{self._created_at}',
'{self._modified_at}' '{self._modified_at}'
) )
@ -194,6 +206,7 @@ class AuthUser(TableABC):
`ForgotPasswordId` = '{"NULL" if self._forgot_password_id is None else self._forgot_password_id}', `ForgotPasswordId` = '{"NULL" if self._forgot_password_id is None else self._forgot_password_id}',
`RefreshTokenExpiryTime` = '{self._refresh_token_expire_time}', `RefreshTokenExpiryTime` = '{self._refresh_token_expire_time}',
`AuthRole` = {self._auth_role_id.value}, `AuthRole` = {self._auth_role_id.value},
`UserId` = {"NULL" if self._user_id is None else self._user_id},
`LastModifiedOn` = '{self._modified_at}' `LastModifiedOn` = '{self._modified_at}'
WHERE `AuthUsers`.`Id` = {self._auth_user_id}; WHERE `AuthUsers`.`Id` = {self._auth_user_id};
""") """)

View File

@ -37,6 +37,7 @@ class AuthUserRepositoryService(AuthUserRepositoryABC):
self._get_value_from_result(result[7]), self._get_value_from_result(result[7]),
self._get_value_from_result(result[8]), self._get_value_from_result(result[8]),
AuthRoleEnum(self._get_value_from_result(result[9])), AuthRoleEnum(self._get_value_from_result(result[9])),
self._get_value_from_result(result[10]),
id=self._get_value_from_result(result[0]) id=self._get_value_from_result(result[0])
) )

View File

@ -45,6 +45,21 @@ class UserRepositoryService(UserRepositoryABC):
id=result[0] id=result[0]
) )
def find_user_by_id(self, id: int) -> Optional[User]:
self._logger.trace(__name__, f'Send SQL command: {User.get_select_by_id_string(id)}')
result = self._context.select(User.get_select_by_id_string(id))
if result is None or len(result) == 0:
return None
result = result[0]
return User(
result[1],
result[2],
self._servers.get_server_by_id(result[3]),
id=result[0]
)
def get_users_by_discord_id(self, discord_id: int) -> List[User]: def get_users_by_discord_id(self, discord_id: int) -> List[User]:
users = List(User) users = List(User)
self._logger.trace(__name__, f'Send SQL command: {User.get_select_by_discord_id_string(discord_id)}') self._logger.trace(__name__, f'Send SQL command: {User.get_select_by_discord_id_string(discord_id)}')