Added *args and **kwargs support to discord bot and DI

This commit is contained in:
Sven Heidemann 2023-01-12 08:57:01 +01:00
parent e0ca7c2ae6
commit 7f46fbe87a
6 changed files with 27 additions and 18 deletions

View File

@ -4,7 +4,7 @@
"Version": { "Version": {
"Major": "2022", "Major": "2022",
"Minor": "12", "Minor": "12",
"Micro": "1.post1" "Micro": "1.post2"
}, },
"Author": "Sven Heidemann", "Author": "Sven Heidemann",
"AuthorEmail": "sven.heidemann@sh-edraft.de", "AuthorEmail": "sven.heidemann@sh-edraft.de",

View File

@ -57,7 +57,7 @@ class ServiceProvider(ServiceProviderABC):
# raise Exception(f'Service {parameter.annotation} not found') # raise Exception(f'Service {parameter.annotation} not found')
def _get_services(self, t: type) -> list[Optional[object]]: def _get_services(self, t: type, *args, **kwargs) -> list[Optional[object]]:
implementations = [] implementations = []
for descriptor in self._service_descriptors: for descriptor in self._service_descriptors:
if descriptor.service_type == t or issubclass(descriptor.service_type, t): if descriptor.service_type == t or issubclass(descriptor.service_type, t):
@ -65,7 +65,7 @@ class ServiceProvider(ServiceProviderABC):
implementations.append(descriptor.implementation) implementations.append(descriptor.implementation)
continue continue
implementation = self.build_service(descriptor.service_type) implementation = self.build_service(descriptor.service_type, *args, **kwargs)
if descriptor.lifetime == ServiceLifetimeEnum.singleton: if descriptor.lifetime == ServiceLifetimeEnum.singleton:
descriptor.implementation = implementation descriptor.implementation = implementation
@ -102,7 +102,7 @@ class ServiceProvider(ServiceProviderABC):
return params return params
def build_service(self, service_type: type) -> object: def build_service(self, service_type: type, *args, **kwargs) -> object:
for descriptor in self._service_descriptors: for descriptor in self._service_descriptors:
if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type): if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type):
if descriptor.implementation is not None: if descriptor.implementation is not None:
@ -115,7 +115,7 @@ class ServiceProvider(ServiceProviderABC):
sig = signature(service_type.__init__) sig = signature(service_type.__init__)
params = self.build_by_signature(sig) params = self.build_by_signature(sig)
return service_type(*params) return service_type(*params, *args, **kwargs)
def set_scope(self, scope: ScopeABC): def set_scope(self, scope: ScopeABC):
self._scope = scope self._scope = scope
@ -124,7 +124,7 @@ class ServiceProvider(ServiceProviderABC):
sb = ScopeBuilder(ServiceProvider(copy.deepcopy(self._service_descriptors), self._configuration, self._database_context)) sb = ScopeBuilder(ServiceProvider(copy.deepcopy(self._service_descriptors), self._configuration, self._database_context))
return sb.build() return sb.build()
def get_service(self, service_type: T) -> Optional[T]: def get_service(self, service_type: T, *args, **kwargs) -> Optional[T]:
result = self._find_service(service_type) result = self._find_service(service_type)
if result is None: if result is None:
@ -133,13 +133,13 @@ class ServiceProvider(ServiceProviderABC):
if result.implementation is not None: if result.implementation is not None:
return result.implementation return result.implementation
implementation = self.build_service(service_type) implementation = self.build_service(service_type, *args, **kwargs)
if result.lifetime == ServiceLifetimeEnum.singleton or result.lifetime == ServiceLifetimeEnum.scoped and self._scope is not None: if result.lifetime == ServiceLifetimeEnum.singleton or result.lifetime == ServiceLifetimeEnum.scoped and self._scope is not None:
result.implementation = implementation result.implementation = implementation
return implementation return implementation
def get_services(self, service_type: T) -> list[Optional[T]]: def get_services(self, service_type: T, *args, **kwargs) -> list[Optional[T]]:
implementations = [] implementations = []
if typing.get_origin(service_type) != list: if typing.get_origin(service_type) != list:

View File

@ -25,7 +25,7 @@ class ServiceProviderABC(ABC):
pass pass
@abstractmethod @abstractmethod
def build_service(self, service_type: type) -> object: def build_service(self, service_type: type, *args, **kwargs) -> object:
r"""Creates instance of given type r"""Creates instance of given type
Parameter Parameter
@ -61,7 +61,7 @@ class ServiceProviderABC(ABC):
pass pass
@abstractmethod @abstractmethod
def get_service(self, instance_type: T) -> Optional[T]: def get_service(self, instance_type: T, *args, **kwargs) -> Optional[T]:
r"""Returns instance of given type r"""Returns instance of given type
Parameter Parameter
@ -76,7 +76,7 @@ class ServiceProviderABC(ABC):
pass pass
@abstractmethod @abstractmethod
def get_services(self, service_type: T) -> list[Optional[T]]: def get_services(self, service_type: T, *args, **kwargs) -> list[Optional[T]]:
r"""Returns instance of given type r"""Returns instance of given type
Parameter Parameter

View File

@ -4,7 +4,7 @@
"Version": { "Version": {
"Major": "2022", "Major": "2022",
"Minor": "12", "Minor": "12",
"Micro": "1" "Micro": "1.post1"
}, },
"Author": "Sven Heidemann", "Author": "Sven Heidemann",
"AuthorEmail": "sven.heidemann@sh-edraft.de", "AuthorEmail": "sven.heidemann@sh-edraft.de",
@ -16,7 +16,7 @@
"LicenseName": "MIT", "LicenseName": "MIT",
"LicenseDescription": "MIT, see LICENSE for more details.", "LicenseDescription": "MIT, see LICENSE for more details.",
"Dependencies": [ "Dependencies": [
"cpl-core>=2022.12.1", "cpl-core>=2022.12.1.post2",
"discord.py==2.1.0", "discord.py==2.1.0",
"cpl-query>=2022.12.2.post1" "cpl-query>=2022.12.2.post1"
], ],

View File

@ -21,7 +21,9 @@ class DiscordBotService(DiscordBotServiceABC):
discord_bot_settings: DiscordBotSettings, discord_bot_settings: DiscordBotSettings,
env: ApplicationEnvironmentABC, env: ApplicationEnvironmentABC,
logging_st: LoggingSettings, logging_st: LoggingSettings,
discord_service: DiscordServiceABC discord_service: DiscordServiceABC,
*args,
**kwargs
): ):
# services # services
self._config = config self._config = config
@ -34,7 +36,12 @@ class DiscordBotService(DiscordBotServiceABC):
self._discord_settings = self._get_settings(discord_bot_settings) self._discord_settings = self._get_settings(discord_bot_settings)
# setup super # setup super
DiscordBotServiceABC.__init__(self, command_prefix=self._discord_settings.prefix, help_command=None, intents=discord.Intents().all()) DiscordBotServiceABC.__init__(
self,
*args,
command_prefix=self._discord_settings.prefix, help_command=None, intents=discord.Intents().all(),
**kwargs
)
self._base = super(DiscordBotServiceABC, self) self._base = super(DiscordBotServiceABC, self)
@staticmethod @staticmethod
@ -50,7 +57,9 @@ class DiscordBotService(DiscordBotServiceABC):
new_settings.from_dict({ new_settings.from_dict({
'Token': env_token if token is None or token == '' else token, 'Token': env_token if token is None or token == '' else token,
'Prefix': ('! ' if self._is_string_invalid(env_prefix) else env_prefix) if self._is_string_invalid(prefix) else prefix 'Prefix':
('! ' if self._is_string_invalid(env_prefix) else env_prefix)
if self._is_string_invalid(prefix) else prefix
}) })
if new_settings.token is None or new_settings.token == '': if new_settings.token is None or new_settings.token == '':
raise Exception('You have to configure discord token by appsettings or environment variables') raise Exception('You have to configure discord token by appsettings or environment variables')

View File

@ -8,8 +8,8 @@ from cpl_query.extension.list import List
class DiscordBotServiceABC(commands.Bot): class DiscordBotServiceABC(commands.Bot):
def __init__(self, **kwargs): def __init__(self, *args, **kwargs):
commands.Bot.__init__(self, **kwargs) commands.Bot.__init__(self, *args, **kwargs)
@abstractmethod @abstractmethod
async def start_async(self): pass async def start_async(self): pass