Improved cpl-discords service loading for commands and events

This commit is contained in:
Sven Heidemann 2023-10-12 20:41:34 +02:00
parent f357e97ce5
commit 1dd48899d7
6 changed files with 27 additions and 47 deletions

View File

@ -157,12 +157,12 @@ class ServiceProvider(ServiceProviderABC):
return implementation return implementation
def get_services(self, service_type: typing.Type[T], *args, **kwargs) -> list[Optional[T]]: def get_services(self, service_type: T, *args, **kwargs) -> list[Optional[T]]:
implementations = [] implementations = []
if typing.get_origin(service_type) != list: if typing.get_origin(service_type) == list:
raise Exception(f"Invalid type {service_type}! Expected list of type") raise Exception(f"Invalid type {service_type}! Expected single type not list of type")
implementations.extend(self._get_services(typing.get_args(service_type)[0])) implementations.extend(self._get_services(service_type))
return implementations return implementations

View File

@ -76,7 +76,7 @@ class ServiceProviderABC(ABC):
pass pass
@abstractmethod @abstractmethod
def get_services(self, service_type: Type[T], *args, **kwargs) -> list[Optional[T]]: def get_services(self, service_type: T, *args, **kwargs) -> list[Optional[T]]:
r"""Returns instance of given type r"""Returns instance of given type
Parameter Parameter

View File

@ -3,8 +3,8 @@
"Name": "cpl-discord", "Name": "cpl-discord",
"Version": { "Version": {
"Major": "2023", "Major": "2023",
"Minor": "4", "Minor": "10",
"Micro": "0.post3" "Micro": "0"
}, },
"Author": "Sven Heidemann", "Author": "Sven Heidemann",
"AuthorEmail": "sven.heidemann@sh-edraft.de", "AuthorEmail": "sven.heidemann@sh-edraft.de",

View File

@ -1,11 +1,11 @@
from typing import Type, Optional from typing import Type
from cpl_core.console import Console, ForegroundColorEnum
from cpl_core.dependency_injection import ServiceCollectionABC from cpl_core.dependency_injection import ServiceCollectionABC
from cpl_discord.command.discord_command_abc import DiscordCommandABC from cpl_discord.command.discord_command_abc import DiscordCommandABC
from cpl_discord.discord_event_types_enum import DiscordEventTypesEnum from cpl_discord.discord_event_types_enum import DiscordEventTypesEnum
from cpl_discord.service.command_error_handler_service import CommandErrorHandlerService from cpl_discord.service.command_error_handler_service import CommandErrorHandlerService
from cpl_discord.service.discord_collection_abc import DiscordCollectionABC from cpl_discord.service.discord_collection_abc import DiscordCollectionABC
from cpl_query.extension.list import List
class DiscordCollection(DiscordCollectionABC): class DiscordCollection(DiscordCollectionABC):
@ -13,26 +13,21 @@ class DiscordCollection(DiscordCollectionABC):
DiscordCollectionABC.__init__(self) DiscordCollectionABC.__init__(self)
self._services = service_collection self._services = service_collection
self._events: dict[str, List] = {}
self._commands = List(type(DiscordCommandABC))
self.add_event(DiscordEventTypesEnum.on_command_error.value, CommandErrorHandlerService) self.add_event(DiscordEventTypesEnum.on_command_error.value, CommandErrorHandlerService)
def add_command(self, _t: Type[DiscordCommandABC]): def add_command(self, _t: Type[DiscordCommandABC]):
Console.set_foreground_color(ForegroundColorEnum.yellow)
Console.write_line(
f"{type(self).__name__}.add_command is deprecated. Instead, use ServiceCollection.add_transient directly!"
)
Console.color_reset()
self._services.add_transient(DiscordCommandABC, _t) self._services.add_transient(DiscordCommandABC, _t)
self._commands.append(_t)
def get_commands(self) -> List[DiscordCommandABC]:
return self._commands
def add_event(self, _t_event: Type, _t: Type): def add_event(self, _t_event: Type, _t: Type):
Console.set_foreground_color(ForegroundColorEnum.yellow)
Console.write_line(
f"{type(self).__name__}.add_event is deprecated. Instead, use ServiceCollection.add_transient directly!"
)
Console.color_reset()
self._services.add_transient(_t_event, _t) self._services.add_transient(_t_event, _t)
if _t_event not in self._events:
self._events[_t_event] = List(type(_t_event))
self._events[_t_event].append(_t)
def get_events_by_base(self, _t_event: Type) -> Optional[List]:
if _t_event not in self._events:
return None
return self._events[_t_event]

View File

@ -13,14 +13,6 @@ class DiscordCollectionABC(ABC):
def add_command(self, _t: Type[DiscordCommandABC]): def add_command(self, _t: Type[DiscordCommandABC]):
pass pass
@abstractmethod
def get_commands(self) -> List[DiscordCommandABC]:
pass
@abstractmethod @abstractmethod
def add_event(self, _t_event: Type, _t: Type): def add_event(self, _t_event: Type, _t: Type):
pass pass
@abstractmethod
def get_events_by_base(self, _t_event: Type):
pass

View File

@ -4,11 +4,12 @@ from typing import Optional, Sequence, Union, Type
import discord import discord
from discord import RawReactionActionEvent from discord import RawReactionActionEvent
from discord.ext import commands from discord.ext import commands
from discord.ext.commands import Context, CommandError, Cog, Command from discord.ext.commands import Context, CommandError, Cog
from cpl_core.dependency_injection import ServiceProviderABC from cpl_core.dependency_injection import ServiceProviderABC
from cpl_core.logging import LoggerABC from cpl_core.logging import LoggerABC
from cpl_core.utils import String from cpl_core.utils import String
from cpl_discord.command import DiscordCommandABC
from cpl_discord.command.discord_commands_meta import DiscordCogMeta from cpl_discord.command.discord_commands_meta import DiscordCogMeta
from cpl_discord.events.on_bulk_message_delete_abc import OnBulkMessageDeleteABC from cpl_discord.events.on_bulk_message_delete_abc import OnBulkMessageDeleteABC
from cpl_discord.events.on_command_abc import OnCommandABC from cpl_discord.events.on_command_abc import OnCommandABC
@ -66,25 +67,17 @@ from cpl_discord.events.on_typing_abc import OnTypingABC
from cpl_discord.events.on_user_update_abc import OnUserUpdateABC from cpl_discord.events.on_user_update_abc import OnUserUpdateABC
from cpl_discord.events.on_voice_state_update_abc import OnVoiceStateUpdateABC from cpl_discord.events.on_voice_state_update_abc import OnVoiceStateUpdateABC
from cpl_discord.events.on_webhooks_update_abc import OnWebhooksUpdateABC from cpl_discord.events.on_webhooks_update_abc import OnWebhooksUpdateABC
from cpl_discord.service.discord_collection_abc import DiscordCollectionABC
from cpl_discord.service.discord_service_abc import DiscordServiceABC from cpl_discord.service.discord_service_abc import DiscordServiceABC
class DiscordService(DiscordServiceABC, commands.Cog, metaclass=DiscordCogMeta): class DiscordService(DiscordServiceABC, commands.Cog, metaclass=DiscordCogMeta):
def __init__(self, logger: LoggerABC, dc_collection: DiscordCollectionABC, services: ServiceProviderABC): def __init__(self, logger: LoggerABC, services: ServiceProviderABC):
DiscordServiceABC.__init__(self) DiscordServiceABC.__init__(self)
self._logger = logger self._logger = logger
self._collection = dc_collection
self._services = services self._services = services
async def _handle_event(self, event: Type, *args, **kwargs): async def _handle_event(self, event: Type, *args, **kwargs):
event_collection = self._collection.get_events_by_base(event) for event_instance in self._services.get_services(event):
if event_collection is None:
return
for event_type in event_collection:
event_instance = self._services.get_service(event_type)
func_name = event.__name__ func_name = event.__name__
if func_name.endswith("ABC"): if func_name.endswith("ABC"):
func_name = func_name.replace("ABC", "") func_name = func_name.replace("ABC", "")
@ -104,11 +97,11 @@ class DiscordService(DiscordServiceABC, commands.Cog, metaclass=DiscordCogMeta):
self._logger.error(__name__, f"{type(self).__name__} initialization failed", e) self._logger.error(__name__, f"{type(self).__name__} initialization failed", e)
try: try:
for command_type in self._collection.get_commands(): for command in self._services.get_services(DiscordCommandABC):
self._logger.trace(__name__, f"Register command {command_type.__name__}") self._logger.trace(__name__, f"Register command {type(command).__name__}")
command: Cog = self._services.get_service(command_type) command: Cog = command
if command is None: if command is None:
self._logger.warn(__name__, f"Instance of {command_type.__name__} not found") self._logger.warn(__name__, f"Instance of {type(command).__name__} not found")
continue continue
await bot.add_cog(command) await bot.add_cog(command)
except Exception as e: except Exception as e: