diff --git a/bot/src/bot_core/abc/task_abc.py b/bot/src/bot_core/abc/task_abc.py index 5f578a99..9bbd3e54 100644 --- a/bot/src/bot_core/abc/task_abc.py +++ b/bot/src/bot_core/abc/task_abc.py @@ -21,7 +21,7 @@ class TaskABC(commands.Cog): @ServiceProviderABC.inject async def _wait_until_ready(self, config: ConfigurationABC, logger: TaskLogger, bot: DiscordBotServiceABC): - logger.debug(__name__, f"Waiting before {type(self).__name__}") + logger.debug(__name__, f"Waiting before ready {type(self).__name__}") await bot.wait_until_ready() async def wait(): diff --git a/bot/src/bot_data/model/scheduled_event.py b/bot/src/bot_data/model/scheduled_event.py index 041cba28..ca762eab 100644 --- a/bot/src/bot_data/model/scheduled_event.py +++ b/bot/src/bot_data/model/scheduled_event.py @@ -154,7 +154,7 @@ class ScheduledEvent(TableABC): {"NULL" if self._channel_id is None else f"'{self._channel_id}'"}, '{self._start_time}', {"NULL" if self._end_time is None else f"'{self._end_time}'"}, - '{self._entity_type}', + '{self._entity_type.value}', {"NULL" if self._location is None else f"'{self._location}'"}, {self._server.id} ); @@ -172,7 +172,7 @@ class ScheduledEvent(TableABC): `ChannelId` = {"NULL" if self._channel_id is None else f"'{self._channel_id}'"}, `StartTime` = '{self._start_time}', `EndTime` = {"NULL" if self._end_time is None else f"'{self._end_time}'"}, - `EntityType` = '{self._entity_type}', + `EntityType` = '{self._entity_type.value}', `Location` = {"NULL" if self._location is None else f"'{self._location}'"} WHERE `Id` = {self._id}; """ diff --git a/bot/src/bot_data/service/scheduled_event_repository_service.py b/bot/src/bot_data/service/scheduled_event_repository_service.py index 01d4fc6c..dc394fdf 100644 --- a/bot/src/bot_data/service/scheduled_event_repository_service.py +++ b/bot/src/bot_data/service/scheduled_event_repository_service.py @@ -2,6 +2,7 @@ from typing import Optional from cpl_core.database.context import DatabaseContextABC from cpl_query.extension import List +from discord import EntityType from bot_core.logging.database_logger import DatabaseLogger from bot_data.abc.server_repository_abc import ServerRepositoryABC @@ -39,7 +40,7 @@ class ScheduledEventRepositoryService(ScheduledEventRepositoryABC): int(self._get_value_from_result(sql_result[4])), # channel_id self._get_value_from_result(sql_result[5]), # start_time self._get_value_from_result(sql_result[6]), # end_time - self._get_value_from_result(sql_result[7]), # entity_type + EntityType(int(self._get_value_from_result(sql_result[7]))), # entity_type self._get_value_from_result(sql_result[8]), # location self._servers.get_server_by_id((sql_result[9])), # server self._get_value_from_result(sql_result[10]), # created_at diff --git a/bot/src/modules/base/base_module.py b/bot/src/modules/base/base_module.py index fc04e55e..df582b08 100644 --- a/bot/src/modules/base/base_module.py +++ b/bot/src/modules/base/base_module.py @@ -45,6 +45,7 @@ from modules.base.events.base_on_voice_state_update_event_scheduled_event_bonus from modules.base.forms.bug_report_form import BugReportForm from modules.base.forms.complaint_form import ComplaintForm from modules.base.helper.base_reaction_handler import BaseReactionHandler +from modules.base.scheduled_events_watcher import ScheduledEventsWatcher from modules.base.service.event_service import EventService from modules.base.service.user_warnings_service import UserWarningsService @@ -61,6 +62,7 @@ class BaseModule(ModuleABC): services.add_singleton(EventService) services.add_transient(UserWarningsService) services.add_singleton(TaskABC, BirthdayWatcher) + services.add_singleton(TaskABC, ScheduledEventsWatcher) # forms services.add_transient(BugReportForm) diff --git a/bot/src/modules/base/scheduled_events_watcher.py b/bot/src/modules/base/scheduled_events_watcher.py new file mode 100644 index 00000000..dc2bb849 --- /dev/null +++ b/bot/src/modules/base/scheduled_events_watcher.py @@ -0,0 +1,126 @@ +import calendar +from datetime import datetime, timedelta +from zoneinfo import ZoneInfo + +from cpl_core.configuration import ConfigurationABC +from cpl_core.database.context import DatabaseContextABC +from cpl_discord.service import DiscordBotServiceABC +from cpl_query.extension import List +from cpl_translation import TranslatePipe +from discord import Guild, PrivacyLevel +from discord.ext import tasks +from discord.scheduled_event import ScheduledEvent as DiscordEvent + +from bot_core.abc.task_abc import TaskABC +from bot_core.logging.task_logger import TaskLogger +from bot_core.service.message_service import MessageService +from bot_data.abc.scheduled_event_repository_abc import ScheduledEventRepositoryABC +from bot_data.abc.server_repository_abc import ServerRepositoryABC +from bot_data.model.scheduled_event import ScheduledEvent +from bot_data.model.scheduled_event_interval_enum import ScheduledEventIntervalEnum + + +class ScheduledEventsWatcher(TaskABC): + def __init__( + self, + config: ConfigurationABC, + logger: TaskLogger, + bot: DiscordBotServiceABC, + db: DatabaseContextABC, + servers: ServerRepositoryABC, + events: ScheduledEventRepositoryABC, + message_service: MessageService, + t: TranslatePipe, + ): + TaskABC.__init__(self) + + self._config = config + self._logger = logger + self._bot = bot + self._db = db + self._servers = servers + self._events = events + self._message_service = message_service + self._t = t + + if not self._is_maintenance(): + self.watch.start() + + def _append_interval(self, interval: ScheduledEventIntervalEnum, ts: datetime) -> datetime: + now = datetime.now() + if ts >= now: + return ts + + if interval == ScheduledEventIntervalEnum.daily: + ts = ts + timedelta(days=1) + + elif interval == ScheduledEventIntervalEnum.weekly: + ts = ts + timedelta(weeks=1) + + elif interval == ScheduledEventIntervalEnum.monthly: + days_in_month = calendar.monthrange(ts.year, ts.month + 1)[1] + ts = ts + timedelta(days=days_in_month) + + elif interval == ScheduledEventIntervalEnum.yearly: + ts = ts + timedelta(days=365) + + return ts + + @tasks.loop(minutes=1) + async def watch(self): + self._logger.info(__name__, "Watching scheduled events") + try: + for guild in self._bot.guilds: + guild: Guild = guild + server = self._servers.get_server_by_discord_id(guild.id) + scheduled_events_from_guild = self._events.get_scheduled_events_by_server_id(server.id) + for scheduled_event in scheduled_events_from_guild: + scheduled_event: ScheduledEvent = scheduled_event + from_guild = List(DiscordEvent, guild.scheduled_events).where( + lambda x: x.name == scheduled_event.name + and x.description == scheduled_event.description + and x.entity_type == scheduled_event.entity_type + ) + if from_guild.count() != 0: + continue + + kwargs = {"name": scheduled_event.name, "description": scheduled_event.description} + + if scheduled_event.channel_id is not None: + kwargs["channel"] = guild.get_channel(scheduled_event.channel_id) + + if scheduled_event.start_time is not None: + scheduled_event.start_time = self._append_interval( + scheduled_event.interval, scheduled_event.start_time + ) + + start_time = scheduled_event.start_time.replace(tzinfo=ZoneInfo("Europe/Berlin")) + + kwargs["start_time"] = start_time + + if scheduled_event.end_time is not None: + scheduled_event.end_time = self._append_interval( + scheduled_event.interval, scheduled_event.end_time + ) + end_time = scheduled_event.end_time.replace(tzinfo=ZoneInfo("Europe/Berlin")) + kwargs["end_time"] = end_time + + kwargs["entity_type"] = scheduled_event.entity_type + if scheduled_event.location is not None: + kwargs["location"] = scheduled_event.location + + kwargs["privacy_level"] = PrivacyLevel.guild_only + + try: + self._logger.debug(__name__, f"Try to create scheduled event for guild {guild.name}") + await guild.create_scheduled_event(**kwargs) + self._events.update_scheduled_event(scheduled_event) + self._db.save_changes() + except Exception as e: + self._logger.error(__name__, f"Watching scheduled events failed", e) + except Exception as e: + self._logger.error(__name__, f"Watching scheduled events failed", e) + + @watch.before_loop + async def wait(self): + await self._wait_until_ready()