diff --git a/src/cpl_discord/command/discord_command_abc.py b/src/cpl_discord/command/discord_command_abc.py index 19e3290c..fe3072be 100644 --- a/src/cpl_discord/command/discord_command_abc.py +++ b/src/cpl_discord/command/discord_command_abc.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from discord.ext import commands -from discord_commands_meta import DiscordCogMeta +from cpl_discord.command.discord_commands_meta import DiscordCogMeta class DiscordCommandABC(ABC, commands.Cog, metaclass=DiscordCogMeta): diff --git a/src/cpl_discord/event_types_enum.py b/src/cpl_discord/discord_event_types_enum.py similarity index 58% rename from src/cpl_discord/event_types_enum.py rename to src/cpl_discord/discord_event_types_enum.py index a028ac52..d963435d 100644 --- a/src/cpl_discord/event_types_enum.py +++ b/src/cpl_discord/discord_event_types_enum.py @@ -1,6 +1,9 @@ from enum import Enum from cpl_discord.events.on_bulk_message_delete_abc import OnBulkMessageDeleteABC +from cpl_discord.events.on_command_abc import OnCommandABC +from cpl_discord.events.on_command_completion_abc import OnCommandCompletionABC +from cpl_discord.events.on_command_error_abc import OnCommandErrorABC from cpl_discord.events.on_connect_abc import OnConnectABC from cpl_discord.events.on_disconnect_abc import OnDisconnectABC from cpl_discord.events.on_group_join_abc import OnGroupJoinABC @@ -48,51 +51,53 @@ from cpl_discord.events.on_voice_state_update_abc import OnVoiceStateUpdateABC from cpl_discord.events.on_webhooks_update_abc import OnWebhooksUpdateABC -class EventTypesEnum(Enum): +class DiscordEventTypesEnum(Enum): + on_bulk_message_delete = OnBulkMessageDeleteABC + on_command = OnCommandABC + on_command_error = OnCommandErrorABC + on_command_completion = OnCommandCompletionABC + on_connect = OnConnectABC + on_disconnect = OnDisconnectABC + on_group_join = OnGroupJoinABC + on_group_remove = OnGroupRemoveABC + on_guild_available = OnGuildAvailableABC + on_guild_channel_create = OnGuildChannelCreateABC + on_guild_channel_delete = OnGuildChannelDeleteABC + on_guild_channel_pins_update = OnGuildChannelPinsUpdateABC + on_guild_channel_update = OnGuildChannelUpdateABC + on_guild_emojis_update = OnGuildEmojisUpdateABC + on_guild_integrations_update = OnGuildIntegrationsUpdateABC + on_guild_join = OnGuildJoinABC + on_guild_remove = OnGuildRemoveABC + on_guild_role_create = OnGuildRoleCreateABC + on_guild_role_delete = OnGuildRoleDeleteABC + on_guild_role_update = OnGuildRoleUpdateABC + on_guild_unavailable = OnGuildUnavailableABC + on_guild_update = OnGuildUpdateABC + on_invite_create = OnInviteCreateABC + on_invite_delete = OnInviteDeleteABC + on_member_ban = OnMemberBanABC + on_member_join = OnMemberJoinABC + on_member_remove = OnMemberRemoveABC + on_member_unban = OnMemberUnbanABC + on_member_update = OnMemberUpdateABC + on_message = OnMessageABC + on_message_delete = OnMessageDeleteABC + on_message_edit = OnMessageEditABC + on_private_channel_create = OnPrivateChannelCreateABC + on_private_channel_delete = OnPrivateChannelDeleteABC + on_private_channel_pins_update = OnPrivateChannelPinsUpdateABC + on_private_channel_update = OnPrivateChannelUpdateABC + on_reaction_add = OnReactionAddABC + on_reaction_clear = OnReactionClearABC + on_reaction_clear_emoji = OnReactionClearEmojiABC + on_reaction_remove = OnReactionRemoveABC on_ready = OnReadyABC - on_bulk_message_delete_abc = OnBulkMessageDeleteABC - on_connect_abc = OnConnectABC - on_disconnect_abc = OnDisconnectABC - on_group_join_abc = OnGroupJoinABC - on_group_remove_abc = OnGroupRemoveABC - on_guild_available_abc = OnGuildAvailableABC - on_guild_channel_create_abc = OnGuildChannelCreateABC - on_guild_channel_delete_abc = OnGuildChannelDeleteABC - on_guild_channel_pins_update_abc = OnGuildChannelPinsUpdateABC - on_guild_channel_update_abc = OnGuildChannelUpdateABC - on_guild_emojis_update_abc = OnGuildEmojisUpdateABC - on_guild_integrations_update_abc = OnGuildIntegrationsUpdateABC - on_guild_join_abc = OnGuildJoinABC - on_guild_remove_abc = OnGuildRemoveABC - on_guild_role_create_abc = OnGuildRoleCreateABC - on_guild_role_delete_abc = OnGuildRoleDeleteABC - on_guild_role_update_abc = OnGuildRoleUpdateABC - on_guild_unavailable_abc = OnGuildUnavailableABC - on_guild_update_abc = OnGuildUpdateABC - on_invite_create_abc = OnInviteCreateABC - on_invite_delete_abc = OnInviteDeleteABC - on_member_ban_abc = OnMemberBanABC - on_member_join_abc = OnMemberJoinABC - on_member_remove_abc = OnMemberRemoveABC - on_member_unban_abc = OnMemberUnbanABC - on_member_update_abc = OnMemberUpdateABC - on_message_abc = OnMessageABC - on_message_delete_abc = OnMessageDeleteABC - on_message_edit_abc = OnMessageEditABC - on_private_channel_create_abc = OnPrivateChannelCreateABC - on_private_channel_delete_abc = OnPrivateChannelDeleteABC - on_private_channel_pins_update_abc = OnPrivateChannelPinsUpdateABC - on_private_channel_update_abc = OnPrivateChannelUpdateABC - on_reaction_add_abc = OnReactionAddABC - on_reaction_clear_abc = OnReactionClearABC - on_reaction_clear_emoji_abc = OnReactionClearEmojiABC - on_reaction_remove_abc = OnReactionRemoveABC - on_ready_abc = OnReadyABC - on_relationship_add_abc = OnRelationshipAddABC - on_relationship_remove_abc = OnRelationshipRemoveABC - on_relationship_update_abc = OnRelationshipUpdateABC - on_resume_abc = OnResumeABC - on_typing_abc = OnTypingABC - on_user_update_abc = OnUserUpdateABC - on_voice_state_update_abc = OnVoiceStateUpdateABC - on_webhooks_update_abc = OnWebhooksUpdateABC + on_relationship_add = OnRelationshipAddABC + on_relationship_remove = OnRelationshipRemoveABC + on_relationship_update = OnRelationshipUpdateABC + on_resume = OnResumeABC + on_typing = OnTypingABC + on_user_update = OnUserUpdateABC + on_voice_state_update = OnVoiceStateUpdateABC + on_webhooks_update = OnWebhooksUpdateABC diff --git a/src/cpl_discord/events/on_command_abc.py b/src/cpl_discord/events/on_command_abc.py new file mode 100644 index 00000000..addda92b --- /dev/null +++ b/src/cpl_discord/events/on_command_abc.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod + +from discord.ext.commands import Context + + +class OnCommandABC(ABC): + + @abstractmethod + def __init__(self): pass + + @abstractmethod + async def on_command(self, ctx: Context): pass diff --git a/src/cpl_discord/events/on_command_completion_abc.py b/src/cpl_discord/events/on_command_completion_abc.py new file mode 100644 index 00000000..d87c6077 --- /dev/null +++ b/src/cpl_discord/events/on_command_completion_abc.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod + +from discord.ext.commands import Context, CommandError + + +class OnCommandCompletionABC(ABC): + + @abstractmethod + def __init__(self): pass + + @abstractmethod + async def on_command_completion(self, ctx: Context): pass diff --git a/src/cpl_discord/events/on_command_error_abc.py b/src/cpl_discord/events/on_command_error_abc.py new file mode 100644 index 00000000..e40989a4 --- /dev/null +++ b/src/cpl_discord/events/on_command_error_abc.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod + +from discord.ext.commands import Context, CommandError + + +class OnCommandErrorABC(ABC): + + @abstractmethod + def __init__(self): pass + + @abstractmethod + async def on_command_error(self, ctx: Context, error: CommandError): pass diff --git a/src/cpl_discord/service/command_error_handler_service.py b/src/cpl_discord/service/command_error_handler_service.py new file mode 100644 index 00000000..0893d866 --- /dev/null +++ b/src/cpl_discord/service/command_error_handler_service.py @@ -0,0 +1,14 @@ +from discord.ext.commands import Context, CommandError + +from cpl_core.logging import LoggerABC +from cpl_discord.events.on_command_error_abc import OnCommandErrorABC + + +class CommandErrorHandlerService(OnCommandErrorABC): + + def __init__(self, logger: LoggerABC): + OnCommandErrorABC.__init__(self) + self._logger = logger + + async def on_command_error(self, ctx: Context, error: CommandError): + self._logger.error(__name__, f'Error in command: {ctx.command}', error) diff --git a/src/cpl_discord/service/discord_bot_service.py b/src/cpl_discord/service/discord_bot_service.py index 00d2c767..8d4b1afa 100644 --- a/src/cpl_discord/service/discord_bot_service.py +++ b/src/cpl_discord/service/discord_bot_service.py @@ -1,7 +1,4 @@ -import sys - import discord -from discord.ext import commands from cpl_core.configuration import ConfigurationABC from cpl_core.console import Console @@ -66,6 +63,6 @@ class DiscordBotService(DiscordBotServiceABC): if self._logging_st.console.value >= LoggingLevelEnum.INFO.value: Console.banner(self._env.application_name if self._env.application_name != '' else 'A bot') - self.add_cog(self._discord_service) + self._discord_service.init(self) await self._discord_service.on_ready() diff --git a/src/cpl_discord/service/discord_collection.py b/src/cpl_discord/service/discord_collection.py index a30773b9..49e9d092 100644 --- a/src/cpl_discord/service/discord_collection.py +++ b/src/cpl_discord/service/discord_collection.py @@ -2,6 +2,8 @@ from typing import Type, Optional from cpl_core.console import Console from cpl_core.dependency_injection import ServiceCollectionABC +from cpl_discord.discord_event_types_enum import DiscordEventTypesEnum +from cpl_discord.service.command_error_handler_service import CommandErrorHandlerService from cpl_discord.service.discord_collection_abc import DiscordCollectionABC from cpl_query.extension import List @@ -15,9 +17,16 @@ class DiscordCollection(DiscordCollectionABC): self._services = service_collection self._events: dict[str, List] = {} + self._commands = List(type(CommandABC)) + + self.add_event(DiscordEventTypesEnum.on_command_error.value, CommandErrorHandlerService) def add_command(self, _t: Type[CommandABC]): self._services.add_transient(CommandABC, _t) + self._commands.append(_t) + + def get_commands(self) -> List[CommandABC]: + return self._commands def add_event(self, _t_event: Type, _t: Type): self._services.add_transient(_t_event, _t) diff --git a/src/cpl_discord/service/discord_collection_abc.py b/src/cpl_discord/service/discord_collection_abc.py index 5262f1a8..122eb100 100644 --- a/src/cpl_discord/service/discord_collection_abc.py +++ b/src/cpl_discord/service/discord_collection_abc.py @@ -14,6 +14,9 @@ class DiscordCollectionABC(ABC): @abstractmethod def add_command(self, _t: Type[CommandABC]): pass + @abstractmethod + def get_commands(self) -> List[CommandABC]: pass + @abstractmethod def add_event(self, _t_event: Type, _t: Type): pass diff --git a/src/cpl_discord/service/discord_service.py b/src/cpl_discord/service/discord_service.py index 247ea1c2..2e2a01e1 100644 --- a/src/cpl_discord/service/discord_service.py +++ b/src/cpl_discord/service/discord_service.py @@ -3,6 +3,7 @@ from typing import Optional, Sequence, Union, Type import discord from discord.ext import commands +from discord.ext.commands import Context, CommandError from cpl_core.console import Console from cpl_core.dependency_injection import ServiceProviderABC @@ -10,6 +11,9 @@ from cpl_core.logging import LoggerABC from cpl_core.utils import String from cpl_discord.command.discord_commands_meta import DiscordCogMeta from cpl_discord.events.on_bulk_message_delete_abc import OnBulkMessageDeleteABC +from cpl_discord.events.on_command_abc import OnCommandABC +from cpl_discord.events.on_command_completion_abc import OnCommandCompletionABC +from cpl_discord.events.on_command_error_abc import OnCommandErrorABC from cpl_discord.events.on_connect_abc import OnConnectABC from cpl_discord.events.on_disconnect_abc import OnDisconnectABC from cpl_discord.events.on_group_join_abc import OnGroupJoinABC @@ -90,13 +94,45 @@ class DiscordService(DiscordServiceABC, commands.Cog, metaclass=DiscordCogMeta): func = getattr(event_instance, func_name) await func(*args) except Exception as e: - self._logger.error(__name__, f'Cannot execute {func_name} of {event_instance.__name__}', e) + self._logger.error(__name__, f'Cannot execute {func_name} of {type(event_instance).__name__}', e) + + def init(self, bot: commands.Bot): + try: + bot.add_cog(self) + except Exception as e: + self._logger.error(__name__, f'{type(self).__name__} initialization failed', e) + + try: + for command_type in self._collection.get_commands(): + self._logger.trace(__name__, f'Register command {command_type.__name__}') + command = self._services.get_service(command_type) + if command is None: + self._logger.warn(__name__, f'Instance of {command_type.__name__} not found') + continue + bot.add_cog(command) + except Exception as e: + self._logger.error(__name__, f'Registration of commands failed', e) @commands.Cog.listener() async def on_connect(self): self._logger.trace(__name__, f'Received on_connect') await self._handle_event(OnConnectABC) + @commands.Cog.listener() + async def on_command(self, ctx: Context): + self._logger.trace(__name__, f'Received on_command') + await self._handle_event(OnCommandABC, ctx) + + @commands.Cog.listener() + async def on_command_error(self, ctx: Context, error: CommandError): + self._logger.trace(__name__, f'Received on_command_error') + await self._handle_event(OnCommandErrorABC, ctx, error) + + @commands.Cog.listener() + async def on_command_completion(self, ctx: Context): + self._logger.trace(__name__, f'Received on_command_completion') + await self._handle_event(OnCommandCompletionABC, ctx) + @commands.Cog.listener() async def on_disconnect(self): self._logger.trace(__name__, f'Received on_disconnect') diff --git a/src/cpl_discord/service/discord_service_abc.py b/src/cpl_discord/service/discord_service_abc.py index d6cc8379..6d7c48a3 100644 --- a/src/cpl_discord/service/discord_service_abc.py +++ b/src/cpl_discord/service/discord_service_abc.py @@ -3,6 +3,7 @@ from datetime import datetime from typing import Optional, Sequence, Union import discord +from discord.ext import commands class DiscordServiceABC(ABC): @@ -10,9 +11,21 @@ class DiscordServiceABC(ABC): def __init__(self): ABC.__init__(self) + @abstractmethod + def init(self, bot: commands.Bot): pass + @abstractmethod async def on_connect(self): pass + @abstractmethod + async def on_command(self): pass + + @abstractmethod + async def on_command_error(self): pass + + @abstractmethod + async def on_command_completion(self): pass + @abstractmethod async def on_disconnect(self): pass diff --git a/src/tests/custom/discord/src/discord_bot/main.py b/src/tests/custom/discord/src/discord_bot/main.py index 909287ae..f6c79e69 100644 --- a/src/tests/custom/discord/src/discord_bot/main.py +++ b/src/tests/custom/discord/src/discord_bot/main.py @@ -1,4 +1,5 @@ import asyncio +from typing import Optional from cpl_core.application import ApplicationBuilder @@ -6,12 +7,24 @@ from discord_bot.application import Application from discord_bot.startup import Startup -async def main(): - app_builder = ApplicationBuilder(Application) - app_builder.use_startup(Startup) - app: Application = await app_builder.build_async() - await app.run_async() +class Main: + + def __init__(self): + self._app: Optional[Application] = None + + async def main(self): + app_builder = ApplicationBuilder(Application) + app_builder.use_startup(Startup) + self._app: Application = await app_builder.build_async() + await self._app.run_async() + + async def stop(self): + await self._app.stop_async() if __name__ == '__main__': - asyncio.run(main()) + main = Main() + try: + asyncio.run(main.main()) + except KeyboardInterrupt: + asyncio.run(main.stop()) diff --git a/src/tests/custom/discord/src/discord_bot/startup.py b/src/tests/custom/discord/src/discord_bot/startup.py index 8f03ab7d..e138c5d1 100644 --- a/src/tests/custom/discord/src/discord_bot/startup.py +++ b/src/tests/custom/discord/src/discord_bot/startup.py @@ -3,9 +3,10 @@ from cpl_core.configuration import ConfigurationABC from cpl_core.dependency_injection import ServiceProviderABC, ServiceCollectionABC from cpl_core.environment import ApplicationEnvironment from cpl_discord import get_discord_collection -from cpl_discord.event_types_enum import EventTypesEnum +from cpl_discord.discord_event_types_enum import DiscordEventTypesEnum from modules.hello_world.on_ready_event import OnReadyEvent from modules.hello_world.on_ready_test_event import OnReadyTestEvent +from modules.hello_world.ping_command import PingCommand class Startup(StartupABC): @@ -24,7 +25,8 @@ class Startup(StartupABC): services.add_logging() services.add_discord() dc_collection = get_discord_collection(services) - dc_collection.add_event(EventTypesEnum.on_ready.value, OnReadyEvent) - dc_collection.add_event(EventTypesEnum.on_ready.value, OnReadyTestEvent) + dc_collection.add_event(DiscordEventTypesEnum.on_ready.value, OnReadyEvent) + dc_collection.add_event(DiscordEventTypesEnum.on_ready.value, OnReadyTestEvent) + dc_collection.add_command(PingCommand) return services.build_service_provider()