Middleware updated & Fixed mysql pool
All checks were successful
Build on push / prepare (push) Successful in 10s
Build on push / core (push) Successful in 19s
Build on push / query (push) Successful in 18s
Build on push / dependency (push) Successful in 17s
Build on push / database (push) Successful in 15s
Build on push / translation (push) Successful in 18s
Build on push / mail (push) Successful in 19s
Build on push / application (push) Successful in 21s
Build on push / auth (push) Successful in 14s
Build on push / api (push) Successful in 14s
All checks were successful
Build on push / prepare (push) Successful in 10s
Build on push / core (push) Successful in 19s
Build on push / query (push) Successful in 18s
Build on push / dependency (push) Successful in 17s
Build on push / database (push) Successful in 15s
Build on push / translation (push) Successful in 18s
Build on push / mail (push) Successful in 19s
Build on push / application (push) Successful in 21s
Build on push / auth (push) Successful in 14s
Build on push / api (push) Successful in 14s
This commit is contained in:
0
src/cpl-api/cpl/api/abc/__init__.py
Normal file
0
src/cpl-api/cpl/api/abc/__init__.py
Normal file
15
src/cpl-api/cpl/api/abc/asgi_middleware_abc.py
Normal file
15
src/cpl-api/cpl/api/abc/asgi_middleware_abc.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from starlette.types import Scope, Receive, Send
|
||||||
|
|
||||||
|
|
||||||
|
class ASGIMiddleware(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, app):
|
||||||
|
self._app = app
|
||||||
|
|
||||||
|
def _call_next(self, scope: Scope, receive: Receive, send: Send):
|
||||||
|
return self._app(scope, receive, send)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send): ...
|
||||||
@@ -1,11 +1,17 @@
|
|||||||
from http.client import HTTPException
|
from http.client import HTTPException
|
||||||
|
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.types import Scope, Receive, Send
|
||||||
|
|
||||||
|
|
||||||
class APIError(HTTPException):
|
class APIError(HTTPException):
|
||||||
status_code = 500
|
status_code = 500
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def asgi_response(cls, scope: Scope, receive: Receive, send: Send):
|
||||||
|
r = JSONResponse({"error": cls.__name__}, status_code=cls.status_code)
|
||||||
|
return await r(scope, receive, send)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def response(cls):
|
def response(cls):
|
||||||
return JSONResponse({"error": cls.__name__}, status_code=cls.status_code)
|
return JSONResponse({"error": cls.__name__}, status_code=cls.status_code)
|
||||||
|
|||||||
@@ -1,21 +1,28 @@
|
|||||||
from keycloak import KeycloakAuthenticationError
|
from keycloak import KeycloakAuthenticationError
|
||||||
|
from starlette.types import Scope, Receive, Send
|
||||||
|
|
||||||
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
from cpl.api.api_logger import APILogger
|
from cpl.api.api_logger import APILogger
|
||||||
from cpl.api.error import Unauthorized
|
from cpl.api.error import Unauthorized
|
||||||
from cpl.api.middleware.request import get_request
|
from cpl.api.middleware.request import get_request
|
||||||
from cpl.api.router import Router
|
from cpl.api.router import Router
|
||||||
from cpl.auth.keycloak import KeycloakClient
|
from cpl.auth.keycloak import KeycloakClient
|
||||||
|
from cpl.auth.schema import AuthUserDao, AuthUser
|
||||||
from cpl.dependency import ServiceProviderABC
|
from cpl.dependency import ServiceProviderABC
|
||||||
|
|
||||||
_logger = APILogger(__name__)
|
_logger = APILogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationMiddleware:
|
class AuthenticationMiddleware(ASGIMiddleware):
|
||||||
|
|
||||||
def __init__(self, app):
|
@ServiceProviderABC.inject
|
||||||
self._app = app
|
def __init__(self, app, keycloak: KeycloakClient, user_dao: AuthUserDao):
|
||||||
|
ASGIMiddleware.__init__(self, app)
|
||||||
|
|
||||||
async def __call__(self, scope, receive, send):
|
self._keycloak = keycloak
|
||||||
|
self._user_dao = user_dao
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||||
request = get_request()
|
request = get_request()
|
||||||
url = request.url.path
|
url = request.url.path
|
||||||
|
|
||||||
@@ -25,27 +32,41 @@ class AuthenticationMiddleware:
|
|||||||
|
|
||||||
if not request.headers.get("Authorization"):
|
if not request.headers.get("Authorization"):
|
||||||
_logger.debug(f"Unauthorized access to {url}, missing Authorization header")
|
_logger.debug(f"Unauthorized access to {url}, missing Authorization header")
|
||||||
return Unauthorized(f"Missing header Authorization").response()
|
return await Unauthorized(f"Missing header Authorization").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
auth_header = request.headers.get("Authorization", None)
|
auth_header = request.headers.get("Authorization", None)
|
||||||
if not auth_header or not auth_header.startswith("Bearer "):
|
if not auth_header or not auth_header.startswith("Bearer "):
|
||||||
return Unauthorized("Invalid Authorization header").response()
|
return await Unauthorized("Invalid Authorization header").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
token = auth_header.split("Bearer ")[1]
|
token = auth_header.split("Bearer ")[1]
|
||||||
if not await self._verify_login(token):
|
if not await self._verify_login(token):
|
||||||
_logger.debug(f"Unauthorized access to {url}, invalid token")
|
_logger.debug(f"Unauthorized access to {url}, invalid token")
|
||||||
return Unauthorized("Invalid token").response()
|
return await Unauthorized("Invalid token").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
# check user exists in db, if not create
|
# check user exists in db, if not create
|
||||||
# unauthorized if user is deleted
|
keycloak_id = self._keycloak.get_user_id(token)
|
||||||
return await self._app(scope, receive, send)
|
if keycloak_id is None:
|
||||||
|
return await Unauthorized("Failed to get user id from token").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
@classmethod
|
user = await self._get_or_crate_user(keycloak_id)
|
||||||
async def _verify_login(cls, token: str) -> bool:
|
if user.deleted:
|
||||||
keycloak = ServiceProviderABC.get_global_service(KeycloakClient)
|
_logger.debug(f"Unauthorized access to {url}, user is deleted")
|
||||||
|
return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
|
return await self._call_next(scope, receive, send)
|
||||||
|
|
||||||
|
async def _get_or_crate_user(self, keycloak_id: str) -> AuthUser:
|
||||||
|
existing = await self._user_dao.find_by_keycloak_id(keycloak_id)
|
||||||
|
if existing is not None:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
user = AuthUser(0, keycloak_id)
|
||||||
|
uid = await self._user_dao.create(user)
|
||||||
|
return await self._user_dao.get_by_id(uid)
|
||||||
|
|
||||||
|
async def _verify_login(self, token: str) -> bool:
|
||||||
try:
|
try:
|
||||||
token_info = keycloak.introspect(token)
|
token_info = self._keycloak.introspect(token)
|
||||||
return token_info.get("active", False)
|
return token_info.get("active", False)
|
||||||
except KeycloakAuthenticationError as e:
|
except KeycloakAuthenticationError as e:
|
||||||
_logger.debug(f"Keycloak authentication error: {e}")
|
_logger.debug(f"Keycloak authentication error: {e}")
|
||||||
|
|||||||
@@ -1,20 +1,22 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
from starlette.types import Receive, Scope, Send
|
||||||
|
|
||||||
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
from cpl.api.api_logger import APILogger
|
from cpl.api.api_logger import APILogger
|
||||||
from cpl.api.middleware.request import get_request
|
from cpl.api.middleware.request import get_request
|
||||||
|
|
||||||
_logger = APILogger(__name__)
|
_logger = APILogger(__name__)
|
||||||
|
|
||||||
class LoggingMiddleware:
|
class LoggingMiddleware(ASGIMiddleware):
|
||||||
def __init__(self, app: ASGIApp):
|
|
||||||
self.app = app
|
def __init__(self, app):
|
||||||
|
ASGIMiddleware.__init__(self, app)
|
||||||
|
|
||||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||||
if scope["type"] != "http":
|
if scope["type"] != "http":
|
||||||
await self.app(scope, receive, send)
|
await self._call_next(scope, receive, send)
|
||||||
return
|
return
|
||||||
|
|
||||||
request = get_request()
|
request = get_request()
|
||||||
@@ -32,7 +34,7 @@ class LoggingMiddleware:
|
|||||||
response_body += message.get("body", b"")
|
response_body += message.get("body", b"")
|
||||||
await send(message)
|
await send(message)
|
||||||
|
|
||||||
await self.app(scope, receive, send_wrapper)
|
await self._call_next(scope, receive, send_wrapper)
|
||||||
|
|
||||||
duration = (time.time() - start_time) * 1000
|
duration = (time.time() - start_time) * 1000
|
||||||
await self._log_after_request(request, status_code, duration)
|
await self._log_after_request(request, status_code, duration)
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ from typing import Optional, Union
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
from starlette.types import Scope, Receive, Send
|
||||||
from starlette.websockets import WebSocket
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
from cpl.api.api_logger import APILogger
|
from cpl.api.api_logger import APILogger
|
||||||
from cpl.api.typing import TRequest
|
from cpl.api.typing import TRequest
|
||||||
|
|
||||||
@@ -14,12 +16,13 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
|
|||||||
_logger = APILogger(__name__)
|
_logger = APILogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RequestMiddleware:
|
class RequestMiddleware(ASGIMiddleware):
|
||||||
|
|
||||||
def __init__(self, app):
|
def __init__(self, app):
|
||||||
self._app = app
|
ASGIMiddleware.__init__(self, app)
|
||||||
self._ctx_token = None
|
self._ctx_token = None
|
||||||
|
|
||||||
async def __call__(self, scope, receive, send):
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||||
request = Request(scope, receive, send)
|
request = Request(scope, receive, send)
|
||||||
await self.set_request_data(request)
|
await self.set_request_data(request)
|
||||||
|
|
||||||
|
|||||||
@@ -66,6 +66,10 @@ class WebApp(ApplicationABC):
|
|||||||
_logger.debug(f"Allowed origins: {origins}")
|
_logger.debug(f"Allowed origins: {origins}")
|
||||||
return origins.split(",")
|
return origins.split(",")
|
||||||
|
|
||||||
|
def with_database(self):
|
||||||
|
self.with_migrations()
|
||||||
|
self.with_seeders()
|
||||||
|
|
||||||
def with_app(self, app: Starlette):
|
def with_app(self, app: Starlette):
|
||||||
assert app is not None, "app must not be None"
|
assert app is not None, "app must not be None"
|
||||||
assert isinstance(app, Starlette), "app must be an instance of Starlette"
|
assert isinstance(app, Starlette), "app must be an instance of Starlette"
|
||||||
@@ -132,7 +136,7 @@ class WebApp(ApplicationABC):
|
|||||||
def with_authorization(self):
|
def with_authorization(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def main(self):
|
async def main(self):
|
||||||
_logger.debug(f"Preparing API")
|
_logger.debug(f"Preparing API")
|
||||||
if self._app is None:
|
if self._app is None:
|
||||||
routes = [
|
routes = [
|
||||||
@@ -162,10 +166,22 @@ class WebApp(ApplicationABC):
|
|||||||
app = self._app
|
app = self._app
|
||||||
|
|
||||||
_logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
|
_logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
|
||||||
uvicorn.run(
|
# uvicorn.run(
|
||||||
|
# app,
|
||||||
|
# host=self._api_settings.host,
|
||||||
|
# port=self._api_settings.port,
|
||||||
|
# log_config=None,
|
||||||
|
# loop="asyncio"
|
||||||
|
# )
|
||||||
|
|
||||||
|
config = uvicorn.Config(
|
||||||
app,
|
app,
|
||||||
host=self._api_settings.host,
|
host=self._api_settings.host,
|
||||||
port=self._api_settings.port,
|
port=self._api_settings.port,
|
||||||
log_config=None,
|
log_config=None,
|
||||||
|
loop="asyncio"
|
||||||
)
|
)
|
||||||
|
server = uvicorn.Server(config)
|
||||||
|
await server.serve()
|
||||||
|
|
||||||
_logger.info("Shutdown API")
|
_logger.info("Shutdown API")
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ class ApplicationABC(ABC):
|
|||||||
try:
|
try:
|
||||||
Host.run(self.main)
|
Host.run(self.main)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
Console.close()
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def main(self): ...
|
def main(self): ...
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
from keycloak import KeycloakOpenID
|
from keycloak import KeycloakOpenID
|
||||||
|
|
||||||
from cpl.auth.auth_logger import AuthLogger
|
from cpl.auth.auth_logger import AuthLogger
|
||||||
@@ -17,3 +19,7 @@ class KeycloakClient(KeycloakOpenID):
|
|||||||
client_secret_key=settings.client_secret,
|
client_secret_key=settings.client_secret,
|
||||||
)
|
)
|
||||||
_logger.info("Initializing Keycloak client")
|
_logger.info("Initializing Keycloak client")
|
||||||
|
|
||||||
|
def get_user_id(self, token: str) -> Optional[str]:
|
||||||
|
info = self.introspect(token)
|
||||||
|
return info.get("sub", None)
|
||||||
@@ -16,7 +16,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
DbModelDaoABC.__init__(self, __name__, AuthUser, TableManager.get("auth_users"))
|
DbModelDaoABC.__init__(self, __name__, AuthUser, TableManager.get("auth_users"))
|
||||||
|
|
||||||
self.attribute(AuthUser.keycloak_id, str, aliases=["keycloakId"])
|
self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId")
|
||||||
|
|
||||||
async def get_users():
|
async def get_users():
|
||||||
return [(x.id, x.username, x.email) for x in await self.get_all()]
|
return [(x.id, x.username, x.email) for x in await self.get_all()]
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ CREATE TABLE IF NOT EXISTS administration_auth_users
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS administration_auth_users_history
|
CREATE TABLE IF NOT EXISTS administration_auth_users_history
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
id INT NOT NULL,
|
||||||
keycloakId CHAR(36) NOT NULL,
|
keycloakId CHAR(36) NOT NULL,
|
||||||
-- for history
|
-- for history
|
||||||
deleted BOOL NOT NULL,
|
deleted BOOL NOT NULL,
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ CREATE TABLE IF NOT EXISTS administration_api_keys
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS administration_api_keys_history
|
CREATE TABLE IF NOT EXISTS administration_api_keys_history
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
id INT NOT NULL,
|
||||||
identifier VARCHAR(255) NOT NULL,
|
identifier VARCHAR(255) NOT NULL,
|
||||||
keyString VARCHAR(255) NOT NULL,
|
keyString VARCHAR(255) NOT NULL,
|
||||||
deleted BOOL NOT NULL,
|
deleted BOOL NOT NULL,
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS permission_permissions
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS permission_permissions_history
|
CREATE TABLE IF NOT EXISTS permission_permissions_history
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
id INT NOT NULL,
|
||||||
name VARCHAR(255) NOT NULL,
|
name VARCHAR(255) NOT NULL,
|
||||||
description TEXT NULL,
|
description TEXT NULL,
|
||||||
deleted BOOL NOT NULL,
|
deleted BOOL NOT NULL,
|
||||||
@@ -57,7 +57,7 @@ CREATE TABLE IF NOT EXISTS permission_roles
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS permission_roles_history
|
CREATE TABLE IF NOT EXISTS permission_roles_history
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
id INT NOT NULL,
|
||||||
name VARCHAR(255) NOT NULL,
|
name VARCHAR(255) NOT NULL,
|
||||||
description TEXT NULL,
|
description TEXT NULL,
|
||||||
deleted BOOL NOT NULL,
|
deleted BOOL NOT NULL,
|
||||||
@@ -103,7 +103,7 @@ CREATE TABLE IF NOT EXISTS permission_role_permissions
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS permission_role_permissions_history
|
CREATE TABLE IF NOT EXISTS permission_role_permissions_history
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
id INT NOT NULL,
|
||||||
RoleId INT NOT NULL,
|
RoleId INT NOT NULL,
|
||||||
permissionId INT NOT NULL,
|
permissionId INT NOT NULL,
|
||||||
deleted BOOL NOT NULL,
|
deleted BOOL NOT NULL,
|
||||||
@@ -149,7 +149,7 @@ CREATE TABLE IF NOT EXISTS permission_role_auth_users
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS permission_role_auth_users_history
|
CREATE TABLE IF NOT EXISTS permission_role_auth_users_history
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
id INT NOT NULL,
|
||||||
RoleId INT NOT NULL,
|
RoleId INT NOT NULL,
|
||||||
UserId INT NOT NULL,
|
UserId INT NOT NULL,
|
||||||
deleted BOOL NOT NULL,
|
deleted BOOL NOT NULL,
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ CREATE TABLE IF NOT EXISTS permission_api_key_permissions
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS permission_api_key_permissions_history
|
CREATE TABLE IF NOT EXISTS permission_api_key_permissions_history
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
id INT NOT NULL,
|
||||||
apiKeyId INT NOT NULL,
|
apiKeyId INT NOT NULL,
|
||||||
permissionId INT NOT NULL,
|
permissionId INT NOT NULL,
|
||||||
deleted BOOL NOT NULL,
|
deleted BOOL NOT NULL,
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from cpl.application.abc import ApplicationABC as _ApplicationABC
|
from cpl.application.abc import ApplicationABC as _ApplicationABC
|
||||||
@@ -7,13 +8,19 @@ from . import postgres as _postgres
|
|||||||
from .table_manager import TableManager
|
from .table_manager import TableManager
|
||||||
|
|
||||||
|
|
||||||
def _with_migrations(self: _ApplicationABC, *paths: list[str]) -> _ApplicationABC:
|
def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC:
|
||||||
from cpl.application.host import Host
|
from cpl.application.host import Host
|
||||||
|
|
||||||
from cpl.database.service.migration_service import MigrationService
|
from cpl.database.service.migration_service import MigrationService
|
||||||
|
|
||||||
migration_service = self._services.get_service(MigrationService)
|
migration_service = self._services.get_service(MigrationService)
|
||||||
migration_service.with_directory("./scripts")
|
migration_service.with_directory(os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts"))
|
||||||
|
|
||||||
|
if isinstance(paths, str):
|
||||||
|
paths = [paths]
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
migration_service.with_directory(path)
|
||||||
|
|
||||||
Host.run(migration_service.migrate)
|
Host.run(migration_service.migrate)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -156,13 +156,16 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
|
|||||||
:param dict result: Result from the database
|
:param dict result: Result from the database
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
value_map: dict[str, T] = {}
|
value_map: dict[str, Any] = {}
|
||||||
|
db_names = self.__db_names.items()
|
||||||
|
|
||||||
for db_name, value in result.items():
|
for db_name, value in result.items():
|
||||||
# Find the attribute name corresponding to the db_name
|
# Find the attribute name corresponding to the db_name
|
||||||
attr_name = next((k for k, v in self.__db_names.items() if v == db_name), None)
|
attr_name = next((k for k, v in db_names if v == db_name), None)
|
||||||
if attr_name:
|
if not attr_name:
|
||||||
value_map[attr_name] = self._get_value_from_sql(self.__attributes[attr_name], value)
|
continue
|
||||||
|
|
||||||
|
value_map[attr_name] = self._get_value_from_sql(self.__attributes[attr_name], value)
|
||||||
|
|
||||||
return self._model_type(**value_map)
|
return self._model_type(**value_map)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
|
|
||||||
import sqlparse
|
import sqlparse
|
||||||
import aiomysql
|
from mysql.connector.aio import MySQLConnectionPool
|
||||||
|
|
||||||
from cpl.core.environment import Environment
|
from cpl.core.environment import Environment
|
||||||
from cpl.database.db_logger import DBLogger
|
from cpl.database.db_logger import DBLogger
|
||||||
@@ -9,97 +9,83 @@ from cpl.database.model import DatabaseSettings
|
|||||||
|
|
||||||
_logger = DBLogger(__name__)
|
_logger = DBLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MySQLPool:
|
class MySQLPool:
|
||||||
"""
|
|
||||||
Create a pool when connecting to MySQL, which will decrease the time spent in
|
|
||||||
requesting connection, creating connection, and closing connection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, database_settings: DatabaseSettings):
|
def __init__(self, database_settings: DatabaseSettings):
|
||||||
self._db_settings = database_settings
|
self._dbconfig = {
|
||||||
self.pool: Optional[aiomysql.Pool] = None
|
"host": database_settings.host,
|
||||||
|
"port": database_settings.port,
|
||||||
|
"user": database_settings.user,
|
||||||
|
"password": database_settings.password,
|
||||||
|
"database": database_settings.database,
|
||||||
|
"ssl_disabled": True,
|
||||||
|
}
|
||||||
|
self._pool: Optional[MySQLConnectionPool] = None
|
||||||
|
|
||||||
async def _get_pool(self):
|
async def _get_pool(self):
|
||||||
if self.pool is None or self.pool._closed:
|
if self._pool is None:
|
||||||
|
self._pool = MySQLConnectionPool(
|
||||||
|
pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig
|
||||||
|
)
|
||||||
|
await self._pool.initialize_pool()
|
||||||
|
|
||||||
|
con = await self._pool.get_connection()
|
||||||
try:
|
try:
|
||||||
self.pool = await aiomysql.create_pool(
|
async with await con.cursor() as cursor:
|
||||||
host=self._db_settings.host,
|
await cursor.execute("SELECT 1")
|
||||||
port=self._db_settings.port,
|
await cursor.fetchall()
|
||||||
user=self._db_settings.user,
|
|
||||||
password=self._db_settings.password,
|
|
||||||
db=self._db_settings.database,
|
|
||||||
minsize=1,
|
|
||||||
maxsize=Environment.get("DB_POOL_SIZE", int, 1),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.fatal("Failed to connect to the database", e)
|
_logger.fatal(f"Error connecting to the database: {e}")
|
||||||
raise
|
finally:
|
||||||
return self.pool
|
await con.close()
|
||||||
|
|
||||||
|
return self._pool
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
|
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
|
||||||
|
result = []
|
||||||
if multi:
|
if multi:
|
||||||
queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()]
|
queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()]
|
||||||
for q in queries:
|
for q in queries:
|
||||||
if q.strip() == "":
|
if q.strip() == "":
|
||||||
continue
|
continue
|
||||||
await cursor.execute(q, args)
|
await cursor.execute(q, args)
|
||||||
|
if cursor.description is not None:
|
||||||
|
result = await cursor.fetchall()
|
||||||
else:
|
else:
|
||||||
await cursor.execute(query, args)
|
await cursor.execute(query, args)
|
||||||
|
if cursor.description is not None:
|
||||||
|
result = await cursor.fetchall()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def execute(self, query: str, args=None, multi=True) -> list[list]:
|
async def execute(self, query: str, args=None, multi=True) -> list[list]:
|
||||||
"""
|
|
||||||
Execute a SQL statement, it could be with args and without args. The usage is
|
|
||||||
similar to the execute() function in aiomysql.
|
|
||||||
:param query: SQL clause
|
|
||||||
:param args: args needed by the SQL clause
|
|
||||||
:param multi: if the query is a multi-statement
|
|
||||||
:return: return result
|
|
||||||
"""
|
|
||||||
pool = await self._get_pool()
|
pool = await self._get_pool()
|
||||||
async with pool.acquire() as con:
|
con = await pool.get_connection()
|
||||||
async with con.cursor() as cursor:
|
try:
|
||||||
await self._exec_sql(cursor, query, args, multi)
|
async with await con.cursor() as cursor:
|
||||||
|
result = await self._exec_sql(cursor, query, args, multi)
|
||||||
await con.commit()
|
await con.commit()
|
||||||
|
return result
|
||||||
if cursor.description is not None: # Query returns rows
|
finally:
|
||||||
res = await cursor.fetchall()
|
await con.close()
|
||||||
if res is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return [list(row) for row in res]
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def select(self, query: str, args=None, multi=True) -> list[str]:
|
async def select(self, query: str, args=None, multi=True) -> list[str]:
|
||||||
"""
|
|
||||||
Execute a SQL statement, it could be with args and without args. The usage is
|
|
||||||
similar to the execute() function in aiomysql.
|
|
||||||
:param query: SQL clause
|
|
||||||
:param args: args needed by the SQL clause
|
|
||||||
:param multi: if the query is a multi-statement
|
|
||||||
:return: return result
|
|
||||||
"""
|
|
||||||
pool = await self._get_pool()
|
pool = await self._get_pool()
|
||||||
async with pool.acquire() as con:
|
con = await pool.get_connection()
|
||||||
async with con.cursor() as cursor:
|
try:
|
||||||
await self._exec_sql(cursor, query, args, multi)
|
async with await con.cursor() as cursor:
|
||||||
res = await cursor.fetchall()
|
res = await self._exec_sql(cursor, query, args, multi)
|
||||||
return list(res)
|
return list(res)
|
||||||
|
finally:
|
||||||
|
await con.close()
|
||||||
|
|
||||||
async def select_map(self, query: str, args=None, multi=True) -> list[dict]:
|
async def select_map(self, query: str, args=None, multi=True) -> list[dict]:
|
||||||
"""
|
|
||||||
Execute a SQL statement, it could be with args and without args. The usage is
|
|
||||||
similar to the execute() function in aiomysql.
|
|
||||||
:param query: SQL clause
|
|
||||||
:param args: args needed by the SQL clause
|
|
||||||
:param multi: if the query is a multi-statement
|
|
||||||
:return: return result
|
|
||||||
"""
|
|
||||||
pool = await self._get_pool()
|
pool = await self._get_pool()
|
||||||
async with pool.acquire() as con:
|
con = await pool.get_connection()
|
||||||
async with con.cursor(aiomysql.DictCursor) as cursor:
|
try:
|
||||||
await self._exec_sql(cursor, query, args, multi)
|
async with await con.cursor(dictionary=True) as cursor:
|
||||||
res = await cursor.fetchall()
|
res = await self._exec_sql(cursor, query, args, multi)
|
||||||
return list(res)
|
return list(res)
|
||||||
|
finally:
|
||||||
|
await con.close()
|
||||||
|
|||||||
@@ -25,21 +25,23 @@ class PostgresPool:
|
|||||||
f"password={database_settings.password} "
|
f"password={database_settings.password} "
|
||||||
f"dbname={database_settings.database}"
|
f"dbname={database_settings.database}"
|
||||||
)
|
)
|
||||||
|
self._pool: Optional[AsyncConnectionPool] = None
|
||||||
self.pool: Optional[AsyncConnectionPool] = None
|
|
||||||
|
|
||||||
async def _get_pool(self):
|
async def _get_pool(self):
|
||||||
pool = AsyncConnectionPool(
|
if self._pool is None:
|
||||||
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
|
pool = AsyncConnectionPool(
|
||||||
)
|
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
|
||||||
await pool.open()
|
)
|
||||||
try:
|
await pool.open()
|
||||||
async with pool.connection() as con:
|
try:
|
||||||
await pool.check_connection(con)
|
async with pool.connection() as con:
|
||||||
except PoolTimeout as e:
|
await pool.check_connection(con)
|
||||||
await pool.close()
|
except PoolTimeout as e:
|
||||||
_logger.fatal(f"Failed to connect to the database", e)
|
await pool.close()
|
||||||
return pool
|
_logger.fatal(f"Failed to connect to the database", e)
|
||||||
|
self._pool = pool
|
||||||
|
|
||||||
|
return self._pool
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
|
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
|
||||||
|
|||||||
@@ -114,14 +114,22 @@ class ServiceProviderABC(ABC):
|
|||||||
if f is None:
|
if f is None:
|
||||||
return functools.partial(cls.inject)
|
return functools.partial(cls.inject)
|
||||||
|
|
||||||
|
if iscoroutinefunction(f):
|
||||||
|
@functools.wraps(f)
|
||||||
|
async def async_inner(*args, **kwargs):
|
||||||
|
if cls._provider is None:
|
||||||
|
raise Exception(f"{cls.__name__} not build!")
|
||||||
|
|
||||||
|
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
|
||||||
|
return await f(*args, *injection, **kwargs)
|
||||||
|
|
||||||
|
return async_inner
|
||||||
|
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
async def inner(*args, **kwargs):
|
def inner(*args, **kwargs):
|
||||||
if cls._provider is None:
|
if cls._provider is None:
|
||||||
raise Exception(f"{cls.__name__} not build!")
|
raise Exception(f"{cls.__name__} not build!")
|
||||||
|
|
||||||
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
|
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
|
||||||
if iscoroutinefunction(f):
|
|
||||||
return await f(*args, *injection, **kwargs)
|
|
||||||
return f(*args, *injection, **kwargs)
|
return f(*args, *injection, **kwargs)
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from service import PingService
|
|||||||
def main():
|
def main():
|
||||||
builder = ApplicationBuilder[WebApp](WebApp)
|
builder = ApplicationBuilder[WebApp](WebApp)
|
||||||
|
|
||||||
|
|
||||||
Configuration.add_json_file(f"appsettings.json")
|
Configuration.add_json_file(f"appsettings.json")
|
||||||
Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json")
|
Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json")
|
||||||
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
|
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
|
||||||
@@ -21,10 +20,12 @@ def main():
|
|||||||
builder.services.add_module(api)
|
builder.services.add_module(api)
|
||||||
|
|
||||||
app = builder.build()
|
app = builder.build()
|
||||||
|
app.with_logging()
|
||||||
|
app.with_database()
|
||||||
|
|
||||||
|
app.with_authentication()
|
||||||
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET")
|
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET")
|
||||||
app.with_routes_directory("routes")
|
app.with_routes_directory("routes")
|
||||||
app.with_logging()
|
|
||||||
app.with_authentication()
|
|
||||||
|
|
||||||
app.run()
|
app.run()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user