Use db pools
This commit is contained in:
parent
dd4b9182f3
commit
995c498e54
50
bot/src/bot_data/db_connection.py
Normal file
50
bot/src/bot_data/db_connection.py
Normal file
@ -0,0 +1,50 @@
|
||||
from typing import Optional
|
||||
|
||||
from cpl_core.database import DatabaseSettings
|
||||
from cpl_core.database.connection import DatabaseConnectionABC
|
||||
from mysql.connector.abstracts import MySQLConnectionAbstract
|
||||
from mysql.connector.cursor import MySQLCursorBuffered
|
||||
|
||||
|
||||
class DBConnection(DatabaseConnectionABC):
|
||||
def __init__(self):
|
||||
DatabaseConnectionABC.__init__(self)
|
||||
|
||||
self._database: Optional[MySQLConnectionAbstract] = None
|
||||
self._cursor: Optional[MySQLCursorBuffered] = None
|
||||
|
||||
@property
|
||||
def server(self) -> MySQLConnectionAbstract:
|
||||
return self._database
|
||||
|
||||
@property
|
||||
def cursor(self) -> MySQLCursorBuffered:
|
||||
return self._cursor
|
||||
|
||||
def connect(self, settings: DatabaseSettings):
|
||||
# connection = sql.connect(
|
||||
# host=settings.host,
|
||||
# port=settings.port,
|
||||
# user=settings.user,
|
||||
# passwd=CredentialManager.decrypt(settings.password),
|
||||
# charset=settings.charset,
|
||||
# use_unicode=settings.use_unicode,
|
||||
# buffered=settings.buffered,
|
||||
# auth_plugin=settings.auth_plugin,
|
||||
# ssl_disabled=settings.ssl_disabled,
|
||||
# )
|
||||
# connection.cursor().execute(f"CREATE DATABASE IF NOT EXISTS `{settings.database}`;")
|
||||
# self._database = sql.connect(
|
||||
# host=settings.host,
|
||||
# port=settings.port,
|
||||
# user=settings.user,
|
||||
# passwd=CredentialManager.decrypt(settings.password),
|
||||
# db=settings.database,
|
||||
# charset=settings.charset,
|
||||
# use_unicode=settings.use_unicode,
|
||||
# buffered=settings.buffered,
|
||||
# auth_plugin=settings.auth_plugin,
|
||||
# ssl_disabled=settings.ssl_disabled,
|
||||
# )
|
||||
self._
|
||||
self._cursor = self._database.cursor()
|
@ -1,4 +1,3 @@
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from cpl_core.database import DatabaseSettings
|
||||
@ -7,6 +6,7 @@ from cpl_core.database.context import DatabaseContext
|
||||
from bot_core.exception.service_error_code_enum import ServiceErrorCode
|
||||
from bot_core.exception.service_exception import ServiceException
|
||||
from bot_core.logging.database_logger import DatabaseLogger
|
||||
from bot_data.mysql_pool import MySQLPool
|
||||
|
||||
|
||||
class DBContext(DatabaseContext):
|
||||
@ -14,27 +14,28 @@ class DBContext(DatabaseContext):
|
||||
self._logger = logger
|
||||
|
||||
DatabaseContext.__init__(self)
|
||||
self._pool: MySQLPool = None
|
||||
self._fails = 0
|
||||
|
||||
def connect(self, database_settings: DatabaseSettings):
|
||||
try:
|
||||
self._logger.debug(__name__, "Connecting to database")
|
||||
self._db.connect(database_settings)
|
||||
self._pool = MySQLPool(database_settings)
|
||||
self._pool.execute(f"CREATE DATABASE IF NOT EXISTS `{database_settings.database}`;", commit=True)
|
||||
self._logger.info(__name__, "Connected to database")
|
||||
except Exception as e:
|
||||
self._logger.fatal(__name__, "Connecting to database failed", e)
|
||||
|
||||
@property
|
||||
def cursor(self):
|
||||
return self
|
||||
|
||||
def save_changes(self):
|
||||
try:
|
||||
self._logger.trace(__name__, "Save changes")
|
||||
super(DBContext, self).save_changes()
|
||||
self._logger.debug(__name__, "Saved changes")
|
||||
except Exception as e:
|
||||
self._logger.error(__name__, "Saving changes failed", e)
|
||||
pass
|
||||
|
||||
def select(self, statement: str) -> list[tuple]:
|
||||
try:
|
||||
return super(DBContext, self).select(statement)
|
||||
return self._pool.execute(statement)
|
||||
except Exception as e:
|
||||
if self._fails >= 3:
|
||||
self._logger.error(__name__, f"Database error caused by {statement}", e)
|
||||
@ -47,9 +48,11 @@ class DBContext(DatabaseContext):
|
||||
self._logger.error(__name__, f"Database error caused by {statement}", e)
|
||||
self._fails += 1
|
||||
try:
|
||||
time.sleep(0.5)
|
||||
self._logger.debug(__name__, "Retry select")
|
||||
return self.select(statement)
|
||||
except Exception as e:
|
||||
pass
|
||||
return []
|
||||
|
||||
def execute(self, statement: str):
|
||||
return self._pool.execute(statement, commit=True)
|
||||
|
104
bot/src/bot_data/mysql_pool.py
Normal file
104
bot/src/bot_data/mysql_pool.py
Normal file
@ -0,0 +1,104 @@
|
||||
# https://stackoverflow.com/questions/32658679/how-to-create-a-mysql-connection-pool-or-any-better-way-to-initialize-the-multip
|
||||
import mysql.connector as sql
|
||||
from cpl_core.database import DatabaseSettings
|
||||
from cpl_core.utils import CredentialManager
|
||||
|
||||
|
||||
class MySQLPool(object):
|
||||
"""
|
||||
create a pool when connect mysql, which will decrease the time spent in
|
||||
request connection, create connection and close connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database_settings: DatabaseSettings,
|
||||
pool_size=5,
|
||||
):
|
||||
res = {
|
||||
"host": database_settings.host,
|
||||
"port": database_settings.port,
|
||||
"user": database_settings.user,
|
||||
"password": CredentialManager.decrypt(database_settings.password),
|
||||
"database": database_settings.database,
|
||||
}
|
||||
|
||||
self.dbconfig = res
|
||||
self.pool = self.create_pool(pool_name="MySqlPool", pool_size=pool_size)
|
||||
|
||||
def create_pool(self, pool_name="MySqlPool", pool_size=3):
|
||||
"""
|
||||
Create a connection pool, after created, the request of connecting
|
||||
MySQL could get a connection from this pool instead of request to
|
||||
create a connection.
|
||||
:param pool_name: the name of pool, default is "mypool"
|
||||
:param pool_size: the size of pool, default is 3
|
||||
:return: connection pool
|
||||
"""
|
||||
pool = sql.pooling.MySQLConnectionPool(
|
||||
pool_name=pool_name, pool_size=pool_size, pool_reset_session=True, **self.dbconfig
|
||||
)
|
||||
return pool
|
||||
|
||||
def close(self, conn, cursor):
|
||||
"""
|
||||
A method used to close connection of mysql.
|
||||
:param conn:
|
||||
:param cursor:
|
||||
:return:
|
||||
"""
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
def execute(self, sql, args=None, commit=False):
|
||||
"""
|
||||
Execute a sql, it could be with args and with out args. The usage is
|
||||
similar with execute() function in module pymysql.
|
||||
:param sql: sql clause
|
||||
:param args: args need by sql clause
|
||||
:param commit: whether to commit
|
||||
:return: if commit, return None, else, return result
|
||||
"""
|
||||
# get connection form connection pool instead of create one.
|
||||
conn = self.pool.get_connection()
|
||||
cursor = conn.cursor()
|
||||
if args:
|
||||
cursor.execute(sql, args)
|
||||
else:
|
||||
cursor.execute(sql)
|
||||
if commit is True:
|
||||
conn.commit()
|
||||
self.close(conn, cursor)
|
||||
return None
|
||||
else:
|
||||
res = cursor.fetchall()
|
||||
self.close(conn, cursor)
|
||||
return res
|
||||
|
||||
def executemany(self, sql, args, commit=False):
|
||||
"""
|
||||
Execute with many args. Similar with executemany() function in pymysql.
|
||||
args should be a sequence.
|
||||
:param sql: sql clause
|
||||
:param args: args
|
||||
:param commit: commit or not.
|
||||
:return: if commit, return None, else, return result
|
||||
"""
|
||||
# get connection form connection pool instead of create one.
|
||||
conn = self.pool.get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.executemany(sql, args)
|
||||
if commit is True:
|
||||
conn.commit()
|
||||
self.close(conn, cursor)
|
||||
return None
|
||||
else:
|
||||
res = cursor.fetchall()
|
||||
self.close(conn, cursor)
|
||||
return res
|
||||
|
||||
def commit(self):
|
||||
conn = self.pool.get_connection()
|
||||
conn.commit()
|
||||
cursor = conn.cursor()
|
||||
self.close(conn, cursor)
|
@ -1,13 +1,13 @@
|
||||
import glob
|
||||
import os
|
||||
|
||||
from cpl_core.database.context import DatabaseContextABC
|
||||
from cpl_core.dependency_injection import ServiceProviderABC
|
||||
from cpl_query.extension import List
|
||||
from packaging import version
|
||||
|
||||
import bot
|
||||
from bot_core.logging.database_logger import DatabaseLogger
|
||||
from bot_data.db_context import DBContext
|
||||
from bot_data.model.migration import Migration
|
||||
from bot_data.model.migration_history import MigrationHistory
|
||||
|
||||
@ -17,13 +17,12 @@ class MigrationService:
|
||||
self,
|
||||
logger: DatabaseLogger,
|
||||
services: ServiceProviderABC,
|
||||
db: DatabaseContextABC,
|
||||
db: DBContext,
|
||||
):
|
||||
self._logger = logger
|
||||
self._services = services
|
||||
|
||||
self._db = db
|
||||
self._cursor = db.cursor
|
||||
|
||||
def _get_migration_history(self) -> List[MigrationHistory]:
|
||||
results = self._db.select(
|
||||
@ -42,7 +41,7 @@ class MigrationService:
|
||||
return
|
||||
|
||||
self._logger.debug(__name__, f"Migrate new migration {migration.migration_id} to old method")
|
||||
self._cursor.execute(migration.change_id_string(f"{migration.migration_id}Migration"))
|
||||
self._db.execute(migration.change_id_string(f"{migration.migration_id}Migration"))
|
||||
self._db.save_changes()
|
||||
|
||||
def _migration_migrations_to_new(self, migration: MigrationHistory):
|
||||
@ -50,12 +49,11 @@ class MigrationService:
|
||||
return
|
||||
|
||||
self._logger.debug(__name__, f"Migrate old migration {migration.migration_id} to new method")
|
||||
self._cursor.execute(migration.change_id_string(migration.migration_id.replace("Migration", "")))
|
||||
self._db.execute(migration.change_id_string(migration.migration_id.replace("Migration", "")))
|
||||
self._db.save_changes()
|
||||
|
||||
def _migrate_from_old_to_new(self):
|
||||
self._cursor.execute("SHOW TABLES LIKE 'MigrationHistory'")
|
||||
result = self._cursor.fetchone()
|
||||
result = self._db.select("SHOW TABLES LIKE 'MigrationHistory'")
|
||||
if not result:
|
||||
return
|
||||
|
||||
@ -120,8 +118,7 @@ class MigrationService:
|
||||
active_statement = ""
|
||||
try:
|
||||
# check if table exists
|
||||
self._cursor.execute("SHOW TABLES LIKE 'MigrationHistory'")
|
||||
result = self._cursor.fetchone()
|
||||
result = self._db.select("SHOW TABLES LIKE 'MigrationHistory'")
|
||||
if result:
|
||||
# there is a table named "tableName"
|
||||
self._logger.trace(
|
||||
@ -142,9 +139,9 @@ class MigrationService:
|
||||
if statement in ["", "\n"]:
|
||||
continue
|
||||
active_statement = statement
|
||||
self._cursor.execute(statement + ";")
|
||||
self._db.execute(statement + ";")
|
||||
|
||||
self._cursor.execute(
|
||||
self._db.execute(
|
||||
MigrationHistory(migration.name).insert_string
|
||||
if upgrade
|
||||
else MigrationHistory(migration.name).delete_string
|
||||
|
@ -2,6 +2,7 @@ from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from cpl_core.database.context import DatabaseContextABC
|
||||
from cpl_core.time import TimeFormatSettings
|
||||
from cpl_query.extension import List
|
||||
|
||||
from bot_core.logging.database_logger import DatabaseLogger
|
||||
@ -15,12 +16,14 @@ from bot_data.model.user_message_count_per_hour import UserMessageCountPerHour
|
||||
class UserMessageCountPerHourRepositoryService(UserMessageCountPerHourRepositoryABC):
|
||||
def __init__(
|
||||
self,
|
||||
time_format: TimeFormatSettings,
|
||||
logger: DatabaseLogger,
|
||||
db_context: DatabaseContextABC,
|
||||
users: UserRepositoryABC,
|
||||
):
|
||||
UserMessageCountPerHourRepositoryABC.__init__(self)
|
||||
|
||||
self._time_format = time_format
|
||||
self._logger = logger
|
||||
self._context = db_context
|
||||
self._users = users
|
||||
@ -67,7 +70,12 @@ class UserMessageCountPerHourRepositoryService(UserMessageCountPerHourRepository
|
||||
) -> UserMessageCountPerHour:
|
||||
sql = UserMessageCountPerHour.get_select_by_user_id_and_date_string(user_id, date)
|
||||
self._logger.trace(__name__, f"Send SQL command: {sql}")
|
||||
return self._from_result(self._context.select(sql)[0])
|
||||
res = self._context.select(sql)
|
||||
if len(res) > 0:
|
||||
return self._from_result(res[0])
|
||||
|
||||
user = self._users.get_user_by_id(user_id)
|
||||
return UserMessageCountPerHour(date.strftime(self._time_format.date_time_format), date.hour, 0, user)
|
||||
|
||||
def find_user_message_count_per_hour_by_user_id_and_date(
|
||||
self, user_id: int, date: datetime
|
||||
|
Loading…
Reference in New Issue
Block a user