From 93269908bc11a0d1c14f95f1df5c5052fa0f65d0 Mon Sep 17 00:00:00 2001 From: Sven Heidemann Date: Tue, 30 Nov 2021 17:59:44 +0100 Subject: [PATCH] Improved database connection --- cpl-workspace.json | 3 +- src/gismo/main.py | 5 +- src/gismo/startup.py | 2 + src/gismo_data/abc/server_repository_abc.py | 2 +- src/gismo_data/abc/user_repository_abc.py | 2 +- src/gismo_data/db_context.py | 4 +- src/gismo_data/model/server.py | 14 +- src/gismo_data/model/user.py | 10 +- .../service/server_repository_service.py | 9 +- .../service/user_repository_service.py | 12 +- src/modules/database/__init__.py | 1 + src/modules/database/database.json | 43 +++++++ src/modules/database/database.py | 120 ++++++++++++++++++ src/modules/database/database_extension.py | 17 +++ 14 files changed, 225 insertions(+), 19 deletions(-) create mode 100644 src/modules/database/__init__.py create mode 100644 src/modules/database/database.json create mode 100644 src/modules/database/database.py create mode 100644 src/modules/database/database_extension.py diff --git a/cpl-workspace.json b/cpl-workspace.json index 72237cd..b0db2c0 100644 --- a/cpl-workspace.json +++ b/cpl-workspace.json @@ -9,7 +9,8 @@ "modules-core": "src/modules_core/modules-core.json", "boot-log": "src/modules/boot_log/boot-log.json", "level-generator": "tools/level_generator/level-generator.json", - "ontime-calculator": "tools/ontime_calculator/ontime-calculator.json" + "ontime-calculator": "tools/ontime_calculator/ontime-calculator.json", + "database": "src/modules/database/database.json" }, "Scripts": { "build-start": "cd src/gismo_cli; echo 'gismo-cli:'; cpl build; cd ../gismo; echo 'gismo:'; cpl build; cd ../../dist/gismo/build/gismo; bash gismo", diff --git a/src/gismo/main.py b/src/gismo/main.py index 4717405..87df32e 100644 --- a/src/gismo/main.py +++ b/src/gismo/main.py @@ -6,6 +6,7 @@ from cpl_core.application import ApplicationBuilder from gismo.application import Gismo from gismo.startup import Startup from modules.boot_log.boot_log_extension import BootLogExtension +from modules.database.database_extension import DatabaseExtension class Main: @@ -15,6 +16,7 @@ class Main: async def main(self): app_builder = ApplicationBuilder(Gismo) + app_builder.use_extension(DatabaseExtension) app_builder.use_extension(BootLogExtension) app_builder.use_startup(Startup) self._gismo: Gismo = await app_builder.build_async() @@ -31,7 +33,6 @@ if __name__ == '__main__': ml.run_until_complete(main.main()) except KeyboardInterrupt: ml.run_until_complete(main.stop()) - # (( # ( `) @@ -41,4 +42,4 @@ if __name__ == '__main__': # / ~/ # / ) ) ~ edraft # ___// | / -# `--' \_~-, \ No newline at end of file +# `--' \_~-, diff --git a/src/gismo/startup.py b/src/gismo/startup.py index a09983e..f359a18 100644 --- a/src/gismo/startup.py +++ b/src/gismo/startup.py @@ -19,6 +19,7 @@ from gismo_data.db_context import DBContext from gismo_data.service.server_repository_service import ServerRepositoryService from gismo_data.service.user_repository_service import UserRepositoryService from modules.boot_log.boot_log import BootLog +from modules.database.database import Database from modules_core.abc.module_abc import ModuleABC from modules_core.abc.module_service_abc import ModuleServiceABC from modules_core.service.module_service import ModuleService @@ -57,6 +58,7 @@ class Startup(StartupABC): services.add_transient(ServerRepositoryABC, ServerRepositoryService) services.add_transient(UserRepositoryABC, UserRepositoryService) + services.add_transient(ModuleABC, Database) services.add_transient(ModuleABC, BootLog) provider: ServiceProviderABC = services.build_service_provider() diff --git a/src/gismo_data/abc/server_repository_abc.py b/src/gismo_data/abc/server_repository_abc.py index 2c082f1..e8bfd38 100644 --- a/src/gismo_data/abc/server_repository_abc.py +++ b/src/gismo_data/abc/server_repository_abc.py @@ -24,7 +24,7 @@ class ServerRepositoryABC(ABC): def find_server_by_discord_id(self, discord_id: int) -> Optional[Server]: pass @abstractmethod - def add_server(self, server: Server) -> int: pass + def add_server(self, server: Server): pass @abstractmethod def update_server(self, server: Server): pass diff --git a/src/gismo_data/abc/user_repository_abc.py b/src/gismo_data/abc/user_repository_abc.py index aa40998..c198753 100644 --- a/src/gismo_data/abc/user_repository_abc.py +++ b/src/gismo_data/abc/user_repository_abc.py @@ -24,7 +24,7 @@ class UserRepositoryABC(ABC): def find_user_by_discord_id(self, discord_id: int) -> Optional[User]: pass @abstractmethod - def add_user(self, user: User) -> int: pass + def add_user(self, user: User): pass @abstractmethod def update_user(self, user: User): pass diff --git a/src/gismo_data/db_context.py b/src/gismo_data/db_context.py index 77ea0b1..aae38ca 100644 --- a/src/gismo_data/db_context.py +++ b/src/gismo_data/db_context.py @@ -29,10 +29,10 @@ class DBContext(DatabaseContext): def save_changes(self): try: self._logger.trace(__name__, "Save changes") - super(DatabaseContext, self).save_changes + super(DBContext, self).save_changes() self._logger.debug(__name__, "Saved changes") except Exception as e: self._logger.error(__name__, "Saving changes failed", e) def select(self, statement: str) -> list[tuple]: - return super(DatabaseContext, self).select(statement) \ No newline at end of file + return super(DBContext, self).select(statement) \ No newline at end of file diff --git a/src/gismo_data/model/server.py b/src/gismo_data/model/server.py index ea6626f..207bc35 100644 --- a/src/gismo_data/model/server.py +++ b/src/gismo_data/model/server.py @@ -1,13 +1,19 @@ +from datetime import datetime from typing import Optional + from cpl_core.database import TableABC class Server(TableABC): - def __init__(self, dc_id: int, id=0): + def __init__(self, dc_id: int, created_at: datetime=None, modified_at: datetime=None, id=0): self._server_id = id self._discord_server_id = dc_id + TableABC.__init__(self) + self._created_at = created_at if created_at is not None else self._created_at + self._modified_at = modified_at if modified_at is not None else self._modified_at + @property def server_id(self) -> int: return self._server_id @@ -20,8 +26,10 @@ class Server(TableABC): def get_create_string() -> str: return str(f""" CREATE TABLE IF NOT EXISTS `Servers` ( - `ServerId` INT(30) NOT NULL AUTO_INCREMENT, - `DiscordServerId` INT(30) NOT NULL, + `ServerId` BIGINT NOT NULL AUTO_INCREMENT, + `DiscordServerId` BIGINT NOT NULL, + `CreatedAt` DATETIME(6), + `LastModifiedAt` DATETIME(6), PRIMARY KEY(`ServerId`) ); """) diff --git a/src/gismo_data/model/user.py b/src/gismo_data/model/user.py index da75bae..3bbc27a 100644 --- a/src/gismo_data/model/user.py +++ b/src/gismo_data/model/user.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Optional from cpl_core.database import TableABC @@ -6,11 +7,15 @@ from gismo_data.model.server import Server class User(TableABC): - def __init__(self, dc_id: int, xp: int, server: Optional[Server], id=0): + def __init__(self, dc_id: int, xp: int, server: Optional[Server], created_at: datetime = None, modified_at: datetime = None, id=0): self._user_id = id self._discord_id = dc_id self._xp = xp self._server = server + + TableABC.__init__(self) + self._created_at = created_at if created_at is not None else self._created_at + self._modified_at = modified_at if modified_at is not None else self._modified_at @property def user_id(self) -> int: @@ -26,6 +31,7 @@ class User(TableABC): @xp.setter def xp(self, value: int): + self._modified_at = datetime.now().isoformat() self._xp = value @property @@ -42,7 +48,7 @@ class User(TableABC): `ServerId` BIGINT, `CreatedAt` DATETIME(6), `LastModifiedAt` DATETIME(6), - FOREIGN KEY (`UserId`) REFERENCES Servers(`ServerId`), + FOREIGN KEY (`ServerId`) REFERENCES Servers(`ServerId`), PRIMARY KEY(`UserId`) ); """) diff --git a/src/gismo_data/service/server_repository_service.py b/src/gismo_data/service/server_repository_service.py index 719fac8..cac5685 100644 --- a/src/gismo_data/service/server_repository_service.py +++ b/src/gismo_data/service/server_repository_service.py @@ -29,7 +29,7 @@ class ServerRepositoryService(ServerRepositoryABC): def get_server_by_id(self, id: int) -> Server: self._logger.trace(__name__, f'Send SQL command: {Server.get_select_by_id_string(id)}') - result = self._context.select(Server.get_select_by_id_string(id)) + result = self._context.select(Server.get_select_by_id_string(id))[0] return Server( result[1], id=result[0] @@ -49,15 +49,18 @@ class ServerRepositoryService(ServerRepositoryABC): if result is None or len(result) == 0: return None + result = result[0] + return Server( result[1], + result[2], + result[3], id=result[0] ) - def add_server(self, server: Server) -> int: + def add_server(self, server: Server): self._logger.trace(__name__, f'Send SQL command: {server.insert_string}') self._context.cursor.execute(server.insert_string) - return int(self._context.select("SELECT LAST_INSERT_ID();")[0]) def update_server(self, server: Server): self._logger.trace(__name__, f'Send SQL command: {server.udpate_string}') diff --git a/src/gismo_data/service/user_repository_service.py b/src/gismo_data/service/user_repository_service.py index 93b7460..6854313 100644 --- a/src/gismo_data/service/user_repository_service.py +++ b/src/gismo_data/service/user_repository_service.py @@ -23,6 +23,7 @@ class UserRepositoryService(UserRepositoryABC): self._logger.trace(__name__, f'Send SQL command: {User.get_select_all_string()}') results = self._context.select(User.get_select_all_string()) for result in results: + self._logger.trace(__name__, f'Get user with id {result[0]}') users.append(User( result[1], result[2], @@ -44,7 +45,7 @@ class UserRepositoryService(UserRepositoryABC): def get_user_by_discord_id(self, discord_id: int) -> User: self._logger.trace(__name__, f'Send SQL command: {User.get_select_by_discord_id_string(discord_id)}') - result = self._context.select(User.get_select_by_discord_id_string(discord_id)) + result = self._context.select(User.get_select_by_discord_id_string(discord_id))[0] return User( result[1], result[2], @@ -58,17 +59,20 @@ class UserRepositoryService(UserRepositoryABC): if result is None or len(result) == 0: return None + result = result[0] + return User( result[1], result[2], + result[3], + result[4], self._servers.get_server_by_id(result[3]), id=result[0] ) - def add_user(self, user: User) -> int: - self._logger.trace(__name__, f'Send SQL command: {user.insert_strin}') + def add_user(self, user: User): + self._logger.trace(__name__, f'Send SQL command: {user.insert_string}') self._context.cursor.execute(user.insert_string) - return int(self._context.select("SELECT LAST_INSERT_ID();")[0]) def update_user(self, user: User): self._logger.trace(__name__, f'Send SQL command: {user.udpate_string}') diff --git a/src/modules/database/__init__.py b/src/modules/database/__init__.py new file mode 100644 index 0000000..ad5eca3 --- /dev/null +++ b/src/modules/database/__init__.py @@ -0,0 +1 @@ +# imports: diff --git a/src/modules/database/database.json b/src/modules/database/database.json new file mode 100644 index 0000000..dfd40a3 --- /dev/null +++ b/src/modules/database/database.json @@ -0,0 +1,43 @@ +{ + "ProjectSettings": { + "Name": "modules/database", + "Version": { + "Major": "0", + "Minor": "0", + "Micro": "0" + }, + "Author": "", + "AuthorEmail": "", + "Description": "", + "LongDescription": "", + "URL": "", + "CopyrightDate": "", + "CopyrightName": "", + "LicenseName": "", + "LicenseDescription": "", + "Dependencies": [ + "sh_cpl-core>=2021.11.0.post1" + ], + "PythonVersion": ">=3.9.2", + "PythonPath": { + "linux": "" + }, + "Classifiers": [] + }, + "BuildSettings": { + "ProjectType": "library", + "SourcePath": "", + "OutputPath": "../../dist", + "Main": "modules/database.main", + "EntryPoint": "modules/database", + "IncludePackageData": false, + "Included": [], + "Excluded": [ + "*/__pycache__", + "*/logs", + "*/tests" + ], + "PackageData": {}, + "ProjectReferences": [] + } +} \ No newline at end of file diff --git a/src/modules/database/database.py b/src/modules/database/database.py new file mode 100644 index 0000000..7dd1569 --- /dev/null +++ b/src/modules/database/database.py @@ -0,0 +1,120 @@ +import asyncio +from datetime import datetime +import time + +import discord +from cpl_core.configuration import ConfigurationABC +from cpl_core.database.context import DatabaseContextABC +from cpl_core.logging import LoggerABC + +from gismo_core.abc.bot_service_abc import BotServiceABC +from gismo_core.abc.message_service_abc import MessageServiceABC +from gismo_core.configuration.server_settings import ServerSettings +from gismo_data.abc.user_repository_abc import UserRepositoryABC +from gismo_data.model.server import Server +from gismo_data.model.user import User +from gismo_data.service.user_repository_service import ServerRepositoryABC +from modules_core.abc.events.on_ready_abc import OnReadyABC +from modules_core.abc.module_abc import ModuleABC + + +class Database(ModuleABC, OnReadyABC): + + def __init__( + self, + config: ConfigurationABC, + logger: LoggerABC, + bot: BotServiceABC, + db_context: DatabaseContextABC, + server_repo: ServerRepositoryABC, + user_repo: UserRepositoryABC + ): + self._config = config + + self._logger = logger + self._bot = bot + self._db_context = db_context + self._servers = server_repo + self._users = user_repo + + ModuleABC.__init__(self) + self._priorities[OnReadyABC] = 0 + self._logger.trace(__name__, f'Module {type(self)} loaded') + + def _validate_init_time(self): + try: + start_time = self._config.get_configuration('Database_StartTime') + init_time = round((datetime.now() - start_time).total_seconds(), 2) + self._config.add_configuration('Database_InitTime', init_time) + self._logger.debug(__name__, f'Database Init time: {init_time}s') + # print warning if initialisation took too long + if init_time >= 30: + self._logger.warn( + __name__, 'It takes long time to start the bot!') + + # print error if initialisation took way too long + elif init_time >= 90: + self._logger.error( + __name__, 'It takes very long time to start the bot!!!') + except Exception as e:# + self._logger.error(__name__, 'Database init time calculation failed', e) + return + + def _check_servers(self): + for g in self._bot.guilds: + g: discord.Guild = g + try: + server = self._servers.find_server_by_discord_id(g.id) + if server is not None: + return + + self._logger.warn(__name__, f'Server not found in database: {g.id}') + self._logger.debug(__name__, f'Add server: {g.id}') + self._servers.add_server(Server(g.id)) + self._db_context.save_changes() + + self._logger.debug(__name__, f'Added server: {g.id}') + except Exception as e: + self._logger.error(__name__, f'Cannot get server', e) + + results = self._servers.get_servers() + if results is None or len(results) == 0: + self._logger.error(__name__, f'Table Servers is empty!') + + def _check_users(self): + for g in self._bot.guilds: + g: discord.Guild = g + + try: + server = self._servers.find_server_by_discord_id(g.id) + if server is None: + self._logger.fatal(__name__, f'Server not found in database: {g.id}') + break + + for u in g.members: + u: discord.Member = u + user = self._users.find_user_by_discord_id(u.id) + if user is not None: + break + + self._logger.warn(__name__, f'User not found in database: {u.id}') + self._logger.debug(__name__, f'Add user: {u.id}') + self._users.add_user(User(u.id, 0, server)) + self._db_context.save_changes() + + self._logger.debug(__name__, f'Added User: {u.id}') + except Exception as e: + self._logger.error(__name__, f'Cannot get User', e) + + results = self._users.get_users() + if results is None or len(results) == 0: + self._logger.error(__name__, f'Table Users is empty!') + + async def on_ready(self): + self._logger.debug(__name__, f'Module {type(self)} started') + + self._check_servers() + self._check_users() + + self._validate_init_time() + self._logger.trace(__name__, f'Module {type(self)} stopped') \ No newline at end of file diff --git a/src/modules/database/database_extension.py b/src/modules/database/database_extension.py new file mode 100644 index 0000000..2fee14a --- /dev/null +++ b/src/modules/database/database_extension.py @@ -0,0 +1,17 @@ +from datetime import datetime + +from cpl_core.application.application_extension_abc import ApplicationExtensionABC +from cpl_core.configuration import ConfigurationABC +from cpl_core.dependency_injection import ServiceProviderABC +from cpl_core.logging import LoggerABC + + +class DatabaseExtension(ApplicationExtensionABC): + + def __init__(self): + pass + + async def run(self, config: ConfigurationABC, services: ServiceProviderABC): + logger: LoggerABC = services.get_service(LoggerABC) + logger.debug(__name__, 'Database extension started') + config.add_configuration('Database_StartTime', datetime.now())