Compare commits

..

6 Commits

Author SHA1 Message Date
12b7c62b69 Fixed formatting
All checks were successful
Build on push / prepare (push) Successful in 9s
Build on push / core (push) Successful in 17s
Build on push / query (push) Successful in 17s
Build on push / dependency (push) Successful in 17s
Build on push / translation (push) Successful in 14s
Build on push / application (push) Successful in 18s
Build on push / database (push) Successful in 17s
Build on push / mail (push) Successful in 18s
Build on push / auth (push) Successful in 13s
Build on push / api (push) Successful in 17s
Test before pr merge / test-lint (pull_request) Successful in 5s
2025-09-21 23:48:09 +02:00
7fc70747bb Added black test
Some checks failed
Test before pr merge / test-lint (pull_request) Failing after 6s
Build on push / prepare (push) Successful in 10s
Build on push / core (push) Successful in 17s
Build on push / query (push) Successful in 17s
Build on push / dependency (push) Successful in 17s
Build on push / api (push) Has been cancelled
Build on push / auth (push) Has been cancelled
Build on push / mail (push) Has started running
Build on push / translation (push) Has been cancelled
Build on push / application (push) Has been cancelled
Build on push / database (push) Has been cancelled
2025-09-21 23:47:15 +02:00
6de4f3c03a Middleware updated & Fixed mysql pool
All checks were successful
Build on push / prepare (push) Successful in 10s
Build on push / core (push) Successful in 19s
Build on push / query (push) Successful in 18s
Build on push / dependency (push) Successful in 17s
Build on push / database (push) Successful in 15s
Build on push / translation (push) Successful in 18s
Build on push / mail (push) Successful in 19s
Build on push / application (push) Successful in 21s
Build on push / auth (push) Successful in 14s
Build on push / api (push) Successful in 14s
2025-09-21 23:41:25 +02:00
ea3055527c Changed middleware to asgi 2025-09-21 21:22:19 +02:00
7b37748ca6 [WIP] validate token via keycloak 2025-09-21 21:07:09 +02:00
073b35f71a App deps check 2025-09-21 20:11:47 +02:00
31 changed files with 445 additions and 190 deletions

View File

@@ -0,0 +1,26 @@
name: Test before pr merge
run-name: Test before pr merge
on:
pull_request:
types:
- opened
- edited
- reopened
- synchronize
- ready_for_review
jobs:
test-lint:
runs-on: [ runner ]
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
uses: https://github.com/actions/checkout@v3
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Installing black
run: python3.12 -m pip install black
- name: Checking black
run: python3.12 -m black src --check

View File

@@ -0,0 +1,26 @@
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
def add_api(collection: _ServiceCollection):
try:
from cpl.database import mysql
collection.add_module(mysql)
except ImportError as e:
from cpl.core.errors import dependency_error
dependency_error("cpl-database", e)
try:
from cpl import auth
from cpl.auth import permission
collection.add_module(auth)
collection.add_module(permission)
except ImportError as e:
from cpl.core.errors import dependency_error
dependency_error("cpl-auth", e)
_ServiceCollection.with_module(add_api, __name__)

View File

View File

@@ -0,0 +1,15 @@
from abc import ABC, abstractmethod
from starlette.types import Scope, Receive, Send
class ASGIMiddleware(ABC):
@abstractmethod
def __init__(self, app):
self._app = app
def _call_next(self, scope: Scope, receive: Receive, send: Send):
return self._app(scope, receive, send)
@abstractmethod
async def __call__(self, scope: Scope, receive: Receive, send: Send): ...

View File

@@ -1,11 +1,17 @@
from http.client import HTTPException
from starlette.responses import JSONResponse
from starlette.types import Scope, Receive, Send
class APIError(HTTPException):
status_code = 500
@classmethod
async def asgi_response(cls, scope: Scope, receive: Receive, send: Send):
r = JSONResponse({"error": cls.__name__}, status_code=cls.status_code)
return await r(scope, receive, send)
@classmethod
def response(cls):
return JSONResponse({"error": cls.__name__}, status_code=cls.status_code)

View File

@@ -1,49 +1,76 @@
from keycloak import KeycloakAuthenticationError
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.api_logger import APILogger
from cpl.api.error import Unauthorized
from cpl.api.middleware.request import get_request
from cpl.api.router import Router
from cpl.auth.keycloak import KeycloakClient
from cpl.auth.schema import AuthUserDao, AuthUser
from cpl.dependency import ServiceProviderABC
_logger = APILogger(__name__)
class AuthenticationMiddleware(BaseHTTPMiddleware):
class AuthenticationMiddleware(ASGIMiddleware):
@classmethod
async def _verify_login(cls, token: str) -> bool:
keycloak = ServiceProviderABC.get_global_service(KeycloakClient)
try:
user_info = keycloak.userinfo(token)
if not user_info:
return False
except KeycloakAuthenticationError:
return False
return True
@ServiceProviderABC.inject
def __init__(self, app, keycloak: KeycloakClient, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app)
async def dispatch(self, request: Request, call_next):
self._keycloak = keycloak
self._user_dao = user_dao
async def __call__(self, scope: Scope, receive: Receive, send: Send):
request = get_request()
url = request.url.path
if url not in Router.get_auth_required_routes():
_logger.trace(f"No authentication required for {url}")
return await call_next(request)
return await self._app(scope, receive, send)
if not request.headers.get("Authorization"):
_logger.debug(f"Unauthorized access to {url}, missing Authorization header")
return Unauthorized(f"Missing header Authorization").response()
return await Unauthorized(f"Missing header Authorization").asgi_response(scope, receive, send)
auth_header = request.headers.get("Authorization", None)
if not auth_header or not auth_header.startswith("Bearer "):
return Unauthorized("Invalid Authorization header").response()
return await Unauthorized("Invalid Authorization header").asgi_response(scope, receive, send)
if not await self._verify_login(auth_header.split("Bearer ")[1]):
token = auth_header.split("Bearer ")[1]
if not await self._verify_login(token):
_logger.debug(f"Unauthorized access to {url}, invalid token")
return Unauthorized("Invalid token").response()
return await Unauthorized("Invalid token").asgi_response(scope, receive, send)
# check user exists in db, if not create
# unauthorized if user is deleted
return await call_next(request)
keycloak_id = self._keycloak.get_user_id(token)
if keycloak_id is None:
return await Unauthorized("Failed to get user id from token").asgi_response(scope, receive, send)
user = await self._get_or_crate_user(keycloak_id)
if user.deleted:
_logger.debug(f"Unauthorized access to {url}, user is deleted")
return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
return await self._call_next(scope, receive, send)
async def _get_or_crate_user(self, keycloak_id: str) -> AuthUser:
existing = await self._user_dao.find_by_keycloak_id(keycloak_id)
if existing is not None:
return existing
user = AuthUser(0, keycloak_id)
uid = await self._user_dao.create(user)
return await self._user_dao.get_by_id(uid)
async def _verify_login(self, token: str) -> bool:
try:
token_info = self._keycloak.introspect(token)
return token_info.get("active", False)
except KeycloakAuthenticationError as e:
_logger.debug(f"Keycloak authentication error: {e}")
return False
except Exception as e:
_logger.error(f"Unexpected error during token verification: {e}")
return False

View File

@@ -1,21 +1,44 @@
import time
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import Receive, Scope, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.api_logger import APILogger
from cpl.api.middleware.request import get_request
_logger = APILogger(__name__)
class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
await self._log_request(request)
response = await call_next(request)
await self._log_after_request(request, response)
class LoggingMiddleware(ASGIMiddleware):
return response
def __init__(self, app):
ASGIMiddleware.__init__(self, app)
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
await self._call_next(scope, receive, send)
return
request = get_request()
await self._log_request(request)
start_time = time.time()
response_body = b""
status_code = 500
async def send_wrapper(message):
nonlocal response_body, status_code
if message["type"] == "http.response.start":
status_code = message["status"]
if message["type"] == "http.response.body":
response_body += message.get("body", b"")
await send(message)
await self._call_next(scope, receive, send_wrapper)
duration = (time.time() - start_time) * 1000
await self._log_after_request(request, status_code, duration)
@staticmethod
def _filter_relevant_headers(headers: dict) -> dict:
@@ -33,7 +56,7 @@ class LoggingMiddleware(BaseHTTPMiddleware):
@classmethod
async def _log_request(cls, request: Request):
_logger.debug(
f"Request {request.state.request_id}: {request.method}@{request.url.path} from {request.client.host}"
f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}"
)
from cpl.core.ctx.user_context import get_user
@@ -55,11 +78,10 @@ class LoggingMiddleware(BaseHTTPMiddleware):
),
}
_logger.trace(f"Request {request.state.request_id}: {request_info}")
_logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
@staticmethod
async def _log_after_request(request: Request, response: Response):
duration = (time.time() - request.state.start_time) * 1000
async def _log_after_request(request: Request, status_code: int, duration: float):
_logger.info(
f"Request finished {request.state.request_id}: {response.status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
)

View File

@@ -3,9 +3,11 @@ from contextvars import ContextVar
from typing import Optional, Union
from uuid import uuid4
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.types import Scope, Receive, Send
from starlette.websockets import WebSocket
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.api_logger import APILogger
from cpl.api.typing import TRequest
@@ -14,34 +16,38 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
_logger = APILogger(__name__)
class RequestMiddleware(BaseHTTPMiddleware):
_request_token = {}
_user_token = {}
class RequestMiddleware(ASGIMiddleware):
@classmethod
async def set_request_data(cls, request: TRequest):
def __init__(self, app):
ASGIMiddleware.__init__(self, app)
self._ctx_token = None
async def __call__(self, scope: Scope, receive: Receive, send: Send):
request = Request(scope, receive, send)
await self.set_request_data(request)
try:
await self._app(scope, receive, send)
finally:
await self.clean_request_data()
async def set_request_data(self, request: TRequest):
request.state.request_id = uuid4()
request.state.start_time = time.time()
_logger.trace(f"Set new current request: {request.state.request_id}")
cls._request_token[request.state.request_id] = _request_context.set(request)
self._ctx_token = _request_context.set(request)
@classmethod
async def clean_request_data(cls):
async def clean_request_data(self):
request = get_request()
if request is None:
return
if request.state.request_id in cls._request_token:
_request_context.reset(cls._request_token[request.state.request_id])
if self._ctx_token is None:
return
async def dispatch(self, request: TRequest, call_next):
await self.set_request_data(request)
try:
response = await call_next(request)
return response
finally:
await self.clean_request_data()
_logger.trace(f"Clearing current request: {request.state.request_id}")
_request_context.reset(self._ctx_token)
def get_request() -> Optional[Union[TRequest, WebSocket]]:

View File

@@ -10,6 +10,7 @@ from starlette.responses import JSONResponse
from starlette.routing import Route
from starlette.types import ExceptionHandler
from cpl import api, auth
from cpl.api.api_logger import APILogger
from cpl.api.api_settings import ApiSettings
from cpl.api.error import APIError
@@ -27,7 +28,7 @@ _logger = APILogger("API")
class WebApp(ApplicationABC):
def __init__(self, services: ServiceProviderABC):
super().__init__(services)
super().__init__(services, [auth, api])
self._app: Starlette | None = None
self._api_settings = Configuration.get(ApiSettings)
@@ -65,6 +66,10 @@ class WebApp(ApplicationABC):
_logger.debug(f"Allowed origins: {origins}")
return origins.split(",")
def with_database(self):
self.with_migrations()
self.with_seeders()
def with_app(self, app: Starlette):
assert app is not None, "app must not be None"
assert isinstance(app, Starlette), "app must be an instance of Starlette"
@@ -131,7 +136,7 @@ class WebApp(ApplicationABC):
def with_authorization(self):
pass
def main(self):
async def main(self):
_logger.debug(f"Preparing API")
if self._app is None:
routes = [
@@ -161,10 +166,18 @@ class WebApp(ApplicationABC):
app = self._app
_logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
uvicorn.run(
app,
host=self._api_settings.host,
port=self._api_settings.port,
log_config=None,
# uvicorn.run(
# app,
# host=self._api_settings.host,
# port=self._api_settings.port,
# log_config=None,
# loop="asyncio"
# )
config = uvicorn.Config(
app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio"
)
server = uvicorn.Server(config)
await server.serve()
_logger.info("Shutdown API")

View File

@@ -22,8 +22,15 @@ class ApplicationABC(ABC):
"""
@abstractmethod
def __init__(self, services: ServiceProviderABC):
def __init__(self, services: ServiceProviderABC, required_modules: list[str | object] = None):
self._services = services
self._required_modules = (
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []
)
@property
def required_modules(self) -> list[str]:
return self._required_modules
@classmethod
def extend(cls, name: str | Callable, func: Callable[[Self], Self]):
@@ -80,7 +87,7 @@ class ApplicationABC(ABC):
try:
Host.run(self.main)
except KeyboardInterrupt:
Console.close()
pass
@abstractmethod
def main(self): ...

View File

@@ -6,6 +6,7 @@ from cpl.application.abc.application_extension_abc import ApplicationExtensionAB
from cpl.application.abc.startup_abc import StartupABC
from cpl.application.abc.startup_extension_abc import StartupExtensionABC
from cpl.application.host import Host
from cpl.core.errors import dependency_error
from cpl.dependency.service_collection import ServiceCollection
TApp = TypeVar("TApp", bound=ApplicationABC)
@@ -35,6 +36,18 @@ class ApplicationBuilder(Generic[TApp]):
def service_provider(self):
return self._services.build()
def validate_app_required_modules(self, app: ApplicationABC):
for module in app.required_modules:
if module in self._services.loaded_modules:
continue
dependency_error(
module,
ImportError(
f"Required module '{module}' for application '{app.__class__.__name__}' is not loaded. Load using 'add_module({module})' method."
),
)
def with_startup(self, startup: Type[StartupABC]) -> "ApplicationBuilder":
self._startup = startup
return self
@@ -62,4 +75,6 @@ class ApplicationBuilder(Generic[TApp]):
for extension in self._app_extensions:
Host.run(extension.run, self.service_provider)
return self._app(self.service_provider)
app = self._app(self.service_provider)
self.validate_app_required_modules(app)
return app

View File

@@ -40,11 +40,10 @@ def _add_daos(collection: _ServiceCollection):
def add_auth(collection: _ServiceCollection):
import os
from cpl.core.console import Console
from cpl.database.service.migration_service import MigrationService
from cpl.database.model.server_type import ServerType, ServerTypes
try:
from cpl.database.service.migration_service import MigrationService
from cpl.database.model.server_type import ServerType, ServerTypes
collection.add_singleton(_KeycloakClient)
collection.add_singleton(_KeycloakAdmin)
@@ -59,22 +58,25 @@ def add_auth(collection: _ServiceCollection):
elif ServerType.server_type == ServerTypes.MYSQL:
migration_service.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/mysql"))
except ImportError as e:
Console.error("cpl-auth is not installed", str(e))
from cpl.core.console import Console
Console.error("cpl-database is not installed", str(e))
def add_permission(collection: _ServiceCollection):
from cpl.auth.permission_seeder import PermissionSeeder
from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.auth.permission.permissions_registry import PermissionsRegistry
from cpl.auth.permission.permissions import Permissions
from .permission_seeder import PermissionSeeder
from .permission.permissions_registry import PermissionsRegistry
from .permission.permissions import Permissions
try:
from cpl.database.abc.data_seeder_abc import DataSeederABC
collection.add_singleton(DataSeederABC, PermissionSeeder)
PermissionsRegistry.with_enum(Permissions)
except ImportError as e:
from cpl.core.console import Console
Console.error("cpl-auth is not installed", str(e))
Console.error("cpl-database is not installed", str(e))
_ServiceCollection.with_module(add_auth, __name__)

View File

@@ -1,4 +1,6 @@
from keycloak import KeycloakOpenID, KeycloakAdmin, KeycloakOpenIDConnection
from typing import Optional
from keycloak import KeycloakOpenID
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.keycloak_settings import KeycloakSettings
@@ -17,10 +19,7 @@ class KeycloakClient(KeycloakOpenID):
client_secret_key=settings.client_secret,
)
_logger.info("Initializing Keycloak client")
connection = KeycloakOpenIDConnection(
server_url=settings.url,
client_id=settings.client_id,
realm_name=settings.realm,
client_secret_key=settings.client_secret,
)
self._admin = KeycloakAdmin(connection=connection)
def get_user_id(self, token: str) -> Optional[str]:
info = self.introspect(token)
return info.get("sub", None)

View File

@@ -16,7 +16,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
def __init__(self):
DbModelDaoABC.__init__(self, __name__, AuthUser, TableManager.get("auth_users"))
self.attribute(AuthUser.keycloak_id, str, aliases=["keycloakId"])
self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId")
async def get_users():
return [(x.id, x.username, x.email) for x in await self.get_all()]

View File

@@ -14,7 +14,7 @@ CREATE TABLE IF NOT EXISTS administration_auth_users
CREATE TABLE IF NOT EXISTS administration_auth_users_history
(
id INT AUTO_INCREMENT PRIMARY KEY,
id INT NOT NULL,
keycloakId CHAR(36) NOT NULL,
-- for history
deleted BOOL NOT NULL,

View File

@@ -15,7 +15,7 @@ CREATE TABLE IF NOT EXISTS administration_api_keys
CREATE TABLE IF NOT EXISTS administration_api_keys_history
(
id INT AUTO_INCREMENT PRIMARY KEY,
id INT NOT NULL,
identifier VARCHAR(255) NOT NULL,
keyString VARCHAR(255) NOT NULL,
deleted BOOL NOT NULL,

View File

@@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS permission_permissions
CREATE TABLE IF NOT EXISTS permission_permissions_history
(
id INT AUTO_INCREMENT PRIMARY KEY,
id INT NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT NULL,
deleted BOOL NOT NULL,
@@ -57,7 +57,7 @@ CREATE TABLE IF NOT EXISTS permission_roles
CREATE TABLE IF NOT EXISTS permission_roles_history
(
id INT AUTO_INCREMENT PRIMARY KEY,
id INT NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT NULL,
deleted BOOL NOT NULL,
@@ -103,7 +103,7 @@ CREATE TABLE IF NOT EXISTS permission_role_permissions
CREATE TABLE IF NOT EXISTS permission_role_permissions_history
(
id INT AUTO_INCREMENT PRIMARY KEY,
id INT NOT NULL,
RoleId INT NOT NULL,
permissionId INT NOT NULL,
deleted BOOL NOT NULL,
@@ -149,7 +149,7 @@ CREATE TABLE IF NOT EXISTS permission_role_auth_users
CREATE TABLE IF NOT EXISTS permission_role_auth_users_history
(
id INT AUTO_INCREMENT PRIMARY KEY,
id INT NOT NULL,
RoleId INT NOT NULL,
UserId INT NOT NULL,
deleted BOOL NOT NULL,

View File

@@ -15,7 +15,7 @@ CREATE TABLE IF NOT EXISTS permission_api_key_permissions
CREATE TABLE IF NOT EXISTS permission_api_key_permissions_history
(
id INT AUTO_INCREMENT PRIMARY KEY,
id INT NOT NULL,
apiKeyId INT NOT NULL,
permissionId INT NOT NULL,
deleted BOOL NOT NULL,

View File

@@ -130,7 +130,7 @@ class Configuration:
key_name = key.__name__ if inspect.isclass(key) else key
result = cls._config.get(key_name, default)
if issubclass(key, ConfigurationModelABC) and result == default:
if isclass(key) and issubclass(key, ConfigurationModelABC) and result == default:
result = key()
cls.set(key, result)

View File

@@ -68,7 +68,7 @@ class ConfigurationModelABC(ABC):
value = cast(Environment.get(env_field, str), cast_type)
if value is None and required:
raise ValueError(f"{field} is required")
raise ValueError(f"{type(self).__name__}.{field} is required")
elif value is None:
self._options[field] = default
return

View File

@@ -0,0 +1,15 @@
import traceback
from cpl.core.console import Console
def dependency_error(package_name: str, e: ImportError) -> None:
Console.error(f"'{package_name}' is required to use this feature. Please install it and try again.")
tb = traceback.format_exc()
if not tb.startswith("NoneType: None"):
Console.write_line("->", tb)
elif e is not None:
Console.write_line("->", str(e))
exit(1)

View File

@@ -1,3 +1,4 @@
import os
from typing import Type
from cpl.application.abc import ApplicationABC as _ApplicationABC
@@ -7,13 +8,19 @@ from . import postgres as _postgres
from .table_manager import TableManager
def _with_migrations(self: _ApplicationABC, *paths: list[str]) -> _ApplicationABC:
def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC:
from cpl.application.host import Host
from cpl.database.service.migration_service import MigrationService
migration_service = self._services.get_service(MigrationService)
migration_service.with_directory("./scripts")
migration_service.with_directory(os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts"))
if isinstance(paths, str):
paths = [paths]
for path in paths:
migration_service.with_directory(path)
Host.run(migration_service.migrate)
return self

View File

@@ -156,13 +156,16 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
:param dict result: Result from the database
:return:
"""
value_map: dict[str, T] = {}
value_map: dict[str, Any] = {}
db_names = self.__db_names.items()
for db_name, value in result.items():
# Find the attribute name corresponding to the db_name
attr_name = next((k for k, v in self.__db_names.items() if v == db_name), None)
if attr_name:
value_map[attr_name] = self._get_value_from_sql(self.__attributes[attr_name], value)
attr_name = next((k for k, v in db_names if v == db_name), None)
if not attr_name:
continue
value_map[attr_name] = self._get_value_from_sql(self.__attributes[attr_name], value)
return self._model_type(**value_map)

View File

@@ -1,7 +1,7 @@
from typing import Optional, Any
import sqlparse
import aiomysql
from mysql.connector.aio import MySQLConnectionPool
from cpl.core.environment import Environment
from cpl.database.db_logger import DBLogger
@@ -11,95 +11,82 @@ _logger = DBLogger(__name__)
class MySQLPool:
"""
Create a pool when connecting to MySQL, which will decrease the time spent in
requesting connection, creating connection, and closing connection.
"""
def __init__(self, database_settings: DatabaseSettings):
self._db_settings = database_settings
self.pool: Optional[aiomysql.Pool] = None
self._dbconfig = {
"host": database_settings.host,
"port": database_settings.port,
"user": database_settings.user,
"password": database_settings.password,
"database": database_settings.database,
"ssl_disabled": True,
}
self._pool: Optional[MySQLConnectionPool] = None
async def _get_pool(self):
if self.pool is None or self.pool._closed:
if self._pool is None:
self._pool = MySQLConnectionPool(
pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig
)
await self._pool.initialize_pool()
con = await self._pool.get_connection()
try:
self.pool = await aiomysql.create_pool(
host=self._db_settings.host,
port=self._db_settings.port,
user=self._db_settings.user,
password=self._db_settings.password,
db=self._db_settings.database,
minsize=1,
maxsize=Environment.get("DB_POOL_SIZE", int, 1),
)
async with await con.cursor() as cursor:
await cursor.execute("SELECT 1")
await cursor.fetchall()
except Exception as e:
_logger.fatal("Failed to connect to the database", e)
raise
return self.pool
_logger.fatal(f"Error connecting to the database: {e}")
finally:
await con.close()
return self._pool
@staticmethod
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
result = []
if multi:
queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()]
for q in queries:
if q.strip() == "":
continue
await cursor.execute(q, args)
if cursor.description is not None:
result = await cursor.fetchall()
else:
await cursor.execute(query, args)
if cursor.description is not None:
result = await cursor.fetchall()
return result
async def execute(self, query: str, args=None, multi=True) -> list[list]:
"""
Execute a SQL statement, it could be with args and without args. The usage is
similar to the execute() function in aiomysql.
:param query: SQL clause
:param args: args needed by the SQL clause
:param multi: if the query is a multi-statement
:return: return result
"""
pool = await self._get_pool()
async with pool.acquire() as con:
async with con.cursor() as cursor:
await self._exec_sql(cursor, query, args, multi)
con = await pool.get_connection()
try:
async with await con.cursor() as cursor:
result = await self._exec_sql(cursor, query, args, multi)
await con.commit()
if cursor.description is not None: # Query returns rows
res = await cursor.fetchall()
if res is None:
return []
return [list(row) for row in res]
else:
return []
return result
finally:
await con.close()
async def select(self, query: str, args=None, multi=True) -> list[str]:
"""
Execute a SQL statement, it could be with args and without args. The usage is
similar to the execute() function in aiomysql.
:param query: SQL clause
:param args: args needed by the SQL clause
:param multi: if the query is a multi-statement
:return: return result
"""
pool = await self._get_pool()
async with pool.acquire() as con:
async with con.cursor() as cursor:
await self._exec_sql(cursor, query, args, multi)
res = await cursor.fetchall()
con = await pool.get_connection()
try:
async with await con.cursor() as cursor:
res = await self._exec_sql(cursor, query, args, multi)
return list(res)
finally:
await con.close()
async def select_map(self, query: str, args=None, multi=True) -> list[dict]:
"""
Execute a SQL statement, it could be with args and without args. The usage is
similar to the execute() function in aiomysql.
:param query: SQL clause
:param args: args needed by the SQL clause
:param multi: if the query is a multi-statement
:return: return result
"""
pool = await self._get_pool()
async with pool.acquire() as con:
async with con.cursor(aiomysql.DictCursor) as cursor:
await self._exec_sql(cursor, query, args, multi)
res = await cursor.fetchall()
con = await pool.get_connection()
try:
async with await con.cursor(dictionary=True) as cursor:
res = await self._exec_sql(cursor, query, args, multi)
return list(res)
finally:
await con.close()

View File

@@ -25,21 +25,23 @@ class PostgresPool:
f"password={database_settings.password} "
f"dbname={database_settings.database}"
)
self.pool: Optional[AsyncConnectionPool] = None
self._pool: Optional[AsyncConnectionPool] = None
async def _get_pool(self):
pool = AsyncConnectionPool(
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
)
await pool.open()
try:
async with pool.connection() as con:
await pool.check_connection(con)
except PoolTimeout as e:
await pool.close()
_logger.fatal(f"Failed to connect to the database", e)
return pool
if self._pool is None:
pool = AsyncConnectionPool(
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
)
await pool.open()
try:
async with pool.connection() as con:
await pool.check_connection(con)
except PoolTimeout as e:
await pool.close()
_logger.fatal(f"Failed to connect to the database", e)
self._pool = pool
return self._pool
@staticmethod
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):

View File

@@ -1,4 +1,4 @@
from typing import Union, Type, Callable
from typing import Union, Type, Callable, Self
from cpl.core.log.logger import Logger
from cpl.core.log.logger_abc import LoggerABC
@@ -15,12 +15,17 @@ class ServiceCollection:
_modules: dict[str, Callable] = {}
@classmethod
def with_module(cls, func: Callable, name: str = None):
def with_module(cls, func: Callable, name: str = None) -> type[Self]:
cls._modules[func.__name__ if name is None else name] = func
return cls
def __init__(self):
self._service_descriptors: list[ServiceDescriptor] = []
self._loaded_modules: set[str] = set()
@property
def loaded_modules(self) -> set[str]:
return self._loaded_modules
def _add_descriptor(self, service: Union[type, object], lifetime: ServiceLifetimeEnum, base_type: Callable = None):
found = False
@@ -45,15 +50,15 @@ class ServiceCollection:
return self
def add_singleton(self, service_type: T, service: Service = None):
def add_singleton(self, service_type: T, service: Service = None) -> Self:
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.singleton, service)
return self
def add_scoped(self, service_type: T, service: Service = None):
def add_scoped(self, service_type: T, service: Service = None) -> Self:
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.scoped, service)
return self
def add_transient(self, service_type: T, service: Service = None):
def add_transient(self, service_type: T, service: Service = None) -> Self:
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service)
return self
@@ -62,7 +67,7 @@ class ServiceCollection:
ServiceProviderABC.set_global_provider(sp)
return sp
def add_module(self, module: str | object):
def add_module(self, module: str | object) -> Self:
if not isinstance(module, str):
module = module.__name__
@@ -70,7 +75,10 @@ class ServiceCollection:
raise ValueError(f"Module {module} not found")
self._modules[module](self)
if module not in self._loaded_modules:
self._loaded_modules.add(module)
return self
def add_logging(self):
def add_logging(self) -> Self:
self.add_transient(LoggerABC, Logger)
return self

View File

@@ -24,19 +24,19 @@ class ServiceProviderABC(ABC):
return cls._provider
@classmethod
def get_global_service(cls, instance_type: T, *args, **kwargs) -> Optional[R]:
def get_global_service(cls, instance_type: Type[T], *args, **kwargs) -> Optional[T]:
if cls._provider is None:
return None
return cls._provider.get_service(instance_type, *args, **kwargs)
@classmethod
def get_global_services(cls, instance_type: T, *args, **kwargs) -> list[Optional[R]]:
def get_global_services(cls, instance_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
if cls._provider is None:
return []
return cls._provider.get_services(instance_type, *args, **kwargs)
@abstractmethod
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[R]: ...
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[T]: ...
@abstractmethod
def _build_service(self, service_type: type, *args, **kwargs) -> object:
@@ -114,14 +114,24 @@ class ServiceProviderABC(ABC):
if f is None:
return functools.partial(cls.inject)
if iscoroutinefunction(f):
@functools.wraps(f)
async def async_inner(*args, **kwargs):
if cls._provider is None:
raise Exception(f"{cls.__name__} not build!")
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
return await f(*args, *injection, **kwargs)
return async_inner
@functools.wraps(f)
async def inner(*args, **kwargs):
def inner(*args, **kwargs):
if cls._provider is None:
raise Exception(f"{cls.__name__} not build!")
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
if iscoroutinefunction(f):
return await f(*args, *injection, **kwargs)
return f(*args, *injection, **kwargs)
return inner

View File

@@ -0,0 +1,8 @@
{
"Logging": {
"Path": "logs/",
"Filename": "log_$start_time.log",
"ConsoleLevel": "TRACE",
"Level": "TRACE"
}
}

View File

@@ -0,0 +1,26 @@
{
"TimeFormat": {
"DateFormat": "%Y-%m-%d",
"TimeFormat": "%H:%M:%S",
"DateTimeFormat": "%Y-%m-%d %H:%M:%S.%f",
"DateTimeLogFormat": "%Y-%m-%d_%H-%M-%S"
},
"Log": {
"Path": "logs/",
"Filename": "log_$start_time.log",
"ConsoleLevel": "TRACE",
"Level": "TRACE"
},
"Database": {
"Host": "localhost",
"User": "cpl",
"Port": 3306,
"Password": "cpl",
"Database": "cpl",
"Charset": "utf8mb4",
"UseUnicode": "true",
"Buffered": "true"
}
}

View File

@@ -0,0 +1,15 @@
{
"TimeFormat": {
"DateFormat": "%Y-%m-%d",
"TimeFormat": "%H:%M:%S",
"DateTimeFormat": "%Y-%m-%d %H:%M:%S.%f",
"DateTimeLogFormat": "%Y-%m-%d_%H-%M-%S"
},
"Log": {
"Path": "logs/",
"Filename": "log_$start_time.log",
"ConsoleLevel": "ERROR",
"Level": "WARNING"
}
}

View File

@@ -1,21 +1,31 @@
from starlette.responses import JSONResponse
from cpl import api
from cpl.api.web_app import WebApp
from cpl.application import ApplicationBuilder
from cpl.core.configuration import Configuration
from cpl.core.environment import Environment
from service import PingService
def main():
builder = ApplicationBuilder[WebApp](WebApp)
Configuration.add_json_file(f"appsettings.json")
Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json")
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
builder.services.add_logging()
builder.services.add_transient(PingService)
builder.services.add_module(api)
app = builder.build()
app.with_logging()
app.with_database()
app.with_authentication()
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET")
app.with_routes_directory("routes")
app.with_logging()
app.with_authentication()
app.run()