diff --git a/src/bot_data/abc/auto_role_repository_abc.py b/src/bot_data/abc/auto_role_repository_abc.py index 6afa0730..7b9c0ed1 100644 --- a/src/bot_data/abc/auto_role_repository_abc.py +++ b/src/bot_data/abc/auto_role_repository_abc.py @@ -47,10 +47,7 @@ class AutoRoleRepositoryABC(ABC): def get_auto_role_rule_by_id(self, id: int) -> AutoRoleRule: pass @abstractmethod - def get_auto_role_rules_by_auto_role_id(self, id: int) -> AutoRoleRule: pass - - @abstractmethod - def find_auto_role_rules_by_auto_role_id(self, id: int) -> Optional[AutoRoleRule]: pass + def get_auto_role_rules_by_auto_role_id(self, id: int) -> List[AutoRoleRule]: pass @abstractmethod def add_auto_role_rule(self, server: AutoRoleRule): pass diff --git a/src/bot_data/abc/migration_abc.py b/src/bot_data/abc/migration_abc.py index 53b1696e..a3ce7ddc 100644 --- a/src/bot_data/abc/migration_abc.py +++ b/src/bot_data/abc/migration_abc.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod class MigrationABC(ABC): + name = None @abstractmethod def __init__(self): pass diff --git a/src/bot_data/migration/auto_role_migration.py b/src/bot_data/migration/auto_role_migration.py index 9aad5f72..2294ff95 100644 --- a/src/bot_data/migration/auto_role_migration.py +++ b/src/bot_data/migration/auto_role_migration.py @@ -4,8 +4,10 @@ from bot_data.db_context import DBContext class AutoRoleMigration(MigrationABC): + name = '0.2.1_AutoRoleMigration' def __init__(self, logger: DatabaseLogger, db: DBContext): + MigrationABC.__init__(self) self._logger = logger self._db = db self._cursor = db.cursor @@ -31,7 +33,7 @@ class AutoRoleMigration(MigrationABC): CREATE TABLE IF NOT EXISTS `AutoRoleRules` ( `AutoRoleRuleId` BIGINT NOT NULL AUTO_INCREMENT, `AutoRoleId` BIGINT, - `DiscordEmojiId` BIGINT NOT NULL, + `DiscordEmojiName` VARCHAR(64), `DiscordRoleId` BIGINT NOT NULL, `CreatedAt` DATETIME(6), `LastModifiedAt` DATETIME(6), diff --git a/src/bot_data/migration/initial_migration.py b/src/bot_data/migration/initial_migration.py index 76dd481d..54c2ae8f 100644 --- a/src/bot_data/migration/initial_migration.py +++ b/src/bot_data/migration/initial_migration.py @@ -4,8 +4,10 @@ from bot_data.db_context import DBContext class InitialMigration(MigrationABC): + name = '0.1_InitialMigration' def __init__(self, logger: DatabaseLogger, db: DBContext): + MigrationABC.__init__(self) self._logger = logger self._db = db self._cursor = db.cursor diff --git a/src/bot_data/model/auto_role.py b/src/bot_data/model/auto_role.py index b81c1c76..26d5dba2 100644 --- a/src/bot_data/model/auto_role.py +++ b/src/bot_data/model/auto_role.py @@ -14,7 +14,19 @@ class AutoRole(TableABC): TableABC.__init__(self) self._created_at = created_at if created_at is not None else self._created_at self._modified_at = modified_at if modified_at is not None else self._modified_at - + + @property + def auto_role_id(self) -> int: + return self._auto_role_id + + @property + def server_id(self) -> int: + return self._server_id + + @property + def discord_message_id(self) -> int: + return self._discord_message_id + @staticmethod def get_select_all_string() -> str: return str(f""" diff --git a/src/bot_data/model/auto_role_rule.py b/src/bot_data/model/auto_role_rule.py index 15ec3e32..ca9fce32 100644 --- a/src/bot_data/model/auto_role_rule.py +++ b/src/bot_data/model/auto_role_rule.py @@ -6,16 +6,24 @@ from cpl_core.database import TableABC class AutoRoleRule(TableABC): - def __init__(self, auto_role_id: int, discord_emoji_id: int, discord_role_id: int, created_at: datetime=None, modified_at: datetime=None, id=0): + def __init__(self, auto_role_id: int, discord_emoji_name: str, discord_role_id: int, created_at: datetime=None, modified_at: datetime=None, id=0): self._auto_role_rule_id = id self._auto_role_id = auto_role_id - self._discord_emoji_id = discord_emoji_id + self._discord_emoji_name = discord_emoji_name self._discord_role_id = discord_role_id TableABC.__init__(self) self._created_at = created_at if created_at is not None else self._created_at self._modified_at = modified_at if modified_at is not None else self._modified_at + @property + def emoji_name(self): + return self._discord_emoji_name + + @property + def role_id(self): + return self._discord_role_id + @staticmethod def get_select_all_string() -> str: return str(f""" @@ -40,10 +48,10 @@ class AutoRoleRule(TableABC): def insert_string(self) -> str: return str(f""" INSERT INTO `AutoRoleRules` ( - `AutoRoleId`, `DiscordEmojiId`, `DiscordRoleId`, `CreatedAt`, `LastModifiedAt` + `AutoRoleId`, `DiscordEmojiName`, `DiscordRoleId`, `CreatedAt`, `LastModifiedAt` ) VALUES ( {self._auto_role_id}, - {self._discord_emoji_id}, + {self._discord_emoji_name}, {self._discord_role_id}, '{self._created_at}', '{self._modified_at}' @@ -55,7 +63,7 @@ class AutoRoleRule(TableABC): return str(f""" UPDATE `AutoRoleRules` SET `AutoRoleId` = {self._auto_role_id}, - SET `DiscordEmojiId` = {self._discord_emoji_id}, + SET `DiscordEmojiName` = {self._discord_emoji_name}, SET `DiscordRoleId` = {self._discord_role_id}, `LastModifiedAt` = '{self._modified_at}' WHERE `AutoRoleRuleId` = {self._auto_role_id}; diff --git a/src/bot_data/service/auto_role_repository_service.py b/src/bot_data/service/auto_role_repository_service.py index 654d7635..7b02bc72 100644 --- a/src/bot_data/service/auto_role_repository_service.py +++ b/src/bot_data/service/auto_role_repository_service.py @@ -137,34 +137,21 @@ class AutoRoleRepositoryService(AutoRoleRepositoryABC): id=result[0] ) - def get_auto_role_rules_by_auto_role_id(self, id: int) -> AutoRoleRule: + def get_auto_role_rules_by_auto_role_id(self, id: int) -> List[AutoRoleRule]: + auto_role_rules = List(AutoRoleRule) self._logger.trace(__name__, f'Send SQL command: {AutoRoleRule.get_select_by_auto_role_id_string(id)}') - result = self._context.select(AutoRoleRule.get_select_by_auto_role_id_string(id))[0] - return AutoRoleRule( - result[1], - result[2], - result[3], - result[4], - result[5], - id=result[0] - ) + results = self._context.select(AutoRoleRule.get_select_by_auto_role_id_string(id)) + for result in results: + auto_role_rules.append(AutoRoleRule( + result[1], + result[2], + result[3], + result[4], + result[5], + id=result[0] + )) - def find_auto_role_rules_by_auto_role_id(self, id: int) -> Optional[AutoRoleRule]: - self._logger.trace(__name__, f'Send SQL command: {AutoRoleRule.get_select_by_auto_role_id_string(id)}') - result = self._context.select(AutoRoleRule.get_select_by_auto_role_id_string(id)) - if result is None or len(result) == 0: - return None - - result = result[0] - - return AutoRoleRule( - result[1], - result[2], - result[3], - result[4], - result[5], - id=result[0] - ) + return auto_role_rules def add_auto_role_rule(self, auto_role_rule: AutoRoleRule): self._logger.trace(__name__, f'Send SQL command: {auto_role_rule.delete_string}') diff --git a/src/bot_data/service/migration_service.py b/src/bot_data/service/migration_service.py index fc2d5913..9df395bf 100644 --- a/src/bot_data/service/migration_service.py +++ b/src/bot_data/service/migration_service.py @@ -2,6 +2,7 @@ from typing import Type from cpl_core.database.context import DatabaseContextABC from cpl_core.dependency_injection import ServiceProviderABC +from cpl_query.extension import List from bot_core.logging.database_logger import DatabaseLogger from bot_data.abc.migration_abc import MigrationABC @@ -17,7 +18,7 @@ class MigrationService: self._db = db self._cursor = db.cursor - self._migrations: list[Type[MigrationABC]] = MigrationABC.__subclasses__() + self._migrations = List(type, MigrationABC.__subclasses__()).order_by(lambda x: x.name) def migrate(self): self._logger.info(__name__, f"Running Migrations") diff --git a/src/modules/autorole/auto_role_module.py b/src/modules/autorole/auto_role_module.py index 9afd273e..c8f83990 100644 --- a/src/modules/autorole/auto_role_module.py +++ b/src/modules/autorole/auto_role_module.py @@ -1,10 +1,13 @@ from cpl_core.configuration import ConfigurationABC from cpl_core.dependency_injection import ServiceCollectionABC from cpl_core.environment import ApplicationEnvironmentABC +from cpl_discord.discord_event_types_enum import DiscordEventTypesEnum from cpl_discord.service.discord_collection_abc import DiscordCollectionABC from bot_core.abc.module_abc import ModuleABC from bot_core.configuration.feature_flags_enum import FeatureFlagsEnum +from modules.autorole.events.auto_role_on_raw_reaction_add import AutoRoleOnRawReactionAddEvent +from modules.autorole.helper.reaction_handler import ReactionHandler class AutoRoleModule(ModuleABC): @@ -16,6 +19,7 @@ class AutoRoleModule(ModuleABC): pass def configure_services(self, services: ServiceCollectionABC, env: ApplicationEnvironmentABC): + services.add_transient(ReactionHandler) # commands # events - pass + self._dc.add_event(DiscordEventTypesEnum.on_raw_reaction_add.value, AutoRoleOnRawReactionAddEvent) diff --git a/src/modules/autorole/events/__init__.py b/src/modules/autorole/events/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/modules/autorole/events/auto_role_on_raw_reaction_add.py b/src/modules/autorole/events/auto_role_on_raw_reaction_add.py new file mode 100644 index 00000000..586f87fd --- /dev/null +++ b/src/modules/autorole/events/auto_role_on_raw_reaction_add.py @@ -0,0 +1,34 @@ +from cpl_core.logging import LoggerABC +from cpl_discord.events.on_raw_reaction_add_abc import OnRawReactionAddABC +from cpl_discord.service import DiscordBotServiceABC +from discord import RawReactionActionEvent + +from bot_data.abc.auto_role_repository_abc import AutoRoleRepositoryABC +from bot_data.abc.server_repository_abc import ServerRepositoryABC +from modules.autorole.helper.reaction_handler import ReactionHandler + + +class AutoRoleOnRawReactionAddEvent(OnRawReactionAddABC): + + def __init__( + self, + logger: LoggerABC, + bot: DiscordBotServiceABC, + servers: ServerRepositoryABC, + auto_roles: AutoRoleRepositoryABC, + reaction_handler: ReactionHandler + ): + OnRawReactionAddABC.__init__(self) + + self._logger = logger + self._bot = bot + self._servers = servers + self._auto_roles = auto_roles + self._reaction_handler = reaction_handler + + async def on_raw_reaction_add(self, payload: RawReactionActionEvent): + self._logger.debug(__name__, f'Module {type(self)} started') + + await self._reaction_handler.handle(payload, 'add') + + self._logger.debug(__name__, f'Module {type(self)} stopped') diff --git a/src/modules/autorole/helper/__init__.py b/src/modules/autorole/helper/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/modules/autorole/helper/reaction_handler.py b/src/modules/autorole/helper/reaction_handler.py new file mode 100644 index 00000000..528a8968 --- /dev/null +++ b/src/modules/autorole/helper/reaction_handler.py @@ -0,0 +1,52 @@ +import discord +from cpl_core.logging import LoggerABC +from cpl_discord.service import DiscordBotServiceABC +from cpl_query.extension import List +from discord import RawReactionActionEvent + +from bot_data.abc.auto_role_repository_abc import AutoRoleRepositoryABC +from bot_data.abc.server_repository_abc import ServerRepositoryABC +from bot_data.model.auto_role import AutoRole +from bot_data.model.auto_role_rule import AutoRoleRule + + +class ReactionHandler: + + def __init__( + self, + logger: LoggerABC, + bot: DiscordBotServiceABC, + servers: ServerRepositoryABC, + auto_roles: AutoRoleRepositoryABC + ): + self._logger = logger + self._bot = bot + self._servers = servers + self._auto_roles = auto_roles + + self._message_ids = self._auto_roles.get_auto_roles().select(lambda x: x.discord_message_id) + self._roles = self._auto_roles.get_auto_roles() + + async def handle(self, payload: RawReactionActionEvent, r_type=None) -> None: + if payload.message_id not in self._message_ids: + return + + guild = self._bot.get_guild(payload.guild_id) + user = await guild.fetch_member(payload.user_id) + if user.bot: + return + + emoji = payload.emoji.name + auto_role: AutoRole = self._roles.where(lambda x: x.discord_message_id).first_or_default() + if auto_role is None: + return + + rules: List[AutoRoleRule] = self._auto_roles.get_auto_role_rules_by_auto_role_id(auto_role.auto_role_id) + + for rule in rules: + if emoji != rule.emoji_name: + continue + + role = guild.get_role(rule.role_id) + self._logger.debug(__name__, f'Assign role {role.name} to {user.name}') + await user.add_roles(role)