Compare commits
7 Commits
2025.09.19
...
2025.09.21
| Author | SHA1 | Date | |
|---|---|---|---|
| 7fc70747bb | |||
| 6de4f3c03a | |||
| ea3055527c | |||
| 7b37748ca6 | |||
| 073b35f71a | |||
| eceff6128b | |||
| 17dfb245bf |
26
.gitea/workflows/test_before_merge.yaml
Normal file
26
.gitea/workflows/test_before_merge.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
name: Test before pr merge
|
||||
run-name: Test before pr merge
|
||||
on:
|
||||
pull_request:
|
||||
types:
|
||||
- opened
|
||||
- edited
|
||||
- reopened
|
||||
- synchronize
|
||||
- ready_for_review
|
||||
|
||||
jobs:
|
||||
test-lint:
|
||||
runs-on: [ runner ]
|
||||
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
|
||||
steps:
|
||||
- name: Clone Repository
|
||||
uses: https://github.com/actions/checkout@v3
|
||||
with:
|
||||
token: ${{ secrets.CI_ACCESS_TOKEN }}
|
||||
|
||||
- name: Installing black
|
||||
run: python3.12 -m pip install black
|
||||
|
||||
- name: Checking black
|
||||
run: python3.12 -m black src --check
|
||||
61
install.sh
Normal file
61
install.sh
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Find and combine requirements from src/cpl-*/requirements.txt,
|
||||
# filtering out lines whose *package name* starts with "cpl-".
|
||||
# Works with pinned versions, extras, markers, editable installs, and VCS refs.
|
||||
|
||||
shopt -s nullglob
|
||||
|
||||
req_files=(src/cpl-*/requirements.txt)
|
||||
if ((${#req_files[@]} == 0)); then
|
||||
echo "No requirements files found at src/cpl-*/requirements.txt" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
tmp_combined="$(mktemp)"
|
||||
trap 'rm -f "$tmp_combined"' EXIT
|
||||
|
||||
# Concatenate, trim comments/whitespace, filter out cpl-* packages, dedupe.
|
||||
# We keep non-package options/flags/constraints as-is.
|
||||
awk '
|
||||
function trim(s){ sub(/^[[:space:]]+/,"",s); sub(/[[:space:]]+$/,"",s); return s }
|
||||
|
||||
{
|
||||
line=$0
|
||||
# drop full-line comments and strip inline comments
|
||||
if (line ~ /^[[:space:]]*#/) next
|
||||
sub(/#[^!].*$/,"",line) # strip trailing comment (simple heuristic)
|
||||
line=trim(line)
|
||||
if (line == "") next
|
||||
|
||||
# Determine the package *name* even for "-e", extras, pins, markers, or VCS "@"
|
||||
e = line
|
||||
sub(/^-e[[:space:]]+/,"",e) # remove editable prefix
|
||||
# Tokenize up to the first of these separators: space, [ < > = ! ~ ; @
|
||||
token = e
|
||||
sub(/\[.*/,"",token) # remove extras quickly
|
||||
n = split(token, a, /[<>=!~;@[:space:]]/)
|
||||
name = tolower(a[1])
|
||||
|
||||
# If the first token (name) starts with "cpl-", skip this requirement
|
||||
if (name ~ /^cpl-/) next
|
||||
|
||||
print line
|
||||
}
|
||||
' "${req_files[@]}" | sort -u > "$tmp_combined"
|
||||
|
||||
if ! [ -s "$tmp_combined" ]; then
|
||||
echo "Nothing to install after filtering out cpl-* packages." >&2
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Installing dependencies (excluding cpl-*) from:"
|
||||
printf ' - %s\n' "${req_files[@]}"
|
||||
echo
|
||||
echo "Final set to install:"
|
||||
cat "$tmp_combined"
|
||||
echo
|
||||
|
||||
# Use python -m pip for reliability; change to python3 if needed.
|
||||
python -m pip install -r "$tmp_combined"
|
||||
@@ -0,0 +1,20 @@
|
||||
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
|
||||
|
||||
def add_api(collection: _ServiceCollection):
|
||||
try:
|
||||
from cpl.database import mysql
|
||||
collection.add_module(mysql)
|
||||
except ImportError as e:
|
||||
from cpl.core.errors import dependency_error
|
||||
dependency_error("cpl-database", e)
|
||||
|
||||
try:
|
||||
from cpl import auth
|
||||
from cpl.auth import permission
|
||||
collection.add_module(auth)
|
||||
collection.add_module(permission)
|
||||
except ImportError as e:
|
||||
from cpl.core.errors import dependency_error
|
||||
dependency_error("cpl-auth", e)
|
||||
|
||||
_ServiceCollection.with_module(add_api, __name__)
|
||||
0
src/cpl-api/cpl/api/abc/__init__.py
Normal file
0
src/cpl-api/cpl/api/abc/__init__.py
Normal file
15
src/cpl-api/cpl/api/abc/asgi_middleware_abc.py
Normal file
15
src/cpl-api/cpl/api/abc/asgi_middleware_abc.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from starlette.types import Scope, Receive, Send
|
||||
|
||||
|
||||
class ASGIMiddleware(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
|
||||
def _call_next(self, scope: Scope, receive: Receive, send: Send):
|
||||
return self._app(scope, receive, send)
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send): ...
|
||||
@@ -1,9 +1,21 @@
|
||||
from http.client import HTTPException
|
||||
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import Scope, Receive, Send
|
||||
|
||||
|
||||
class APIError(HTTPException):
|
||||
status_code = 500
|
||||
|
||||
@classmethod
|
||||
async def asgi_response(cls, scope: Scope, receive: Receive, send: Send):
|
||||
r = JSONResponse({"error": cls.__name__}, status_code=cls.status_code)
|
||||
return await r(scope, receive, send)
|
||||
|
||||
@classmethod
|
||||
def response(cls):
|
||||
return JSONResponse({"error": cls.__name__}, status_code=cls.status_code)
|
||||
|
||||
|
||||
class Unauthorized(APIError):
|
||||
status_code = 401
|
||||
|
||||
76
src/cpl-api/cpl/api/middleware/authentication.py
Normal file
76
src/cpl-api/cpl/api/middleware/authentication.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from keycloak import KeycloakAuthenticationError
|
||||
from starlette.types import Scope, Receive, Send
|
||||
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.api_logger import APILogger
|
||||
from cpl.api.error import Unauthorized
|
||||
from cpl.api.middleware.request import get_request
|
||||
from cpl.api.router import Router
|
||||
from cpl.auth.keycloak import KeycloakClient
|
||||
from cpl.auth.schema import AuthUserDao, AuthUser
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
|
||||
class AuthenticationMiddleware(ASGIMiddleware):
|
||||
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, keycloak: KeycloakClient, user_dao: AuthUserDao):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._keycloak = keycloak
|
||||
self._user_dao = user_dao
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
request = get_request()
|
||||
url = request.url.path
|
||||
|
||||
if url not in Router.get_auth_required_routes():
|
||||
_logger.trace(f"No authentication required for {url}")
|
||||
return await self._app(scope, receive, send)
|
||||
|
||||
if not request.headers.get("Authorization"):
|
||||
_logger.debug(f"Unauthorized access to {url}, missing Authorization header")
|
||||
return await Unauthorized(f"Missing header Authorization").asgi_response(scope, receive, send)
|
||||
|
||||
auth_header = request.headers.get("Authorization", None)
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
return await Unauthorized("Invalid Authorization header").asgi_response(scope, receive, send)
|
||||
|
||||
token = auth_header.split("Bearer ")[1]
|
||||
if not await self._verify_login(token):
|
||||
_logger.debug(f"Unauthorized access to {url}, invalid token")
|
||||
return await Unauthorized("Invalid token").asgi_response(scope, receive, send)
|
||||
|
||||
# check user exists in db, if not create
|
||||
keycloak_id = self._keycloak.get_user_id(token)
|
||||
if keycloak_id is None:
|
||||
return await Unauthorized("Failed to get user id from token").asgi_response(scope, receive, send)
|
||||
|
||||
user = await self._get_or_crate_user(keycloak_id)
|
||||
if user.deleted:
|
||||
_logger.debug(f"Unauthorized access to {url}, user is deleted")
|
||||
return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
|
||||
|
||||
return await self._call_next(scope, receive, send)
|
||||
|
||||
async def _get_or_crate_user(self, keycloak_id: str) -> AuthUser:
|
||||
existing = await self._user_dao.find_by_keycloak_id(keycloak_id)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
user = AuthUser(0, keycloak_id)
|
||||
uid = await self._user_dao.create(user)
|
||||
return await self._user_dao.get_by_id(uid)
|
||||
|
||||
async def _verify_login(self, token: str) -> bool:
|
||||
try:
|
||||
token_info = self._keycloak.introspect(token)
|
||||
return token_info.get("active", False)
|
||||
except KeycloakAuthenticationError as e:
|
||||
_logger.debug(f"Keycloak authentication error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
_logger.error(f"Unexpected error during token verification: {e}")
|
||||
return False
|
||||
@@ -1,21 +1,43 @@
|
||||
import time
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.api_logger import APILogger
|
||||
from cpl.api.middleware.request import get_request
|
||||
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
class LoggingMiddleware(ASGIMiddleware):
|
||||
|
||||
class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
def __init__(self, app):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] != "http":
|
||||
await self._call_next(scope, receive, send)
|
||||
return
|
||||
|
||||
request = get_request()
|
||||
await self._log_request(request)
|
||||
response = await call_next(request)
|
||||
await self._log_after_request(request, response)
|
||||
start_time = time.time()
|
||||
|
||||
return response
|
||||
response_body = b""
|
||||
status_code = 500
|
||||
|
||||
async def send_wrapper(message):
|
||||
nonlocal response_body, status_code
|
||||
if message["type"] == "http.response.start":
|
||||
status_code = message["status"]
|
||||
if message["type"] == "http.response.body":
|
||||
response_body += message.get("body", b"")
|
||||
await send(message)
|
||||
|
||||
await self._call_next(scope, receive, send_wrapper)
|
||||
|
||||
duration = (time.time() - start_time) * 1000
|
||||
await self._log_after_request(request, status_code, duration)
|
||||
|
||||
@staticmethod
|
||||
def _filter_relevant_headers(headers: dict) -> dict:
|
||||
@@ -33,7 +55,7 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
@classmethod
|
||||
async def _log_request(cls, request: Request):
|
||||
_logger.debug(
|
||||
f"Request {request.state.request_id}: {request.method}@{request.url.path} from {request.client.host}"
|
||||
f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}"
|
||||
)
|
||||
|
||||
from cpl.core.ctx.user_context import get_user
|
||||
@@ -55,11 +77,10 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
),
|
||||
}
|
||||
|
||||
_logger.trace(f"Request {request.state.request_id}: {request_info}")
|
||||
_logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
|
||||
|
||||
@staticmethod
|
||||
async def _log_after_request(request: Request, response: Response):
|
||||
duration = (time.time() - request.state.start_time) * 1000
|
||||
async def _log_after_request(request: Request, status_code: int, duration: float):
|
||||
_logger.info(
|
||||
f"Request finished {request.state.request_id}: {response.status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
|
||||
)
|
||||
f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
|
||||
)
|
||||
@@ -3,9 +3,11 @@ from contextvars import ContextVar
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.types import Scope, Receive, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.api_logger import APILogger
|
||||
from cpl.api.typing import TRequest
|
||||
|
||||
@@ -14,35 +16,39 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
|
||||
class RequestMiddleware(BaseHTTPMiddleware):
|
||||
_request_token = {}
|
||||
_user_token = {}
|
||||
class RequestMiddleware(ASGIMiddleware):
|
||||
|
||||
@classmethod
|
||||
async def set_request_data(cls, request: TRequest):
|
||||
def __init__(self, app):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
self._ctx_token = None
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
request = Request(scope, receive, send)
|
||||
await self.set_request_data(request)
|
||||
|
||||
try:
|
||||
await self._app(scope, receive, send)
|
||||
finally:
|
||||
await self.clean_request_data()
|
||||
|
||||
async def set_request_data(self, request: TRequest):
|
||||
request.state.request_id = uuid4()
|
||||
request.state.start_time = time.time()
|
||||
_logger.trace(f"Set new current request: {request.state.request_id}")
|
||||
|
||||
cls._request_token[request.state.request_id] = _request_context.set(request)
|
||||
self._ctx_token = _request_context.set(request)
|
||||
|
||||
@classmethod
|
||||
async def clean_request_data(cls):
|
||||
async def clean_request_data(self):
|
||||
request = get_request()
|
||||
if request is None:
|
||||
return
|
||||
|
||||
if request.state.request_id in cls._request_token:
|
||||
_request_context.reset(cls._request_token[request.state.request_id])
|
||||
if self._ctx_token is None:
|
||||
return
|
||||
|
||||
async def dispatch(self, request: TRequest, call_next):
|
||||
await self.set_request_data(request)
|
||||
try:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
finally:
|
||||
await self.clean_request_data()
|
||||
_logger.trace(f"Clearing current request: {request.state.request_id}")
|
||||
_request_context.reset(self._ctx_token)
|
||||
|
||||
|
||||
def get_request() -> Optional[Union[TRequest, WebSocket]]:
|
||||
return _request_context.get()
|
||||
return _request_context.get()
|
||||
@@ -3,11 +3,35 @@ from starlette.routing import Route
|
||||
|
||||
class Router:
|
||||
_registered_routes: list[Route] = []
|
||||
_auth_required: list[str] = []
|
||||
|
||||
@classmethod
|
||||
def get_routes(cls) -> list[Route]:
|
||||
return cls._registered_routes
|
||||
|
||||
@classmethod
|
||||
def get_auth_required_routes(cls) -> list[str]:
|
||||
return cls._auth_required
|
||||
|
||||
@classmethod
|
||||
def authenticate(cls):
|
||||
"""
|
||||
Decorator to mark a route as requiring authentication.
|
||||
Usage:
|
||||
@Route.authenticate()
|
||||
@Route.get("/example")
|
||||
async def example_endpoint(request: TRequest):
|
||||
...
|
||||
"""
|
||||
|
||||
def inner(fn):
|
||||
route_path = getattr(fn, "_route_path", None)
|
||||
if route_path and route_path not in cls._auth_required:
|
||||
cls._auth_required.append(route_path)
|
||||
return fn
|
||||
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def route(cls, path=None, **kwargs):
|
||||
def inner(fn):
|
||||
@@ -57,4 +81,4 @@ class Router:
|
||||
|
||||
return fn
|
||||
|
||||
return inner
|
||||
return inner
|
||||
|
||||
@@ -10,4 +10,4 @@ HTTPMethods = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
|
||||
PartialMiddleware = Union[
|
||||
Middleware,
|
||||
Callable[[ASGIApp], ASGIApp],
|
||||
]
|
||||
]
|
||||
|
||||
@@ -10,9 +10,11 @@ from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
from starlette.types import ExceptionHandler
|
||||
|
||||
from cpl import api, auth
|
||||
from cpl.api.api_logger import APILogger
|
||||
from cpl.api.api_settings import ApiSettings
|
||||
from cpl.api.error import APIError
|
||||
from cpl.api.middleware.authentication import AuthenticationMiddleware
|
||||
from cpl.api.middleware.logging import LoggingMiddleware
|
||||
from cpl.api.middleware.request import RequestMiddleware
|
||||
from cpl.api.router import Router
|
||||
@@ -24,10 +26,9 @@ from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
_logger = APILogger("API")
|
||||
|
||||
|
||||
|
||||
class WebApp(ApplicationABC):
|
||||
def __init__(self, services: ServiceProviderABC):
|
||||
super().__init__(services)
|
||||
super().__init__(services, [auth, api])
|
||||
self._app: Starlette | None = None
|
||||
|
||||
self._api_settings = Configuration.get(ApiSettings)
|
||||
@@ -37,18 +38,22 @@ class WebApp(ApplicationABC):
|
||||
Middleware(RequestMiddleware),
|
||||
Middleware(LoggingMiddleware),
|
||||
]
|
||||
self._exception_handlers: Mapping[Any, ExceptionHandler] = {Exception: self.handle_exception}
|
||||
self._exception_handlers: Mapping[Any, ExceptionHandler] = {
|
||||
Exception: self._handle_exception,
|
||||
APIError: self._handle_exception,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def handle_exception(request: Request, exc: Exception):
|
||||
async def _handle_exception(request: Request, exc: Exception):
|
||||
if isinstance(exc, APIError):
|
||||
_logger.error(exc)
|
||||
return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
|
||||
|
||||
if hasattr(request.state, "request_id"):
|
||||
_logger.error(f"Request {request.state.request_id}", exc)
|
||||
else:
|
||||
_logger.error("Request unknown", exc)
|
||||
|
||||
if isinstance(exc, APIError):
|
||||
return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
|
||||
|
||||
return JSONResponse({"error": str(exc)}, status_code=500)
|
||||
|
||||
def _get_allowed_origins(self):
|
||||
@@ -61,6 +66,10 @@ class WebApp(ApplicationABC):
|
||||
_logger.debug(f"Allowed origins: {origins}")
|
||||
return origins.split(",")
|
||||
|
||||
def with_database(self):
|
||||
self.with_migrations()
|
||||
self.with_seeders()
|
||||
|
||||
def with_app(self, app: Starlette):
|
||||
assert app is not None, "app must not be None"
|
||||
assert isinstance(app, Starlette), "app must be an instance of Starlette"
|
||||
@@ -96,7 +105,15 @@ class WebApp(ApplicationABC):
|
||||
self._check_for_app()
|
||||
assert path is not None, "path must not be None"
|
||||
assert fn is not None, "fn must not be None"
|
||||
assert method in ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"], "method must be a valid HTTP method"
|
||||
assert method in [
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"DELETE",
|
||||
"PATCH",
|
||||
"OPTIONS",
|
||||
"HEAD",
|
||||
], "method must be a valid HTTP method"
|
||||
self._routes.append(Route(path, fn, methods=[method], **kwargs))
|
||||
return self
|
||||
|
||||
@@ -105,16 +122,21 @@ class WebApp(ApplicationABC):
|
||||
|
||||
if isinstance(middleware, Middleware):
|
||||
self._middleware.append(middleware)
|
||||
|
||||
elif callable(middleware):
|
||||
self._middleware.append(Middleware(middleware))
|
||||
else:
|
||||
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
|
||||
|
||||
|
||||
return self
|
||||
|
||||
def main(self):
|
||||
def with_authentication(self):
|
||||
self.with_middleware(AuthenticationMiddleware)
|
||||
return self
|
||||
|
||||
def with_authorization(self):
|
||||
pass
|
||||
|
||||
async def main(self):
|
||||
_logger.debug(f"Preparing API")
|
||||
if self._app is None:
|
||||
routes = [
|
||||
@@ -144,10 +166,22 @@ class WebApp(ApplicationABC):
|
||||
app = self._app
|
||||
|
||||
_logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
|
||||
uvicorn.run(
|
||||
# uvicorn.run(
|
||||
# app,
|
||||
# host=self._api_settings.host,
|
||||
# port=self._api_settings.port,
|
||||
# log_config=None,
|
||||
# loop="asyncio"
|
||||
# )
|
||||
|
||||
config = uvicorn.Config(
|
||||
app,
|
||||
host=self._api_settings.host,
|
||||
port=self._api_settings.port,
|
||||
log_config=None,
|
||||
loop="asyncio"
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
_logger.info("Shutdown API")
|
||||
|
||||
@@ -3,4 +3,5 @@ cpl-application
|
||||
cpl-core
|
||||
cpl-dependency
|
||||
starlette==0.48.0
|
||||
python-multipart==0.0.20
|
||||
python-multipart==0.0.20
|
||||
uvicorn==0.35.0
|
||||
@@ -22,8 +22,16 @@ class ApplicationABC(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, services: ServiceProviderABC):
|
||||
def __init__(self, services: ServiceProviderABC, required_modules: list[str | object] = None):
|
||||
self._services = services
|
||||
self._required_modules = [
|
||||
x.__name__ if not isinstance(x, str) else x
|
||||
for x in required_modules
|
||||
] if required_modules else []
|
||||
|
||||
@property
|
||||
def required_modules(self) -> list[str]:
|
||||
return self._required_modules
|
||||
|
||||
@classmethod
|
||||
def extend(cls, name: str | Callable, func: Callable[[Self], Self]):
|
||||
@@ -80,7 +88,7 @@ class ApplicationABC(ABC):
|
||||
try:
|
||||
Host.run(self.main)
|
||||
except KeyboardInterrupt:
|
||||
Console.close()
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def main(self): ...
|
||||
|
||||
@@ -6,6 +6,7 @@ from cpl.application.abc.application_extension_abc import ApplicationExtensionAB
|
||||
from cpl.application.abc.startup_abc import StartupABC
|
||||
from cpl.application.abc.startup_extension_abc import StartupExtensionABC
|
||||
from cpl.application.host import Host
|
||||
from cpl.core.errors import dependency_error
|
||||
from cpl.dependency.service_collection import ServiceCollection
|
||||
|
||||
TApp = TypeVar("TApp", bound=ApplicationABC)
|
||||
@@ -35,6 +36,18 @@ class ApplicationBuilder(Generic[TApp]):
|
||||
def service_provider(self):
|
||||
return self._services.build()
|
||||
|
||||
def validate_app_required_modules(self, app: ApplicationABC):
|
||||
for module in app.required_modules:
|
||||
if module in self._services.loaded_modules:
|
||||
continue
|
||||
|
||||
dependency_error(
|
||||
module,
|
||||
ImportError(
|
||||
f"Required module '{module}' for application '{app.__class__.__name__}' is not loaded. Load using 'add_module({module})' method."
|
||||
),
|
||||
)
|
||||
|
||||
def with_startup(self, startup: Type[StartupABC]) -> "ApplicationBuilder":
|
||||
self._startup = startup
|
||||
return self
|
||||
@@ -62,4 +75,6 @@ class ApplicationBuilder(Generic[TApp]):
|
||||
for extension in self._app_extensions:
|
||||
Host.run(extension.run, self.service_provider)
|
||||
|
||||
return self._app(self.service_provider)
|
||||
app = self._app(self.service_provider)
|
||||
self.validate_app_required_modules(app)
|
||||
return app
|
||||
|
||||
@@ -40,11 +40,10 @@ def _add_daos(collection: _ServiceCollection):
|
||||
def add_auth(collection: _ServiceCollection):
|
||||
import os
|
||||
|
||||
from cpl.core.console import Console
|
||||
from cpl.database.service.migration_service import MigrationService
|
||||
from cpl.database.model.server_type import ServerType, ServerTypes
|
||||
|
||||
try:
|
||||
from cpl.database.service.migration_service import MigrationService
|
||||
from cpl.database.model.server_type import ServerType, ServerTypes
|
||||
|
||||
collection.add_singleton(_KeycloakClient)
|
||||
collection.add_singleton(_KeycloakAdmin)
|
||||
|
||||
@@ -59,22 +58,23 @@ def add_auth(collection: _ServiceCollection):
|
||||
elif ServerType.server_type == ServerTypes.MYSQL:
|
||||
migration_service.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/mysql"))
|
||||
except ImportError as e:
|
||||
Console.error("cpl-auth is not installed", str(e))
|
||||
from cpl.core.console import Console
|
||||
Console.error("cpl-database is not installed", str(e))
|
||||
|
||||
|
||||
def add_permission(collection: _ServiceCollection):
|
||||
from cpl.auth.permission_seeder import PermissionSeeder
|
||||
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
||||
from cpl.auth.permission.permissions_registry import PermissionsRegistry
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from .permission_seeder import PermissionSeeder
|
||||
from .permission.permissions_registry import PermissionsRegistry
|
||||
from .permission.permissions import Permissions
|
||||
|
||||
try:
|
||||
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
||||
collection.add_singleton(DataSeederABC, PermissionSeeder)
|
||||
PermissionsRegistry.with_enum(Permissions)
|
||||
except ImportError as e:
|
||||
from cpl.core.console import Console
|
||||
|
||||
Console.error("cpl-auth is not installed", str(e))
|
||||
Console.error("cpl-database is not installed", str(e))
|
||||
|
||||
|
||||
_ServiceCollection.with_module(add_auth, __name__)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from keycloak import KeycloakOpenID, KeycloakAdmin, KeycloakOpenIDConnection
|
||||
from typing import Optional
|
||||
|
||||
from keycloak import KeycloakOpenID
|
||||
|
||||
from cpl.auth.auth_logger import AuthLogger
|
||||
from cpl.auth.keycloak_settings import KeycloakSettings
|
||||
@@ -17,10 +19,7 @@ class KeycloakClient(KeycloakOpenID):
|
||||
client_secret_key=settings.client_secret,
|
||||
)
|
||||
_logger.info("Initializing Keycloak client")
|
||||
connection = KeycloakOpenIDConnection(
|
||||
server_url=settings.url,
|
||||
client_id=settings.client_id,
|
||||
realm_name=settings.realm,
|
||||
client_secret_key=settings.client_secret,
|
||||
)
|
||||
self._admin = KeycloakAdmin(connection=connection)
|
||||
|
||||
def get_user_id(self, token: str) -> Optional[str]:
|
||||
info = self.introspect(token)
|
||||
return info.get("sub", None)
|
||||
@@ -16,7 +16,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, __name__, AuthUser, TableManager.get("auth_users"))
|
||||
|
||||
self.attribute(AuthUser.keycloak_id, str, aliases=["keycloakId"])
|
||||
self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId")
|
||||
|
||||
async def get_users():
|
||||
return [(x.id, x.username, x.email) for x in await self.get_all()]
|
||||
|
||||
@@ -14,7 +14,7 @@ CREATE TABLE IF NOT EXISTS administration_auth_users
|
||||
|
||||
CREATE TABLE IF NOT EXISTS administration_auth_users_history
|
||||
(
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
id INT NOT NULL,
|
||||
keycloakId CHAR(36) NOT NULL,
|
||||
-- for history
|
||||
deleted BOOL NOT NULL,
|
||||
|
||||
@@ -15,7 +15,7 @@ CREATE TABLE IF NOT EXISTS administration_api_keys
|
||||
|
||||
CREATE TABLE IF NOT EXISTS administration_api_keys_history
|
||||
(
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
id INT NOT NULL,
|
||||
identifier VARCHAR(255) NOT NULL,
|
||||
keyString VARCHAR(255) NOT NULL,
|
||||
deleted BOOL NOT NULL,
|
||||
|
||||
@@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS permission_permissions
|
||||
|
||||
CREATE TABLE IF NOT EXISTS permission_permissions_history
|
||||
(
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
id INT NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT NULL,
|
||||
deleted BOOL NOT NULL,
|
||||
@@ -57,7 +57,7 @@ CREATE TABLE IF NOT EXISTS permission_roles
|
||||
|
||||
CREATE TABLE IF NOT EXISTS permission_roles_history
|
||||
(
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
id INT NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT NULL,
|
||||
deleted BOOL NOT NULL,
|
||||
@@ -103,7 +103,7 @@ CREATE TABLE IF NOT EXISTS permission_role_permissions
|
||||
|
||||
CREATE TABLE IF NOT EXISTS permission_role_permissions_history
|
||||
(
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
id INT NOT NULL,
|
||||
RoleId INT NOT NULL,
|
||||
permissionId INT NOT NULL,
|
||||
deleted BOOL NOT NULL,
|
||||
@@ -149,7 +149,7 @@ CREATE TABLE IF NOT EXISTS permission_role_auth_users
|
||||
|
||||
CREATE TABLE IF NOT EXISTS permission_role_auth_users_history
|
||||
(
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
id INT NOT NULL,
|
||||
RoleId INT NOT NULL,
|
||||
UserId INT NOT NULL,
|
||||
deleted BOOL NOT NULL,
|
||||
|
||||
@@ -15,7 +15,7 @@ CREATE TABLE IF NOT EXISTS permission_api_key_permissions
|
||||
|
||||
CREATE TABLE IF NOT EXISTS permission_api_key_permissions_history
|
||||
(
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
id INT NOT NULL,
|
||||
apiKeyId INT NOT NULL,
|
||||
permissionId INT NOT NULL,
|
||||
deleted BOOL NOT NULL,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
cpl-core
|
||||
cpl-dependency
|
||||
cpl-database
|
||||
python-keycloak-5.8.1
|
||||
python-keycloak==5.8.1
|
||||
@@ -130,7 +130,7 @@ class Configuration:
|
||||
key_name = key.__name__ if inspect.isclass(key) else key
|
||||
|
||||
result = cls._config.get(key_name, default)
|
||||
if issubclass(key, ConfigurationModelABC) and result == default:
|
||||
if isclass(key) and issubclass(key, ConfigurationModelABC) and result == default:
|
||||
result = key()
|
||||
cls.set(key, result)
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class ConfigurationModelABC(ABC):
|
||||
value = cast(Environment.get(env_field, str), cast_type)
|
||||
|
||||
if value is None and required:
|
||||
raise ValueError(f"{field} is required")
|
||||
raise ValueError(f"{type(self).__name__}.{field} is required")
|
||||
elif value is None:
|
||||
self._options[field] = default
|
||||
return
|
||||
|
||||
15
src/cpl-core/cpl/core/errors.py
Normal file
15
src/cpl-core/cpl/core/errors.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import traceback
|
||||
|
||||
from cpl.core.console import Console
|
||||
|
||||
|
||||
def dependency_error(package_name: str, e: ImportError) -> None:
|
||||
Console.error(f"'{package_name}' is required to use this feature. Please install it and try again.")
|
||||
tb = traceback.format_exc()
|
||||
if not tb.startswith("NoneType: None"):
|
||||
Console.write_line("->", tb)
|
||||
|
||||
elif e is not None:
|
||||
Console.write_line("->", str(e))
|
||||
|
||||
exit(1)
|
||||
@@ -2,5 +2,4 @@ art==6.5
|
||||
colorama==0.4.6
|
||||
tabulate==0.9.0
|
||||
termcolor==3.1.0
|
||||
mysql-connector-python==9.4.0
|
||||
pynput==1.8.1
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Type
|
||||
|
||||
from cpl.application.abc import ApplicationABC as _ApplicationABC
|
||||
@@ -7,13 +8,19 @@ from . import postgres as _postgres
|
||||
from .table_manager import TableManager
|
||||
|
||||
|
||||
def _with_migrations(self: _ApplicationABC, *paths: list[str]) -> _ApplicationABC:
|
||||
def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC:
|
||||
from cpl.application.host import Host
|
||||
|
||||
from cpl.database.service.migration_service import MigrationService
|
||||
|
||||
migration_service = self._services.get_service(MigrationService)
|
||||
migration_service.with_directory("./scripts")
|
||||
migration_service.with_directory(os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts"))
|
||||
|
||||
if isinstance(paths, str):
|
||||
paths = [paths]
|
||||
|
||||
for path in paths:
|
||||
migration_service.with_directory(path)
|
||||
|
||||
Host.run(migration_service.migrate)
|
||||
|
||||
return self
|
||||
|
||||
@@ -156,13 +156,16 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
|
||||
:param dict result: Result from the database
|
||||
:return:
|
||||
"""
|
||||
value_map: dict[str, T] = {}
|
||||
value_map: dict[str, Any] = {}
|
||||
db_names = self.__db_names.items()
|
||||
|
||||
for db_name, value in result.items():
|
||||
# Find the attribute name corresponding to the db_name
|
||||
attr_name = next((k for k, v in self.__db_names.items() if v == db_name), None)
|
||||
if attr_name:
|
||||
value_map[attr_name] = self._get_value_from_sql(self.__attributes[attr_name], value)
|
||||
attr_name = next((k for k, v in db_names if v == db_name), None)
|
||||
if not attr_name:
|
||||
continue
|
||||
|
||||
value_map[attr_name] = self._get_value_from_sql(self.__attributes[attr_name], value)
|
||||
|
||||
return self._model_type(**value_map)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
import sqlparse
|
||||
import aiomysql
|
||||
from mysql.connector.aio import MySQLConnectionPool
|
||||
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.database.db_logger import DBLogger
|
||||
@@ -9,97 +9,83 @@ from cpl.database.model import DatabaseSettings
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class MySQLPool:
|
||||
"""
|
||||
Create a pool when connecting to MySQL, which will decrease the time spent in
|
||||
requesting connection, creating connection, and closing connection.
|
||||
"""
|
||||
|
||||
def __init__(self, database_settings: DatabaseSettings):
|
||||
self._db_settings = database_settings
|
||||
self.pool: Optional[aiomysql.Pool] = None
|
||||
self._dbconfig = {
|
||||
"host": database_settings.host,
|
||||
"port": database_settings.port,
|
||||
"user": database_settings.user,
|
||||
"password": database_settings.password,
|
||||
"database": database_settings.database,
|
||||
"ssl_disabled": True,
|
||||
}
|
||||
self._pool: Optional[MySQLConnectionPool] = None
|
||||
|
||||
async def _get_pool(self):
|
||||
if self.pool is None or self.pool._closed:
|
||||
if self._pool is None:
|
||||
self._pool = MySQLConnectionPool(
|
||||
pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig
|
||||
)
|
||||
await self._pool.initialize_pool()
|
||||
|
||||
con = await self._pool.get_connection()
|
||||
try:
|
||||
self.pool = await aiomysql.create_pool(
|
||||
host=self._db_settings.host,
|
||||
port=self._db_settings.port,
|
||||
user=self._db_settings.user,
|
||||
password=self._db_settings.password,
|
||||
db=self._db_settings.database,
|
||||
minsize=1,
|
||||
maxsize=Environment.get("DB_POOL_SIZE", int, 1),
|
||||
)
|
||||
async with await con.cursor() as cursor:
|
||||
await cursor.execute("SELECT 1")
|
||||
await cursor.fetchall()
|
||||
except Exception as e:
|
||||
_logger.fatal("Failed to connect to the database", e)
|
||||
raise
|
||||
return self.pool
|
||||
_logger.fatal(f"Error connecting to the database: {e}")
|
||||
finally:
|
||||
await con.close()
|
||||
|
||||
return self._pool
|
||||
|
||||
@staticmethod
|
||||
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
|
||||
result = []
|
||||
if multi:
|
||||
queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()]
|
||||
for q in queries:
|
||||
if q.strip() == "":
|
||||
continue
|
||||
await cursor.execute(q, args)
|
||||
if cursor.description is not None:
|
||||
result = await cursor.fetchall()
|
||||
else:
|
||||
await cursor.execute(query, args)
|
||||
if cursor.description is not None:
|
||||
result = await cursor.fetchall()
|
||||
|
||||
return result
|
||||
|
||||
async def execute(self, query: str, args=None, multi=True) -> list[list]:
|
||||
"""
|
||||
Execute a SQL statement, it could be with args and without args. The usage is
|
||||
similar to the execute() function in aiomysql.
|
||||
:param query: SQL clause
|
||||
:param args: args needed by the SQL clause
|
||||
:param multi: if the query is a multi-statement
|
||||
:return: return result
|
||||
"""
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as con:
|
||||
async with con.cursor() as cursor:
|
||||
await self._exec_sql(cursor, query, args, multi)
|
||||
con = await pool.get_connection()
|
||||
try:
|
||||
async with await con.cursor() as cursor:
|
||||
result = await self._exec_sql(cursor, query, args, multi)
|
||||
await con.commit()
|
||||
|
||||
if cursor.description is not None: # Query returns rows
|
||||
res = await cursor.fetchall()
|
||||
if res is None:
|
||||
return []
|
||||
|
||||
return [list(row) for row in res]
|
||||
else:
|
||||
return []
|
||||
return result
|
||||
finally:
|
||||
await con.close()
|
||||
|
||||
async def select(self, query: str, args=None, multi=True) -> list[str]:
|
||||
"""
|
||||
Execute a SQL statement, it could be with args and without args. The usage is
|
||||
similar to the execute() function in aiomysql.
|
||||
:param query: SQL clause
|
||||
:param args: args needed by the SQL clause
|
||||
:param multi: if the query is a multi-statement
|
||||
:return: return result
|
||||
"""
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as con:
|
||||
async with con.cursor() as cursor:
|
||||
await self._exec_sql(cursor, query, args, multi)
|
||||
res = await cursor.fetchall()
|
||||
con = await pool.get_connection()
|
||||
try:
|
||||
async with await con.cursor() as cursor:
|
||||
res = await self._exec_sql(cursor, query, args, multi)
|
||||
return list(res)
|
||||
finally:
|
||||
await con.close()
|
||||
|
||||
async def select_map(self, query: str, args=None, multi=True) -> list[dict]:
|
||||
"""
|
||||
Execute a SQL statement, it could be with args and without args. The usage is
|
||||
similar to the execute() function in aiomysql.
|
||||
:param query: SQL clause
|
||||
:param args: args needed by the SQL clause
|
||||
:param multi: if the query is a multi-statement
|
||||
:return: return result
|
||||
"""
|
||||
pool = await self._get_pool()
|
||||
async with pool.acquire() as con:
|
||||
async with con.cursor(aiomysql.DictCursor) as cursor:
|
||||
await self._exec_sql(cursor, query, args, multi)
|
||||
res = await cursor.fetchall()
|
||||
con = await pool.get_connection()
|
||||
try:
|
||||
async with await con.cursor(dictionary=True) as cursor:
|
||||
res = await self._exec_sql(cursor, query, args, multi)
|
||||
return list(res)
|
||||
finally:
|
||||
await con.close()
|
||||
|
||||
@@ -25,21 +25,23 @@ class PostgresPool:
|
||||
f"password={database_settings.password} "
|
||||
f"dbname={database_settings.database}"
|
||||
)
|
||||
|
||||
self.pool: Optional[AsyncConnectionPool] = None
|
||||
self._pool: Optional[AsyncConnectionPool] = None
|
||||
|
||||
async def _get_pool(self):
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
|
||||
)
|
||||
await pool.open()
|
||||
try:
|
||||
async with pool.connection() as con:
|
||||
await pool.check_connection(con)
|
||||
except PoolTimeout as e:
|
||||
await pool.close()
|
||||
_logger.fatal(f"Failed to connect to the database", e)
|
||||
return pool
|
||||
if self._pool is None:
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
|
||||
)
|
||||
await pool.open()
|
||||
try:
|
||||
async with pool.connection() as con:
|
||||
await pool.check_connection(con)
|
||||
except PoolTimeout as e:
|
||||
await pool.close()
|
||||
_logger.fatal(f"Failed to connect to the database", e)
|
||||
self._pool = pool
|
||||
|
||||
return self._pool
|
||||
|
||||
@staticmethod
|
||||
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Union, Type, Callable
|
||||
from typing import Union, Type, Callable, Self
|
||||
|
||||
from cpl.core.log.logger import Logger
|
||||
from cpl.core.log.logger_abc import LoggerABC
|
||||
@@ -15,12 +15,17 @@ class ServiceCollection:
|
||||
_modules: dict[str, Callable] = {}
|
||||
|
||||
@classmethod
|
||||
def with_module(cls, func: Callable, name: str = None):
|
||||
def with_module(cls, func: Callable, name: str = None) -> type[Self]:
|
||||
cls._modules[func.__name__ if name is None else name] = func
|
||||
return cls
|
||||
|
||||
def __init__(self):
|
||||
self._service_descriptors: list[ServiceDescriptor] = []
|
||||
self._loaded_modules: set[str] = set()
|
||||
|
||||
@property
|
||||
def loaded_modules(self) -> set[str]:
|
||||
return self._loaded_modules
|
||||
|
||||
def _add_descriptor(self, service: Union[type, object], lifetime: ServiceLifetimeEnum, base_type: Callable = None):
|
||||
found = False
|
||||
@@ -45,15 +50,15 @@ class ServiceCollection:
|
||||
|
||||
return self
|
||||
|
||||
def add_singleton(self, service_type: T, service: Service = None):
|
||||
def add_singleton(self, service_type: T, service: Service = None) -> Self:
|
||||
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.singleton, service)
|
||||
return self
|
||||
|
||||
def add_scoped(self, service_type: T, service: Service = None):
|
||||
def add_scoped(self, service_type: T, service: Service = None) -> Self:
|
||||
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.scoped, service)
|
||||
return self
|
||||
|
||||
def add_transient(self, service_type: T, service: Service = None):
|
||||
def add_transient(self, service_type: T, service: Service = None) -> Self:
|
||||
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service)
|
||||
return self
|
||||
|
||||
@@ -62,7 +67,7 @@ class ServiceCollection:
|
||||
ServiceProviderABC.set_global_provider(sp)
|
||||
return sp
|
||||
|
||||
def add_module(self, module: str | object):
|
||||
def add_module(self, module: str | object) -> Self:
|
||||
if not isinstance(module, str):
|
||||
module = module.__name__
|
||||
|
||||
@@ -70,7 +75,10 @@ class ServiceCollection:
|
||||
raise ValueError(f"Module {module} not found")
|
||||
|
||||
self._modules[module](self)
|
||||
if module not in self._loaded_modules:
|
||||
self._loaded_modules.add(module)
|
||||
return self
|
||||
|
||||
def add_logging(self):
|
||||
def add_logging(self) -> Self:
|
||||
self.add_transient(LoggerABC, Logger)
|
||||
return self
|
||||
|
||||
@@ -77,7 +77,7 @@ class ServiceProvider(ServiceProviderABC):
|
||||
|
||||
return implementations
|
||||
|
||||
def _build_by_signature(self, sig: Signature, origin_service_type: type=None) -> list[R]:
|
||||
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[R]:
|
||||
params = []
|
||||
for param in sig.parameters.items():
|
||||
parameter = param[1]
|
||||
|
||||
@@ -24,19 +24,19 @@ class ServiceProviderABC(ABC):
|
||||
return cls._provider
|
||||
|
||||
@classmethod
|
||||
def get_global_service(cls, instance_type: T, *args, **kwargs) -> Optional[R]:
|
||||
def get_global_service(cls, instance_type: Type[T], *args, **kwargs) -> Optional[T]:
|
||||
if cls._provider is None:
|
||||
return None
|
||||
return cls._provider.get_service(instance_type, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_global_services(cls, instance_type: T, *args, **kwargs) -> list[Optional[R]]:
|
||||
def get_global_services(cls, instance_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
|
||||
if cls._provider is None:
|
||||
return []
|
||||
return cls._provider.get_services(instance_type, *args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _build_by_signature(self, sig: Signature, origin_service_type: type=None) -> list[R]: ...
|
||||
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[T]: ...
|
||||
|
||||
@abstractmethod
|
||||
def _build_service(self, service_type: type, *args, **kwargs) -> object:
|
||||
@@ -114,14 +114,22 @@ class ServiceProviderABC(ABC):
|
||||
if f is None:
|
||||
return functools.partial(cls.inject)
|
||||
|
||||
if iscoroutinefunction(f):
|
||||
@functools.wraps(f)
|
||||
async def async_inner(*args, **kwargs):
|
||||
if cls._provider is None:
|
||||
raise Exception(f"{cls.__name__} not build!")
|
||||
|
||||
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
|
||||
return await f(*args, *injection, **kwargs)
|
||||
|
||||
return async_inner
|
||||
|
||||
@functools.wraps(f)
|
||||
async def inner(*args, **kwargs):
|
||||
def inner(*args, **kwargs):
|
||||
if cls._provider is None:
|
||||
raise Exception(f"{cls.__name__} not build!")
|
||||
|
||||
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
|
||||
if iscoroutinefunction(f):
|
||||
return await f(*args, *injection, **kwargs)
|
||||
return f(*args, *injection, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
8
tests/custom/api/src/appsettings.development.json
Normal file
8
tests/custom/api/src/appsettings.development.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"Logging": {
|
||||
"Path": "logs/",
|
||||
"Filename": "log_$start_time.log",
|
||||
"ConsoleLevel": "TRACE",
|
||||
"Level": "TRACE"
|
||||
}
|
||||
}
|
||||
26
tests/custom/api/src/appsettings.edrafts-pc.json
Normal file
26
tests/custom/api/src/appsettings.edrafts-pc.json
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"TimeFormat": {
|
||||
"DateFormat": "%Y-%m-%d",
|
||||
"TimeFormat": "%H:%M:%S",
|
||||
"DateTimeFormat": "%Y-%m-%d %H:%M:%S.%f",
|
||||
"DateTimeLogFormat": "%Y-%m-%d_%H-%M-%S"
|
||||
},
|
||||
|
||||
"Log": {
|
||||
"Path": "logs/",
|
||||
"Filename": "log_$start_time.log",
|
||||
"ConsoleLevel": "TRACE",
|
||||
"Level": "TRACE"
|
||||
},
|
||||
|
||||
"Database": {
|
||||
"Host": "localhost",
|
||||
"User": "cpl",
|
||||
"Port": 3306,
|
||||
"Password": "cpl",
|
||||
"Database": "cpl",
|
||||
"Charset": "utf8mb4",
|
||||
"UseUnicode": "true",
|
||||
"Buffered": "true"
|
||||
}
|
||||
}
|
||||
15
tests/custom/api/src/appsettings.json
Normal file
15
tests/custom/api/src/appsettings.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"TimeFormat": {
|
||||
"DateFormat": "%Y-%m-%d",
|
||||
"TimeFormat": "%H:%M:%S",
|
||||
"DateTimeFormat": "%Y-%m-%d %H:%M:%S.%f",
|
||||
"DateTimeLogFormat": "%Y-%m-%d_%H-%M-%S"
|
||||
},
|
||||
|
||||
"Log": {
|
||||
"Path": "logs/",
|
||||
"Filename": "log_$start_time.log",
|
||||
"ConsoleLevel": "ERROR",
|
||||
"Level": "WARNING"
|
||||
}
|
||||
}
|
||||
@@ -1,20 +1,31 @@
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from cpl import api
|
||||
from cpl.api.web_app import WebApp
|
||||
from cpl.application import ApplicationBuilder
|
||||
from custom.api.src.service import PingService
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.environment import Environment
|
||||
from service import PingService
|
||||
|
||||
|
||||
def main():
|
||||
builder = ApplicationBuilder[WebApp](WebApp)
|
||||
|
||||
Configuration.add_json_file(f"appsettings.json")
|
||||
Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json")
|
||||
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
|
||||
|
||||
builder.services.add_logging()
|
||||
builder.services.add_transient(PingService)
|
||||
builder.services.add_module(api)
|
||||
|
||||
app = builder.build()
|
||||
app.with_logging()
|
||||
app.with_database()
|
||||
|
||||
app.with_authentication()
|
||||
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET")
|
||||
app.with_routes_directory("routes")
|
||||
app.with_logging()
|
||||
|
||||
app.run()
|
||||
|
||||
|
||||
@@ -4,9 +4,10 @@ from starlette.responses import JSONResponse
|
||||
|
||||
from cpl.api.router import Router
|
||||
from cpl.core.log import Logger
|
||||
from custom.api.src.service import PingService
|
||||
from service import PingService
|
||||
|
||||
|
||||
@Router.authenticate()
|
||||
@Router.get(f"/ping")
|
||||
async def ping(r: Request, ping: PingService, logger: Logger):
|
||||
logger.info(f"Ping: {ping}")
|
||||
|
||||
Reference in New Issue
Block a user