Compare commits

..

1 Commits

Author SHA1 Message Date
b2344a8065 Fixed api build
Some checks failed
Build on push / prepare (push) Successful in 8s
Build on push / core (push) Successful in 18s
Build on push / query (push) Successful in 17s
Build on push / dependency (push) Successful in 17s
Build on push / application (push) Successful in 15s
Build on push / translation (push) Successful in 15s
Build on push / database (push) Successful in 18s
Build on push / mail (push) Successful in 18s
Build on push / auth (push) Successful in 17s
Build on push / api (push) Failing after 14s
2025-09-19 21:07:00 +02:00
192 changed files with 1504 additions and 2492 deletions

View File

@@ -1,26 +0,0 @@
name: Test before pr merge
run-name: Test before pr merge
on:
pull_request:
types:
- opened
- edited
- reopened
- synchronize
- ready_for_review
jobs:
test-lint:
runs-on: [ runner ]
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
uses: https://github.com/actions/checkout@v3
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Installing black
run: python3.12 -m pip install black
- name: Checking black
run: python3.12 -m black src --check

3
.gitignore vendored
View File

@@ -139,6 +139,3 @@ PythonImportHelper-v2-Completion.json
# cpl unittest stuff # cpl unittest stuff
unittests/test_*_playground unittests/test_*_playground
# cpl logs
**/logs/*.jsonl

View File

@@ -1,47 +0,0 @@
from starlette.responses import JSONResponse
from cpl import api
from cpl.api.application.web_app import WebApp
from cpl.application import ApplicationBuilder
from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema import AuthUser, Role
from cpl.core.configuration import Configuration
from cpl.core.environment import Environment
from cpl.core.utils.cache import Cache
from service import PingService
def main():
builder = ApplicationBuilder[WebApp](WebApp)
Configuration.add_json_file(f"appsettings.json")
Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json")
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
# builder.services.add_logging()
builder.services.add_structured_logging()
builder.services.add_transient(PingService)
builder.services.add_module(api)
builder.services.add_cache(AuthUser)
builder.services.add_cache(Role)
app = builder.build()
app.with_logging()
app.with_database()
app.with_authentication()
app.with_authorization()
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET", authentication=True, permissions=[Permissions.administrator])
app.with_routes_directory("routes")
provider = builder.service_provider
user_cache = provider.get_service(Cache[AuthUser])
role_cache = provider.get_service(Cache[Role])
app.run()
if __name__ == "__main__":
main()

View File

@@ -1,16 +0,0 @@
from urllib.request import Request
from service import PingService
from starlette.responses import JSONResponse
from cpl.api import APILogger
from cpl.api.router import Router
@Router.authenticate()
# @Router.authorize(permissions=[Permissions.administrator])
# @Router.authorize(policies=["test"])
@Router.get(f"/ping")
async def ping(r: Request, ping: PingService, logger: APILogger):
logger.info(f"Ping: {ping}")
return JSONResponse(ping.ping(r))

View File

@@ -1,26 +0,0 @@
{
"TimeFormat": {
"DateFormat": "%Y-%m-%d",
"TimeFormat": "%H:%M:%S",
"DateTimeFormat": "%Y-%m-%d %H:%M:%S.%f",
"DateTimeLogFormat": "%Y-%m-%d_%H-%M-%S"
},
"Log": {
"Path": "logs/",
"Filename": "log_$start_time.log",
"ConsoleLevel": "TRACE",
"Level": "TRACE"
},
"Database": {
"Host": "localhost",
"User": "cpl",
"Port": 3306,
"Password": "cpl",
"Database": "cpl",
"Charset": "utf8mb4",
"UseUnicode": "true",
"Buffered": "true"
}
}

View File

@@ -1,15 +0,0 @@
{
"TimeFormat": {
"DateFormat": "%Y-%m-%d",
"TimeFormat": "%H:%M:%S",
"DateTimeFormat": "%Y-%m-%d %H:%M:%S.%f",
"DateTimeLogFormat": "%Y-%m-%d_%H-%M-%S"
},
"Log": {
"Path": "logs/",
"Filename": "log_$start_time.log",
"ConsoleLevel": "ERROR",
"Level": "WARNING"
}
}

View File

@@ -1,8 +0,0 @@
{
"Logging": {
"Path": "logs/",
"Filename": "log_$start_time.log",
"ConsoleLevel": "TRACE",
"Level": "TRACE"
}
}

View File

@@ -1,60 +0,0 @@
from cpl.core.console import Console
from cpl.core.utils.benchmark import Benchmark
from cpl.query.enumerable import Enumerable
from cpl.query.immutable_list import ImmutableList
from cpl.query.list import List
from cpl.query.set import Set
def _default():
Console.write_line(Enumerable.empty().to_list())
Console.write_line(Enumerable.range(0, 100).length)
Console.write_line(Enumerable.range(0, 100).to_list())
Console.write_line(Enumerable.range(0, 100).where(lambda x: x % 2 == 0).length)
Console.write_line(
Enumerable.range(0, 100).where(lambda x: x % 2 == 0).to_list().select(lambda x: str(x)).to_list()
)
Console.write_line(List)
s =Enumerable.range(0, 10).to_set()
Console.write_line(s)
s.add(1)
Console.write_line(s)
data = Enumerable(
[
{"name": "Alice", "age": 30},
{"name": "Dave", "age": 35},
{"name": "Charlie", "age": 25},
{"name": "Bob", "age": 25},
]
)
Console.write_line(data.order_by(lambda x: x["age"]).to_list())
Console.write_line(data.order_by(lambda x: x["age"]).then_by(lambda x: x["name"]).to_list())
Console.write_line(data.order_by(lambda x: x["name"]).then_by(lambda x: x["age"]).to_list())
def t_benchmark(data: list):
Benchmark.all("Enumerable", lambda: Enumerable(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
Benchmark.all("Set", lambda: Set(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
Benchmark.all("List", lambda: List(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
Benchmark.all(
"ImmutableList", lambda: ImmutableList(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list()
)
Benchmark.all("List comprehension", lambda: [x * 2 for x in data if x % 2 == 0])
def main():
N = 10_000_000
data = list(range(N))
#t_benchmark(data)
Console.write_line()
_default()
if __name__ == "__main__":
main()

View File

@@ -1,61 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
# Find and combine requirements from src/cpl-*/requirements.txt,
# filtering out lines whose *package name* starts with "cpl-".
# Works with pinned versions, extras, markers, editable installs, and VCS refs.
shopt -s nullglob
req_files=(src/cpl-*/requirements.txt)
if ((${#req_files[@]} == 0)); then
echo "No requirements files found at src/cpl-*/requirements.txt" >&2
exit 1
fi
tmp_combined="$(mktemp)"
trap 'rm -f "$tmp_combined"' EXIT
# Concatenate, trim comments/whitespace, filter out cpl-* packages, dedupe.
# We keep non-package options/flags/constraints as-is.
awk '
function trim(s){ sub(/^[[:space:]]+/,"",s); sub(/[[:space:]]+$/,"",s); return s }
{
line=$0
# drop full-line comments and strip inline comments
if (line ~ /^[[:space:]]*#/) next
sub(/#[^!].*$/,"",line) # strip trailing comment (simple heuristic)
line=trim(line)
if (line == "") next
# Determine the package *name* even for "-e", extras, pins, markers, or VCS "@"
e = line
sub(/^-e[[:space:]]+/,"",e) # remove editable prefix
# Tokenize up to the first of these separators: space, [ < > = ! ~ ; @
token = e
sub(/\[.*/,"",token) # remove extras quickly
n = split(token, a, /[<>=!~;@[:space:]]/)
name = tolower(a[1])
# If the first token (name) starts with "cpl-", skip this requirement
if (name ~ /^cpl-/) next
print line
}
' "${req_files[@]}" | sort -u > "$tmp_combined"
if ! [ -s "$tmp_combined" ]; then
echo "Nothing to install after filtering out cpl-* packages." >&2
exit 0
fi
echo "Installing dependencies (excluding cpl-*) from:"
printf ' - %s\n' "${req_files[@]}"
echo
echo "Final set to install:"
cat "$tmp_combined"
echo
# Use python -m pip for reliability; change to python3 if needed.
python -m pip install -r "$tmp_combined"

View File

@@ -1,36 +0,0 @@
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
from .error import APIError, AlreadyExists, EndpointNotImplemented, Forbidden, NotFound, Unauthorized
from .logger import APILogger
from .settings import ApiSettings
def add_api(collection: _ServiceCollection):
try:
from cpl.database import mysql
collection.add_module(mysql)
except ImportError as e:
from cpl.core.errors import dependency_error
dependency_error("cpl-database", e)
try:
from cpl import auth
from cpl.auth import permission
collection.add_module(auth)
collection.add_module(permission)
except ImportError as e:
from cpl.core.errors import dependency_error
dependency_error("cpl-auth", e)
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.registry.route import RouteRegistry
collection.add_singleton(PolicyRegistry)
collection.add_singleton(RouteRegistry)
_ServiceCollection.with_module(add_api, __name__)

View File

@@ -1 +0,0 @@
from .asgi_middleware_abc import ASGIMiddleware

View File

@@ -1,15 +0,0 @@
from abc import ABC, abstractmethod
from starlette.types import Scope, Receive, Send
class ASGIMiddleware(ABC):
@abstractmethod
def __init__(self, app):
self._app = app
def _call_next(self, scope: Scope, receive: Receive, send: Send):
return self._app(scope, receive, send)
@abstractmethod
async def __call__(self, scope: Scope, receive: Receive, send: Send): ...

View File

@@ -0,0 +1,7 @@
from cpl.core.log.logger import Logger
class APILogger(Logger):
def __init__(self, source: str):
Logger.__init__(self, source, "api")

View File

@@ -1 +0,0 @@
from .web_app import WebApp

View File

@@ -1,249 +0,0 @@
import os
from enum import Enum
from typing import Mapping, Any, Callable, Self, Union
import uvicorn
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.types import ExceptionHandler
from cpl import api, auth
from cpl.api.error import APIError
from cpl.api.logger import APILogger
from cpl.api.middleware.authentication import AuthenticationMiddleware
from cpl.api.middleware.authorization import AuthorizationMiddleware
from cpl.api.middleware.logging import LoggingMiddleware
from cpl.api.middleware.request import RequestMiddleware
from cpl.api.model.api_route import ApiRoute
from cpl.api.model.policy import Policy
from cpl.api.model.validation_match import ValidationMatch
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.registry.route import RouteRegistry
from cpl.api.router import Router
from cpl.api.settings import ApiSettings
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
from cpl.application.abc.application_abc import ApplicationABC
from cpl.core.configuration import Configuration
from cpl.dependency.service_provider_abc import ServiceProviderABC
PolicyInput = Union[dict[str, PolicyResolver], Policy]
class WebApp(ApplicationABC):
def __init__(self, services: ServiceProviderABC):
super().__init__(services, [auth, api])
self._app: Starlette | None = None
self._logger = services.get_service(APILogger)
self._api_settings = Configuration.get(ApiSettings)
self._policies = services.get_service(PolicyRegistry)
self._routes = services.get_service(RouteRegistry)
self._middleware: list[Middleware] = [
Middleware(RequestMiddleware),
Middleware(LoggingMiddleware),
]
self._exception_handlers: Mapping[Any, ExceptionHandler] = {
Exception: self._handle_exception,
APIError: self._handle_exception,
}
async def _handle_exception(self, request: Request, exc: Exception):
if isinstance(exc, APIError):
self._logger.error(exc)
return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
if hasattr(request.state, "request_id"):
self._logger.error(f"Request {request.state.request_id}", exc)
else:
self._logger.error("Request unknown", exc)
return JSONResponse({"error": str(exc)}, status_code=500)
def _get_allowed_origins(self):
origins = self._api_settings.allowed_origins
if origins is None or origins == "":
self._logger.warning("No allowed origins specified, allowing all origins")
return ["*"]
self._logger.debug(f"Allowed origins: {origins}")
return origins.split(",")
def with_database(self) -> Self:
self.with_migrations()
self.with_seeders()
return self
def with_app(self, app: Starlette) -> Self:
assert app is not None, "app must not be None"
assert isinstance(app, Starlette), "app must be an instance of Starlette"
self._app = app
return self
def _check_for_app(self):
if self._app is not None:
raise ValueError("App is already set, cannot add routes or middleware")
def with_routes_directory(self, directory: str) -> Self:
self._check_for_app()
assert directory is not None, "directory must not be None"
base = directory.replace("/", ".").replace("\\", ".")
for filename in os.listdir(directory):
if not filename.endswith(".py") or filename == "__init__.py":
continue
__import__(f"{base}.{filename[:-3]}")
return self
def with_routes(
self,
routes: list[ApiRoute],
method: HTTPMethods,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Self:
self._check_for_app()
assert self._routes is not None, "routes must not be None"
assert all(isinstance(route, ApiRoute) for route in routes), "all routes must be of type ApiRoute"
for route in routes:
self.with_route(
route.path,
route.fn,
method,
authentication,
roles,
permissions,
policies,
match,
)
return self
def with_route(
self,
path: str,
fn: Callable[[Request], Any],
method: HTTPMethods,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Self:
self._check_for_app()
assert path is not None, "path must not be None"
assert fn is not None, "fn must not be None"
assert method in [
"GET",
"HEAD",
"POST",
"PUT",
"PATCH",
"DELETE",
"OPTIONS",
], "method must be a valid HTTP method"
Router.route(path, method, registry=self._routes)(fn)
if authentication:
Router.authenticate()(fn)
if roles or permissions or policies:
Router.authorize(roles, permissions, policies, match)(fn)
return self
def with_middleware(self, middleware: PartialMiddleware) -> Self:
self._check_for_app()
if isinstance(middleware, Middleware):
self._middleware.append(middleware)
elif callable(middleware):
self._middleware.append(Middleware(middleware))
else:
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
return self
def with_authentication(self) -> Self:
self.with_middleware(AuthenticationMiddleware)
return self
def with_authorization(self, *policies: list[PolicyInput] | PolicyInput) -> Self:
if policies:
_policies = []
if not isinstance(policies, list):
policies = list(policies)
for i, policy in enumerate(policies):
if isinstance(policy, dict):
for name, resolver in policy.items():
if not isinstance(name, str):
self._logger.warning(f"Skipping policy at index {i}, name must be a string")
continue
if not callable(resolver):
self._logger.warning(f"Skipping policy {name}, resolver must be callable")
continue
_policies.append(Policy(name, resolver))
continue
_policies.append(policy)
self._policies.extend(_policies)
self.with_middleware(AuthorizationMiddleware)
return self
def _validate_policies(self):
for rule in Router.get_authorization_rules():
for policy_name in rule["policies"]:
policy = self._policies.get(policy_name)
if not policy:
self._logger.fatal(f"Authorization policy '{policy_name}' not found")
async def main(self):
self._logger.debug(f"Preparing API")
self._validate_policies()
if self._app is None:
routes = [route.to_starlette(self._services.inject) for route in self._routes.all()]
app = Starlette(
routes=routes,
middleware=[
*self._middleware,
Middleware(
CORSMiddleware,
allow_origins=self._get_allowed_origins(),
allow_methods=["*"],
allow_headers=["*"],
),
],
exception_handlers=self._exception_handlers,
)
else:
app = self._app
self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
config = uvicorn.Config(
app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio"
)
server = uvicorn.Server(config)
await server.serve()
self._logger.info("Shutdown API")

View File

@@ -1,30 +1,9 @@
from http.client import HTTPException from http.client import HTTPException
from starlette.responses import JSONResponse
from starlette.types import Scope, Receive, Send
class APIError(HTTPException): class APIError(HTTPException):
status_code = 500 status_code = 500
def __init__(self, message: str = ""):
super().__init__(self.status_code, message)
self._message = message
@property
def error_message(self) -> str:
if self._message:
return f"{type(self).__name__}: {self._message}"
return f"{type(self).__name__}"
async def asgi_response(self, scope: Scope, receive: Receive, send: Send):
r = JSONResponse({"error": self.error_message}, status_code=self.status_code)
return await r(scope, receive, send)
def response(self):
return JSONResponse({"error": self.error_message}, status_code=self.status_code)
class Unauthorized(APIError): class Unauthorized(APIError):
status_code = 401 status_code = 401

View File

@@ -1,7 +0,0 @@
from cpl.core.log.wrapped_logger import WrappedLogger
class APILogger(WrappedLogger):
def __init__(self):
WrappedLogger.__init__(self, "api")

View File

@@ -1,4 +0,0 @@
from .authentication import AuthenticationMiddleware
from .authorization import AuthorizationMiddleware
from .logging import LoggingMiddleware
from .request import RequestMiddleware

View File

@@ -1,80 +0,0 @@
from keycloak import KeycloakAuthenticationError
from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.error import Unauthorized
from cpl.api.logger import APILogger
from cpl.api.middleware.request import get_request
from cpl.api.router import Router
from cpl.auth.keycloak import KeycloakClient
from cpl.auth.schema import AuthUserDao, AuthUser
from cpl.core.ctx import set_user
from cpl.dependency import ServiceProviderABC
class AuthenticationMiddleware(ASGIMiddleware):
@ServiceProviderABC.inject
def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app)
self._logger = logger
self._keycloak = keycloak
self._user_dao = user_dao
async def __call__(self, scope: Scope, receive: Receive, send: Send):
request = get_request()
url = request.url.path
if url not in Router.get_auth_required_routes():
self._logger.trace(f"No authentication required for {url}")
return await self._app(scope, receive, send)
if not request.headers.get("Authorization"):
self._logger.debug(f"Unauthorized access to {url}, missing Authorization header")
return await Unauthorized(f"Missing header Authorization").asgi_response(scope, receive, send)
auth_header = request.headers.get("Authorization", None)
if not auth_header or not auth_header.startswith("Bearer "):
return await Unauthorized("Invalid Authorization header").asgi_response(scope, receive, send)
token = auth_header.split("Bearer ")[1]
if not await self._verify_login(token):
self._logger.debug(f"Unauthorized access to {url}, invalid token")
return await Unauthorized("Invalid token").asgi_response(scope, receive, send)
# check user exists in db, if not create
keycloak_id = self._keycloak.get_user_id(token)
if keycloak_id is None:
return await Unauthorized("Failed to get user id from token").asgi_response(scope, receive, send)
user = await self._get_or_crate_user(keycloak_id)
if user.deleted:
self._logger.debug(f"Unauthorized access to {url}, user is deleted")
return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
request.state.user = user
set_user(user)
return await self._call_next(scope, receive, send)
async def _get_or_crate_user(self, keycloak_id: str) -> AuthUser:
existing = await self._user_dao.find_by_keycloak_id(keycloak_id)
if existing is not None:
return existing
user = AuthUser(0, keycloak_id)
uid = await self._user_dao.create(user)
return await self._user_dao.get_by_id(uid)
async def _verify_login(self, token: str) -> bool:
try:
token_info = self._keycloak.introspect(token)
return token_info.get("active", False)
except KeycloakAuthenticationError as e:
self._logger.debug(f"Keycloak authentication error: {e}")
return False
except Exception as e:
self._logger.error(f"Unexpected error during token verification: {e}")
return False

View File

@@ -1,73 +0,0 @@
from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.error import Unauthorized, Forbidden
from cpl.api.logger import APILogger
from cpl.api.middleware.request import get_request
from cpl.api.model.validation_match import ValidationMatch
from cpl.api.registry.policy import PolicyRegistry
from cpl.api.router import Router
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
from cpl.core.ctx.user_context import get_user
from cpl.dependency.service_provider_abc import ServiceProviderABC
class AuthorizationMiddleware(ASGIMiddleware):
@ServiceProviderABC.inject
def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app)
self._logger = logger
self._policies = policies
self._user_dao = user_dao
async def __call__(self, scope: Scope, receive: Receive, send: Send):
request = get_request()
url = request.url.path
if url not in Router.get_authorization_rules_paths():
self._logger.trace(f"No authorization required for {url}")
return await self._app(scope, receive, send)
user = get_user()
if not user:
return await Unauthorized(f"Unknown user").asgi_response(scope, receive, send)
roles = await user.roles
request.state.roles = roles
role_names = [r.name for r in roles]
perms = await user.permissions
request.state.permissions = perms
perm_names = [p.name for p in perms]
for rule in Router.get_authorization_rules():
match = rule["match"]
if rule["roles"]:
if match == ValidationMatch.all and not all(r in role_names for r in rule["roles"]):
return await Forbidden(f"missing roles: {rule["roles"]}").asgi_response(scope, receive, send)
if match == ValidationMatch.any and not any(r in role_names for r in rule["roles"]):
return await Forbidden(f"missing roles: {rule["roles"]}").asgi_response(scope, receive, send)
if rule["permissions"]:
if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
scope, receive, send
)
if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
scope, receive, send
)
for policy_name in rule["policies"]:
policy = self._policies.get(policy_name)
if not policy:
self._logger.warning(f"Authorization policy '{policy_name}' not found")
continue
if not await policy.resolve(user):
return await Forbidden(f"policy {policy.name} failed").asgi_response(scope, receive, send)
return await self._call_next(scope, receive, send)

View File

@@ -1,46 +1,21 @@
import time import time
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request from starlette.requests import Request
from starlette.types import Receive, Scope, Send from starlette.responses import Response
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.api_logger import APILogger
from cpl.api.logger import APILogger
from cpl.api.middleware.request import get_request _logger = APILogger(__name__)
from cpl.dependency import ServiceProviderABC
class LoggingMiddleware(ASGIMiddleware): class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
@ServiceProviderABC.inject
def __init__(self, app, logger: APILogger):
ASGIMiddleware.__init__(self, app)
self._logger = logger
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
await self._call_next(scope, receive, send)
return
request = get_request()
await self._log_request(request) await self._log_request(request)
start_time = time.time() response = await call_next(request)
await self._log_after_request(request, response)
response_body = b"" return response
status_code = 500
async def send_wrapper(message):
nonlocal response_body, status_code
if message["type"] == "http.response.start":
status_code = message["status"]
if message["type"] == "http.response.body":
response_body += message.get("body", b"")
await send(message)
await self._call_next(scope, receive, send_wrapper)
duration = (time.time() - start_time) * 1000
await self._log_after_request(request, status_code, duration)
@staticmethod @staticmethod
def _filter_relevant_headers(headers: dict) -> dict: def _filter_relevant_headers(headers: dict) -> dict:
@@ -55,9 +30,10 @@ class LoggingMiddleware(ASGIMiddleware):
} }
return {key: value for key, value in headers.items() if key in relevant_keys} return {key: value for key, value in headers.items() if key in relevant_keys}
async def _log_request(self, request: Request): @classmethod
self._logger.debug( async def _log_request(cls, request: Request):
f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}" _logger.debug(
f"Request {request.state.request_id}: {request.method}@{request.url.path} from {request.client.host}"
) )
from cpl.core.ctx.user_context import get_user from cpl.core.ctx.user_context import get_user
@@ -65,7 +41,7 @@ class LoggingMiddleware(ASGIMiddleware):
user = get_user() user = get_user()
request_info = { request_info = {
"headers": self._filter_relevant_headers(dict(request.headers)), "headers": cls._filter_relevant_headers(dict(request.headers)),
"args": dict(request.query_params), "args": dict(request.query_params),
"form-data": ( "form-data": (
await request.form() await request.form()
@@ -79,9 +55,11 @@ class LoggingMiddleware(ASGIMiddleware):
), ),
} }
self._logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}") _logger.trace(f"Request {request.state.request_id}: {request_info}")
async def _log_after_request(self, request: Request, status_code: int, duration: float): @staticmethod
self._logger.info( async def _log_after_request(request: Request, response: Response):
f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms" duration = (time.time() - request.state.start_time) * 1000
_logger.info(
f"Request finished {request.state.request_id}: {response.status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
) )

View File

@@ -3,54 +3,46 @@ from contextvars import ContextVar
from typing import Optional, Union from typing import Optional, Union
from uuid import uuid4 from uuid import uuid4
from starlette.requests import Request from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import Scope, Receive, Send from starlette.websockets import WebSocket
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.api_logger import APILogger
from cpl.api.logger import APILogger
from cpl.api.typing import TRequest from cpl.api.typing import TRequest
from cpl.dependency import ServiceProviderABC
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None) _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
_logger = APILogger(__name__)
class RequestMiddleware(ASGIMiddleware):
@ServiceProviderABC.inject class RequestMiddleware(BaseHTTPMiddleware):
def __init__(self, app, logger: APILogger): _request_token = {}
ASGIMiddleware.__init__(self, app) _user_token = {}
self._logger = logger @classmethod
async def set_request_data(cls, request: TRequest):
self._ctx_token = None
async def __call__(self, scope: Scope, receive: Receive, send: Send):
request = Request(scope, receive, send)
await self.set_request_data(request)
try:
await self._app(scope, receive, send)
finally:
await self.clean_request_data()
async def set_request_data(self, request: TRequest):
request.state.request_id = uuid4() request.state.request_id = uuid4()
request.state.start_time = time.time() request.state.start_time = time.time()
self._logger.trace(f"Set new current request: {request.state.request_id}") _logger.trace(f"Set new current request: {request.state.request_id}")
self._ctx_token = _request_context.set(request) cls._request_token[request.state.request_id] = _request_context.set(request)
async def clean_request_data(self): @classmethod
async def clean_request_data(cls):
request = get_request() request = get_request()
if request is None: if request is None:
return return
if self._ctx_token is None: if request.state.request_id in cls._request_token:
return _request_context.reset(cls._request_token[request.state.request_id])
self._logger.trace(f"Clearing current request: {request.state.request_id}") async def dispatch(self, request: TRequest, call_next):
_request_context.reset(self._ctx_token) await self.set_request_data(request)
try:
response = await call_next(request)
return response
finally:
await self.clean_request_data()
def get_request() -> Optional[TRequest]: def get_request() -> Optional[Union[TRequest, WebSocket]]:
return _request_context.get() return _request_context.get()

View File

@@ -1,3 +0,0 @@
from .api_route import ApiRoute
from .policy import Policy
from .validation_match import ValidationMatch

View File

@@ -1,43 +0,0 @@
from typing import Callable
from starlette.routing import Route
from cpl.api.typing import HTTPMethods
class ApiRoute:
def __init__(self, path: str, fn: Callable, method: HTTPMethods, **kwargs):
self._path = path
self._fn = fn
self._method = method
self._kwargs = kwargs
@property
def name(self) -> str:
return self._fn.__name__
@property
def fn(self) -> Callable:
return self._fn
@property
def path(self) -> str:
return self._path
@property
def method(self) -> HTTPMethods:
return self._method
@property
def kwargs(self) -> dict:
return self._kwargs
def to_starlette(self, wrap_endpoint: Callable = None) -> Route:
return Route(
self._path,
self._fn if not wrap_endpoint else wrap_endpoint(self._fn),
methods=[self._method],
**self._kwargs,
)

View File

@@ -1,34 +0,0 @@
from asyncio import iscoroutinefunction
from typing import Optional
from cpl.api.typing import PolicyResolver
from cpl.core.ctx import get_user
class Policy:
def __init__(
self,
name: str,
resolver: PolicyResolver = None,
):
self._name = name
self._resolver: Optional[PolicyResolver] = resolver
@property
def name(self) -> str:
return self._name
@property
def resolvers(self) -> PolicyResolver:
return self._resolver
async def resolve(self, *args, **kwargs) -> bool:
if not self._resolver:
return True
if callable(self._resolver):
if iscoroutinefunction(self._resolver):
return await self._resolver(get_user())
return self._resolver(get_user())
return False

View File

@@ -1,6 +0,0 @@
from enum import Enum
class ValidationMatch(Enum):
any = "any"
all = "all"

View File

@@ -1,2 +0,0 @@
from .policy import PolicyRegistry
from .route import RouteRegistry

View File

@@ -1,28 +0,0 @@
from typing import Optional
from cpl.api.model.policy import Policy
from cpl.core.abc.registry_abc import RegistryABC
class PolicyRegistry(RegistryABC):
def __init__(self):
RegistryABC.__init__(self)
def extend(self, items: list[Policy]):
for policy in items:
self.add(policy)
def add(self, item: Policy):
assert isinstance(item, Policy), "policy must be an instance of Policy"
if item.name in self._items:
raise ValueError(f"Policy {item.name} is already registered")
self._items[item.name] = item
def get(self, key: str) -> Optional[Policy]:
return self._items.get(key)
def all(self) -> list[Policy]:
return list(self._items.values())

View File

@@ -1,32 +0,0 @@
from typing import Optional
from cpl.api.model.api_route import ApiRoute
from cpl.core.abc.registry_abc import RegistryABC
class RouteRegistry(RegistryABC):
def __init__(self):
RegistryABC.__init__(self)
def extend(self, items: list[ApiRoute]):
for policy in items:
self.add(policy)
def add(self, item: ApiRoute):
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
if item.path in self._items:
raise ValueError(f"ApiRoute {item.path} is already registered")
self._items[item.path] = item
def set(self, item: ApiRoute):
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
self._items[item.path] = item
def get(self, key: str) -> Optional[ApiRoute]:
return self._items.get(key)
def all(self) -> list[ApiRoute]:
return list(self._items.values())

View File

@@ -1,136 +1,41 @@
from enum import Enum from starlette.routing import Route
from cpl.api.model.validation_match import ValidationMatch
from cpl.api.registry.route import RouteRegistry
from cpl.api.typing import HTTPMethods
class Router: class Router:
_auth_required: list[str] = [] _registered_routes: list[Route] = []
_authorization_rules: dict[str, dict] = {}
@classmethod @classmethod
def get_auth_required_routes(cls) -> list[str]: def get_routes(cls) -> list[Route]:
return cls._auth_required return cls._registered_routes
@classmethod @classmethod
def get_authorization_rules_paths(cls) -> list[str]: def route(cls, path=None, **kwargs):
return list(cls._authorization_rules.keys())
@classmethod
def get_authorization_rules(cls) -> list[dict]:
return list(cls._authorization_rules.values())
@classmethod
def authenticate(cls):
"""
Decorator to mark a route as requiring authentication.
Usage:
@Route.authenticate()
@Route.get("/example")
async def example_endpoint(request: TRequest):
...
"""
def inner(fn): def inner(fn):
route_path = getattr(fn, "_route_path", None) cls._registered_routes.append(Route(path, fn, **kwargs))
if route_path and route_path not in cls._auth_required:
cls._auth_required.append(route_path)
return fn
return inner
@classmethod
def authorize(
cls,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
):
"""
Decorator to mark a route as requiring authorization.
Usage:
@Route.authorize()
@Route.get("/example")
async def example_endpoint(request: TRequest):
...
"""
assert roles is None or isinstance(roles, list), "roles must be a list of strings"
assert permissions is None or isinstance(permissions, list), "permissions must be a list of strings"
assert policies is None or isinstance(policies, list), "policies must be a list of strings"
assert match is None or isinstance(match, ValidationMatch), "match must be an instance of ValidationMatch"
if roles is not None:
for role in roles:
if isinstance(role, Enum):
roles[roles.index(role)] = role.value
if permissions is not None:
for perm in permissions:
if isinstance(perm, Enum):
permissions[permissions.index(perm)] = perm.value
def inner(fn):
path = getattr(fn, "_route_path", None)
if not path:
return fn
if path in cls._authorization_rules:
raise ValueError(f"Route {path} is already registered for authorization")
cls._authorization_rules[path] = {
"roles": roles or [],
"permissions": permissions or [],
"policies": policies or [],
"match": match or ValidationMatch.all,
}
return fn
return inner
@classmethod
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs):
from cpl.api.model.api_route import ApiRoute
if not registry:
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = ServiceProviderABC.get_global_service(RouteRegistry)
else:
routes = registry
def inner(fn):
routes.add(ApiRoute(path, fn, method, **kwargs))
setattr(fn, "_route_path", path) setattr(fn, "_route_path", path)
return fn return fn
return inner return inner
@classmethod @classmethod
def get(cls, path: str, **kwargs): def get(cls, path=None, **kwargs):
return cls.route(path, "GET", **kwargs) return cls.route(path, methods=["GET"], **kwargs)
@classmethod @classmethod
def head(cls, path: str, **kwargs): def post(cls, path=None, **kwargs):
return cls.route(path, "HEAD", **kwargs) return cls.route(path, methods=["POST"], **kwargs)
@classmethod @classmethod
def post(cls, path: str, **kwargs): def head(cls, path=None, **kwargs):
return cls.route(path, "POST", **kwargs) return cls.route(path, methods=["HEAD"], **kwargs)
@classmethod @classmethod
def put(cls, path: str, **kwargs): def put(cls, path=None, **kwargs):
return cls.route(path, "PUT", **kwargs) return cls.route(path, methods=["PUT"], **kwargs)
@classmethod @classmethod
def patch(cls, path: str, **kwargs): def delete(cls, path=None, **kwargs):
return cls.route(path, "PATCH", **kwargs) return cls.route(path, methods=["DELETE"], **kwargs)
@classmethod
def delete(cls, path: str, **kwargs):
return cls.route(path, "DELETE", **kwargs)
@classmethod @classmethod
def override(cls): def override(cls):
@@ -143,22 +48,13 @@ class Router:
... ...
""" """
from cpl.api.model.api_route import ApiRoute
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = ServiceProviderABC.get_global_service(RouteRegistry)
def inner(fn): def inner(fn):
path = getattr(fn, "_route_path", None) route_path = getattr(fn, "_route_path", None)
if path is None:
raise ValueError("Cannot override a route that has not been registered yet")
route = routes.get(path) routes = list(filter(lambda x: x.path == route_path, cls._registered_routes))
if route is None: for route in routes[:-1]:
raise ValueError(f"Cannot override a route that does not exist: {path}") cls._registered_routes.remove(route)
routes.add(ApiRoute(path, fn, route.method, **route.kwargs))
setattr(fn, "_route_path", path)
return fn return fn
return inner return inner

View File

@@ -1,19 +1,13 @@
from typing import Union, Literal, Callable, Type, Awaitable from typing import Union, Literal, Callable
from urllib.request import Request from urllib.request import Request
from starlette.middleware import Middleware from starlette.middleware import Middleware
from starlette.types import ASGIApp from starlette.types import ASGIApp
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.auth.schema import AuthUser
TRequest = Union[Request, WebSocket] TRequest = Union[Request, WebSocket]
HTTPMethods = Literal["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"] HTTPMethods = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
PartialMiddleware = Union[ PartialMiddleware = Union[
ASGIMiddleware,
Type[ASGIMiddleware],
Middleware, Middleware,
Callable[[ASGIApp], ASGIApp], Callable[[ASGIApp], ASGIApp],
] ]
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]

View File

@@ -0,0 +1,153 @@
import os
from typing import Mapping, Any, Callable
import uvicorn
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from starlette.types import ExceptionHandler
from cpl.api.api_logger import APILogger
from cpl.api.api_settings import ApiSettings
from cpl.api.error import APIError
from cpl.api.middleware.logging import LoggingMiddleware
from cpl.api.middleware.request import RequestMiddleware
from cpl.api.router import Router
from cpl.api.typing import HTTPMethods, PartialMiddleware
from cpl.application.abc.application_abc import ApplicationABC
from cpl.core.configuration import Configuration
from cpl.dependency.service_provider_abc import ServiceProviderABC
_logger = APILogger("API")
class WebApp(ApplicationABC):
def __init__(self, services: ServiceProviderABC):
super().__init__(services)
self._app: Starlette | None = None
self._api_settings = Configuration.get(ApiSettings)
self._routes: list[Route] = []
self._middleware: list[Middleware] = [
Middleware(RequestMiddleware),
Middleware(LoggingMiddleware),
]
self._exception_handlers: Mapping[Any, ExceptionHandler] = {Exception: self.handle_exception}
@staticmethod
async def handle_exception(request: Request, exc: Exception):
if hasattr(request.state, "request_id"):
_logger.error(f"Request {request.state.request_id}", exc)
else:
_logger.error("Request unknown", exc)
if isinstance(exc, APIError):
return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
return JSONResponse({"error": str(exc)}, status_code=500)
def _get_allowed_origins(self):
origins = self._api_settings.allowed_origins
if origins is None or origins == "":
_logger.warning("No allowed origins specified, allowing all origins")
return ["*"]
_logger.debug(f"Allowed origins: {origins}")
return origins.split(",")
def with_app(self, app: Starlette):
assert app is not None, "app must not be None"
assert isinstance(app, Starlette), "app must be an instance of Starlette"
self._app = app
return self
def _check_for_app(self):
if self._app is not None:
raise ValueError("App is already set, cannot add routes or middleware")
def with_routes_directory(self, directory: str) -> "WebApp":
self._check_for_app()
assert directory is not None, "directory must not be None"
base = directory.replace("/", ".").replace("\\", ".")
for filename in os.listdir(directory):
if not filename.endswith(".py") or filename == "__init__.py":
continue
__import__(f"{base}.{filename[:-3]}")
return self
def with_routes(self, routes: list[Route]) -> "WebApp":
self._check_for_app()
assert self._routes is not None, "routes must not be None"
assert all(isinstance(route, Route) for route in routes), "all routes must be of type starlette.routing.Route"
self._routes.extend(routes)
return self
def with_route(self, path: str, fn: Callable[[Request], Any], method: HTTPMethods, **kwargs) -> "WebApp":
self._check_for_app()
assert path is not None, "path must not be None"
assert fn is not None, "fn must not be None"
assert method in ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"], "method must be a valid HTTP method"
self._routes.append(Route(path, fn, methods=[method], **kwargs))
return self
def with_middleware(self, middleware: PartialMiddleware) -> "WebApp":
self._check_for_app()
if isinstance(middleware, Middleware):
self._middleware.append(middleware)
elif callable(middleware):
self._middleware.append(Middleware(middleware))
else:
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
return self
def main(self):
_logger.debug(f"Preparing API")
if self._app is None:
routes = [
Route(
path=route.path,
endpoint=self._services.inject(route.endpoint),
methods=route.methods,
name=route.name,
)
for route in self._routes + Router.get_routes()
]
app = Starlette(
routes=routes,
middleware=[
*self._middleware,
Middleware(
CORSMiddleware,
allow_origins=self._get_allowed_origins(),
allow_methods=["*"],
allow_headers=["*"],
),
],
exception_handlers=self._exception_handlers,
)
else:
app = self._app
_logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
uvicorn.run(
app,
host=self._api_settings.host,
port=self._api_settings.port,
log_config=None,
)
_logger.info("Shutdown API")

View File

@@ -3,16 +3,16 @@ requires = ["setuptools>=70.1.0", "wheel>=0.43.0"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project] [project]
name = "cpl-api" name = "cpl-application"
version = "2024.7.0" version = "2024.7.0"
description = "CPL api" description = "CPL application"
readme ="CPL api package" readme ="CPL application package"
requires-python = ">=3.12" requires-python = ">=3.12"
license = { text = "MIT" } license = { text = "MIT" }
authors = [ authors = [
{ name = "Sven Heidemann", email = "sven.heidemann@sh-edraft.de" } { name = "Sven Heidemann", email = "sven.heidemann@sh-edraft.de" }
] ]
keywords = ["cpl", "api", "backend", "shared", "library"] keywords = ["cpl", "application", "backend", "shared", "library"]
dynamic = ["dependencies", "optional-dependencies"] dynamic = ["dependencies", "optional-dependencies"]

View File

@@ -3,5 +3,4 @@ cpl-application
cpl-core cpl-core
cpl-dependency cpl-dependency
starlette==0.48.0 starlette==0.48.0
python-multipart==0.0.20 python-multipart==0.0.20
uvicorn==0.35.0

View File

@@ -2,8 +2,9 @@ from abc import ABC, abstractmethod
from typing import Callable, Self from typing import Callable, Self
from cpl.application.host import Host from cpl.application.host import Host
from cpl.core.console.console import Console
from cpl.core.log import LogSettings
from cpl.core.log.log_level import LogLevel from cpl.core.log.log_level import LogLevel
from cpl.core.log.log_settings import LogSettings
from cpl.core.log.logger_abc import LoggerABC from cpl.core.log.logger_abc import LoggerABC
from cpl.dependency.service_provider_abc import ServiceProviderABC from cpl.dependency.service_provider_abc import ServiceProviderABC
@@ -21,15 +22,8 @@ class ApplicationABC(ABC):
""" """
@abstractmethod @abstractmethod
def __init__(self, services: ServiceProviderABC, required_modules: list[str | object] = None): def __init__(self, services: ServiceProviderABC):
self._services = services self._services = services
self._required_modules = (
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []
)
@property
def required_modules(self) -> list[str]:
return self._required_modules
@classmethod @classmethod
def extend(cls, name: str | Callable, func: Callable[[Self], Self]): def extend(cls, name: str | Callable, func: Callable[[Self], Self]):
@@ -86,7 +80,7 @@ class ApplicationABC(ABC):
try: try:
Host.run(self.main) Host.run(self.main)
except KeyboardInterrupt: except KeyboardInterrupt:
pass Console.close()
@abstractmethod @abstractmethod
def main(self): ... def main(self): ...

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from cpl.dependency.service_provider_abc import ServiceProviderABC from cpl.dependency import ServiceProviderABC
class ApplicationExtensionABC(ABC): class ApplicationExtensionABC(ABC):

View File

@@ -6,7 +6,6 @@ from cpl.application.abc.application_extension_abc import ApplicationExtensionAB
from cpl.application.abc.startup_abc import StartupABC from cpl.application.abc.startup_abc import StartupABC
from cpl.application.abc.startup_extension_abc import StartupExtensionABC from cpl.application.abc.startup_extension_abc import StartupExtensionABC
from cpl.application.host import Host from cpl.application.host import Host
from cpl.core.errors import dependency_error
from cpl.dependency.service_collection import ServiceCollection from cpl.dependency.service_collection import ServiceCollection
TApp = TypeVar("TApp", bound=ApplicationABC) TApp = TypeVar("TApp", bound=ApplicationABC)
@@ -36,18 +35,6 @@ class ApplicationBuilder(Generic[TApp]):
def service_provider(self): def service_provider(self):
return self._services.build() return self._services.build()
def validate_app_required_modules(self, app: ApplicationABC):
for module in app.required_modules:
if module in self._services.loaded_modules:
continue
dependency_error(
module,
ImportError(
f"Required module '{module}' for application '{app.__class__.__name__}' is not loaded. Load using 'add_module({module})' method."
),
)
def with_startup(self, startup: Type[StartupABC]) -> "ApplicationBuilder": def with_startup(self, startup: Type[StartupABC]) -> "ApplicationBuilder":
self._startup = startup self._startup = startup
return self return self
@@ -75,6 +62,4 @@ class ApplicationBuilder(Generic[TApp]):
for extension in self._app_extensions: for extension in self._app_extensions:
Host.run(extension.run, self.service_provider) Host.run(extension.run, self.service_provider)
app = self._app(self.service_provider) return self._app(self.service_provider)
self.validate_app_required_modules(app)
return app

View File

@@ -6,7 +6,7 @@ from cpl.auth import permission as _permission
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin as _KeycloakAdmin from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin as _KeycloakAdmin
from cpl.auth.keycloak.keycloak_client import KeycloakClient as _KeycloakClient from cpl.auth.keycloak.keycloak_client import KeycloakClient as _KeycloakClient
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
from .logger import AuthLogger from .auth_logger import AuthLogger
from .keycloak_settings import KeycloakSettings from .keycloak_settings import KeycloakSettings
from .permission_seeder import PermissionSeeder from .permission_seeder import PermissionSeeder
@@ -40,10 +40,11 @@ def _add_daos(collection: _ServiceCollection):
def add_auth(collection: _ServiceCollection): def add_auth(collection: _ServiceCollection):
import os import os
try: from cpl.core.console import Console
from cpl.database.service.migration_service import MigrationService from cpl.database.service.migration_service import MigrationService
from cpl.database.model.server_type import ServerType, ServerTypes from cpl.database.model.server_type import ServerType, ServerTypes
try:
collection.add_singleton(_KeycloakClient) collection.add_singleton(_KeycloakClient)
collection.add_singleton(_KeycloakAdmin) collection.add_singleton(_KeycloakAdmin)
@@ -58,25 +59,22 @@ def add_auth(collection: _ServiceCollection):
elif ServerType.server_type == ServerTypes.MYSQL: elif ServerType.server_type == ServerTypes.MYSQL:
migration_service.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/mysql")) migration_service.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/mysql"))
except ImportError as e: except ImportError as e:
from cpl.core.console import Console Console.error("cpl-auth is not installed", str(e))
Console.error("cpl-database is not installed", str(e))
def add_permission(collection: _ServiceCollection): def add_permission(collection: _ServiceCollection):
from .permission_seeder import PermissionSeeder from cpl.auth.permission_seeder import PermissionSeeder
from .permission.permissions_registry import PermissionsRegistry from cpl.database.abc.data_seeder_abc import DataSeederABC
from .permission.permissions import Permissions from cpl.auth.permission.permissions_registry import PermissionsRegistry
from cpl.auth.permission.permissions import Permissions
try: try:
from cpl.database.abc.data_seeder_abc import DataSeederABC
collection.add_singleton(DataSeederABC, PermissionSeeder) collection.add_singleton(DataSeederABC, PermissionSeeder)
PermissionsRegistry.with_enum(Permissions) PermissionsRegistry.with_enum(Permissions)
except ImportError as e: except ImportError as e:
from cpl.core.console import Console from cpl.core.console import Console
Console.error("cpl-database is not installed", str(e)) Console.error("cpl-auth is not installed", str(e))
_ServiceCollection.with_module(add_auth, __name__) _ServiceCollection.with_module(add_auth, __name__)

View File

@@ -0,0 +1,8 @@
from cpl.core.log import Logger
from cpl.core.typing import Source
class AuthLogger(Logger):
def __init__(self, source: Source):
Logger.__init__(self, source, "auth")

View File

@@ -1,13 +1,15 @@
from keycloak import KeycloakAdmin as _KeycloakAdmin, KeycloakOpenIDConnection from keycloak import KeycloakAdmin as _KeycloakAdmin, KeycloakOpenIDConnection
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.keycloak_settings import KeycloakSettings from cpl.auth.keycloak_settings import KeycloakSettings
from cpl.auth.logger import AuthLogger
_logger = AuthLogger("keycloak")
class KeycloakAdmin(_KeycloakAdmin): class KeycloakAdmin(_KeycloakAdmin):
def __init__(self, logger: AuthLogger, settings: KeycloakSettings): def __init__(self, settings: KeycloakSettings):
# logger.info("Initializing Keycloak admin") _logger.info("Initializing Keycloak admin")
_connection = KeycloakOpenIDConnection( _connection = KeycloakOpenIDConnection(
server_url=settings.url, server_url=settings.url,
client_id=settings.client_id, client_id=settings.client_id,

View File

@@ -1,14 +1,14 @@
from typing import Optional from keycloak import KeycloakOpenID, KeycloakAdmin, KeycloakOpenIDConnection
from keycloak import KeycloakOpenID from cpl.auth.auth_logger import AuthLogger
from cpl.auth.logger import AuthLogger
from cpl.auth.keycloak_settings import KeycloakSettings from cpl.auth.keycloak_settings import KeycloakSettings
_logger = AuthLogger("keycloak")
class KeycloakClient(KeycloakOpenID): class KeycloakClient(KeycloakOpenID):
def __init__(self, logger: AuthLogger, settings: KeycloakSettings): def __init__(self, settings: KeycloakSettings):
KeycloakOpenID.__init__( KeycloakOpenID.__init__(
self, self,
server_url=settings.url, server_url=settings.url,
@@ -16,8 +16,11 @@ class KeycloakClient(KeycloakOpenID):
realm_name=settings.realm, realm_name=settings.realm,
client_secret_key=settings.client_secret, client_secret_key=settings.client_secret,
) )
logger.info("Initializing Keycloak client") _logger.info("Initializing Keycloak client")
connection = KeycloakOpenIDConnection(
def get_user_id(self, token: str) -> Optional[str]: server_url=settings.url,
info = self.introspect(token) client_id=settings.client_id,
return info.get("sub", None) realm_name=settings.realm,
client_secret_key=settings.client_secret,
)
self._admin = KeycloakAdmin(connection=connection)

View File

@@ -1,7 +0,0 @@
from cpl.core.log.wrapped_logger import WrappedLogger
class AuthLogger(WrappedLogger):
def __init__(self):
WrappedLogger.__init__(self, "auth")

View File

@@ -14,13 +14,14 @@ from cpl.auth.schema import (
) )
from cpl.core.utils.get_value import get_value from cpl.core.utils.get_value import get_value
from cpl.database.abc.data_seeder_abc import DataSeederABC from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class PermissionSeeder(DataSeederABC): class PermissionSeeder(DataSeederABC):
def __init__( def __init__(
self, self,
logger: DBLogger,
permission_dao: PermissionDao, permission_dao: PermissionDao,
role_dao: RoleDao, role_dao: RoleDao,
role_permission_dao: RolePermissionDao, role_permission_dao: RolePermissionDao,
@@ -28,7 +29,6 @@ class PermissionSeeder(DataSeederABC):
api_key_permission_dao: ApiKeyPermissionDao, api_key_permission_dao: ApiKeyPermissionDao,
): ):
DataSeederABC.__init__(self) DataSeederABC.__init__(self)
self._logger = logger
self._permission_dao = permission_dao self._permission_dao = permission_dao
self._role_dao = role_dao self._role_dao = role_dao
self._role_permission_dao = role_permission_dao self._role_permission_dao = role_permission_dao
@@ -40,7 +40,7 @@ class PermissionSeeder(DataSeederABC):
possible_permissions = [permission for permission in PermissionsRegistry.get()] possible_permissions = [permission for permission in PermissionsRegistry.get()]
if len(permissions) == len(possible_permissions): if len(permissions) == len(possible_permissions):
self._logger.info("Permissions already existing") _logger.info("Permissions already existing")
await self._update_missing_descriptions() await self._update_missing_descriptions()
return return
@@ -53,7 +53,7 @@ class PermissionSeeder(DataSeederABC):
await self._permission_dao.delete_many(to_delete, hard_delete=True) await self._permission_dao.delete_many(to_delete, hard_delete=True)
self._logger.warning("Permissions incomplete") _logger.warning("Permissions incomplete")
permission_names = [permission.name for permission in permissions] permission_names = [permission.name for permission in permissions]
await self._permission_dao.create_many( await self._permission_dao.create_many(
[ [

View File

@@ -3,12 +3,15 @@ from typing import Optional
from cpl.auth.schema._administration.api_key import ApiKey from cpl.auth.schema._administration.api_key import ApiKey
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class ApiKeyDao(DbModelDaoABC[ApiKey]): class ApiKeyDao(DbModelDaoABC[ApiKey]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, ApiKey, TableManager.get("api_keys")) DbModelDaoABC.__init__(self, __name__, ApiKey, TableManager.get("api_keys"))
self.attribute(ApiKey.identifier, str) self.attribute(ApiKey.identifier, str)
self.attribute(ApiKey.key, str, "keystring") self.attribute(ApiKey.key, str, "keystring")

View File

@@ -6,12 +6,14 @@ from async_property import async_property
from keycloak import KeycloakGetError from keycloak import KeycloakGetError
from cpl.auth.keycloak import KeycloakAdmin from cpl.auth.keycloak import KeycloakAdmin
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions import Permissions
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
from cpl.database.logger import DBLogger
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProviderABC
_logger = AuthLogger(__name__)
class AuthUser(DbModelABC): class AuthUser(DbModelABC):
def __init__( def __init__(
@@ -36,13 +38,12 @@ class AuthUser(DbModelABC):
return "ANONYMOUS" return "ANONYMOUS"
try: try:
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin) keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
return keycloak.get_user(self._keycloak_id).get("username") return keycloak_admin.get_user(self._keycloak_id).get("username")
except KeycloakGetError as e: except KeycloakGetError as e:
return "UNKNOWN" return "UNKNOWN"
except Exception as e: except Exception as e:
logger = ServiceProviderABC.get_global_service(DBLogger) _logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
return "UNKNOWN" return "UNKNOWN"
@property @property
@@ -51,13 +52,12 @@ class AuthUser(DbModelABC):
return "ANONYMOUS" return "ANONYMOUS"
try: try:
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin) keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
return keycloak.get_user(self._keycloak_id).get("email") return keycloak_admin.get_user(self._keycloak_id).get("email")
except KeycloakGetError as e: except KeycloakGetError as e:
return "UNKNOWN" return "UNKNOWN"
except Exception as e: except Exception as e:
logger = ServiceProviderABC.get_global_service(DBLogger) _logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
return "UNKNOWN" return "UNKNOWN"
@async_property @async_property

View File

@@ -4,16 +4,19 @@ from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema._administration.auth_user import AuthUser from cpl.auth.schema._administration.auth_user import AuthUser
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProviderABC
_logger = DBLogger(__name__)
class AuthUserDao(DbModelDaoABC[AuthUser]): class AuthUserDao(DbModelDaoABC[AuthUser]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, AuthUser, TableManager.get("auth_users")) DbModelDaoABC.__init__(self, __name__, AuthUser, TableManager.get("auth_users"))
self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId") self.attribute(AuthUser.keycloak_id, str, aliases=["keycloakId"])
async def get_users(): async def get_users():
return [(x.id, x.username, x.email) for x in await self.get_all()] return [(x.id, x.username, x.email) for x in await self.get_all()]
@@ -40,9 +43,9 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value) p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
result = await self._db.select_map( result = await self._db.select_map(
f""" f"""
SELECT COUNT(*) as count SELECT COUNT(*)
FROM {TableManager.get("role_users")} ru FROM permission.role_users ru
JOIN {TableManager.get("role_permissions")} rp ON ru.roleId = rp.roleId JOIN permission.role_permissions rp ON ru.roleId = rp.roleId
WHERE ru.userId = {user_id} WHERE ru.userId = {user_id}
AND rp.permissionId = {p.id} AND rp.permissionId = {p.id}
AND ru.deleted = FALSE AND ru.deleted = FALSE
@@ -58,9 +61,9 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
result = await self._db.select_map( result = await self._db.select_map(
f""" f"""
SELECT p.* SELECT p.*
FROM {TableManager.get("permissions")} p FROM permission.permissions p
JOIN {TableManager.get("role_permissions")} rp ON p.id = rp.permissionId JOIN permission.role_permissions rp ON p.id = rp.permissionId
JOIN {TableManager.get("role_users")} ru ON rp.roleId = ru.roleId JOIN permission.role_users ru ON rp.roleId = ru.roleId
WHERE ru.userId = {user_id} WHERE ru.userId = {user_id}
AND rp.deleted = FALSE AND rp.deleted = FALSE
AND ru.deleted = FALSE; AND ru.deleted = FALSE;

View File

@@ -1,12 +1,15 @@
from cpl.auth.schema._permission.api_key_permission import ApiKeyPermission from cpl.auth.schema._permission.api_key_permission import ApiKeyPermission
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class ApiKeyPermissionDao(DbModelDaoABC[ApiKeyPermission]): class ApiKeyPermissionDao(DbModelDaoABC[ApiKeyPermission]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, ApiKeyPermission, TableManager.get("api_key_permissions")) DbModelDaoABC.__init__(self, __name__, ApiKeyPermission, TableManager.get("api_key_permissions"))
self.attribute(ApiKeyPermission.api_key_id, int) self.attribute(ApiKeyPermission.api_key_id, int)
self.attribute(ApiKeyPermission.permission_id, int) self.attribute(ApiKeyPermission.permission_id, int)

View File

@@ -3,12 +3,15 @@ from typing import Optional
from cpl.auth.schema._permission.permission import Permission from cpl.auth.schema._permission.permission import Permission
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class PermissionDao(DbModelDaoABC[Permission]): class PermissionDao(DbModelDaoABC[Permission]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, Permission, TableManager.get("permissions")) DbModelDaoABC.__init__(self, __name__, Permission, TableManager.get("permissions"))
self.attribute(Permission.name, str) self.attribute(Permission.name, str)
self.attribute(Permission.description, Optional[str]) self.attribute(Permission.description, Optional[str])

View File

@@ -1,11 +1,14 @@
from cpl.auth.schema._permission.role import Role from cpl.auth.schema._permission.role import Role
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class RoleDao(DbModelDaoABC[Role]): class RoleDao(DbModelDaoABC[Role]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, Role, TableManager.get("roles")) DbModelDaoABC.__init__(self, __name__, Role, TableManager.get("roles"))
self.attribute(Role.name, str) self.attribute(Role.name, str)
self.attribute(Role.description, str) self.attribute(Role.description, str)

View File

@@ -1,12 +1,15 @@
from cpl.auth.schema._permission.role_permission import RolePermission from cpl.auth.schema._permission.role_permission import RolePermission
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class RolePermissionDao(DbModelDaoABC[RolePermission]): class RolePermissionDao(DbModelDaoABC[RolePermission]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, RolePermission, TableManager.get("role_permissions")) DbModelDaoABC.__init__(self, __name__, RolePermission, TableManager.get("role_permissions"))
self.attribute(RolePermission.role_id, int) self.attribute(RolePermission.role_id, int)
self.attribute(RolePermission.permission_id, int) self.attribute(RolePermission.permission_id, int)

View File

@@ -1,12 +1,15 @@
from cpl.auth.schema._permission.role_user import RoleUser from cpl.auth.schema._permission.role_user import RoleUser
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class RoleUserDao(DbModelDaoABC[RoleUser]): class RoleUserDao(DbModelDaoABC[RoleUser]):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, RoleUser, TableManager.get("role_users")) DbModelDaoABC.__init__(self, __name__, RoleUser, TableManager.get("role_users"))
self.attribute(RoleUser.role_id, int) self.attribute(RoleUser.role_id, int)
self.attribute(RoleUser.user_id, int) self.attribute(RoleUser.user_id, int)

View File

@@ -14,7 +14,7 @@ CREATE TABLE IF NOT EXISTS administration_auth_users
CREATE TABLE IF NOT EXISTS administration_auth_users_history CREATE TABLE IF NOT EXISTS administration_auth_users_history
( (
id INT NOT NULL, id INT AUTO_INCREMENT PRIMARY KEY,
keycloakId CHAR(36) NOT NULL, keycloakId CHAR(36) NOT NULL,
-- for history -- for history
deleted BOOL NOT NULL, deleted BOOL NOT NULL,

View File

@@ -15,7 +15,7 @@ CREATE TABLE IF NOT EXISTS administration_api_keys
CREATE TABLE IF NOT EXISTS administration_api_keys_history CREATE TABLE IF NOT EXISTS administration_api_keys_history
( (
id INT NOT NULL, id INT AUTO_INCREMENT PRIMARY KEY,
identifier VARCHAR(255) NOT NULL, identifier VARCHAR(255) NOT NULL,
keyString VARCHAR(255) NOT NULL, keyString VARCHAR(255) NOT NULL,
deleted BOOL NOT NULL, deleted BOOL NOT NULL,

View File

@@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS permission_permissions
CREATE TABLE IF NOT EXISTS permission_permissions_history CREATE TABLE IF NOT EXISTS permission_permissions_history
( (
id INT NOT NULL, id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(255) NOT NULL, name VARCHAR(255) NOT NULL,
description TEXT NULL, description TEXT NULL,
deleted BOOL NOT NULL, deleted BOOL NOT NULL,
@@ -57,7 +57,7 @@ CREATE TABLE IF NOT EXISTS permission_roles
CREATE TABLE IF NOT EXISTS permission_roles_history CREATE TABLE IF NOT EXISTS permission_roles_history
( (
id INT NOT NULL, id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(255) NOT NULL, name VARCHAR(255) NOT NULL,
description TEXT NULL, description TEXT NULL,
deleted BOOL NOT NULL, deleted BOOL NOT NULL,
@@ -103,7 +103,7 @@ CREATE TABLE IF NOT EXISTS permission_role_permissions
CREATE TABLE IF NOT EXISTS permission_role_permissions_history CREATE TABLE IF NOT EXISTS permission_role_permissions_history
( (
id INT NOT NULL, id INT AUTO_INCREMENT PRIMARY KEY,
RoleId INT NOT NULL, RoleId INT NOT NULL,
permissionId INT NOT NULL, permissionId INT NOT NULL,
deleted BOOL NOT NULL, deleted BOOL NOT NULL,
@@ -149,7 +149,7 @@ CREATE TABLE IF NOT EXISTS permission_role_auth_users
CREATE TABLE IF NOT EXISTS permission_role_auth_users_history CREATE TABLE IF NOT EXISTS permission_role_auth_users_history
( (
id INT NOT NULL, id INT AUTO_INCREMENT PRIMARY KEY,
RoleId INT NOT NULL, RoleId INT NOT NULL,
UserId INT NOT NULL, UserId INT NOT NULL,
deleted BOOL NOT NULL, deleted BOOL NOT NULL,

View File

@@ -15,7 +15,7 @@ CREATE TABLE IF NOT EXISTS permission_api_key_permissions
CREATE TABLE IF NOT EXISTS permission_api_key_permissions_history CREATE TABLE IF NOT EXISTS permission_api_key_permissions_history
( (
id INT NOT NULL, id INT AUTO_INCREMENT PRIMARY KEY,
apiKeyId INT NOT NULL, apiKeyId INT NOT NULL,
permissionId INT NOT NULL, permissionId INT NOT NULL,
deleted BOOL NOT NULL, deleted BOOL NOT NULL,

View File

@@ -1,4 +1,4 @@
cpl-core cpl-core
cpl-dependency cpl-dependency
cpl-database cpl-database
python-keycloak==5.8.1 python-keycloak-5.8.1

View File

@@ -1,23 +0,0 @@
from abc import abstractmethod, ABC
from typing import Generic
from cpl.core.typing import T
class RegistryABC(ABC, Generic[T]):
@abstractmethod
def __init__(self):
self._items: dict[str, T] = {}
@abstractmethod
def extend(self, items: list[T]) -> None: ...
@abstractmethod
def add(self, item: T) -> None: ...
@abstractmethod
def get(self, key: str) -> T | None: ...
@abstractmethod
def all(self) -> list[T]: ...

View File

@@ -130,7 +130,7 @@ class Configuration:
key_name = key.__name__ if inspect.isclass(key) else key key_name = key.__name__ if inspect.isclass(key) else key
result = cls._config.get(key_name, default) result = cls._config.get(key_name, default)
if isclass(key) and issubclass(key, ConfigurationModelABC) and result == default: if issubclass(key, ConfigurationModelABC) and result == default:
result = key() result = key()
cls.set(key, result) cls.set(key, result)

View File

@@ -68,7 +68,7 @@ class ConfigurationModelABC(ABC):
value = cast(Environment.get(env_field, str), cast_type) value = cast(Environment.get(env_field, str), cast_type)
if value is None and required: if value is None and required:
raise ValueError(f"{type(self).__name__}.{field} is required") raise ValueError(f"{field} is required")
elif value is None: elif value is None:
self._options[field] = default self._options[field] = default
return return

View File

@@ -1,18 +1,17 @@
from contextvars import ContextVar from contextvars import ContextVar
from typing import Optional from typing import Optional
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.schema._administration.auth_user import AuthUser from cpl.auth.schema._administration.auth_user import AuthUser
_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None) _user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None)
_logger = AuthLogger(__name__)
def set_user(user: Optional[AuthUser]):
from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.core.log.logger_abc import LoggerABC
logger = ServiceProviderABC.get_global_service(LoggerABC) def set_user(user_id: Optional[AuthUser]):
logger.trace("Setting user context", user.id) _logger.trace("Setting user context", user_id)
_user_context.set(user) _user_context.set(user_id)
def get_user() -> Optional[AuthUser]: def get_user() -> Optional[AuthUser]:

View File

@@ -1,15 +0,0 @@
import traceback
from cpl.core.console import Console
def dependency_error(package_name: str, e: ImportError) -> None:
Console.error(f"'{package_name}' is required to use this feature. Please install it and try again.")
tb = traceback.format_exc()
if not tb.startswith("NoneType: None"):
Console.write_line("->", tb)
elif e is not None:
Console.write_line("->", str(e))
exit(1)

View File

@@ -2,4 +2,3 @@ from .logger import Logger
from .logger_abc import LoggerABC from .logger_abc import LoggerABC
from .log_level import LogLevel from .log_level import LogLevel
from .log_settings import LogSettings from .log_settings import LogSettings
from .structured_logger import StructuredLogger

View File

@@ -1,111 +0,0 @@
import asyncio
import importlib.util
import json
import traceback
from datetime import datetime
from starlette.requests import Request
from cpl.core.log.log_level import LogLevel
from cpl.core.log.logger import Logger
from cpl.core.typing import Source, Messages
class StructuredLogger(Logger):
def __init__(self, source: Source, file_prefix: str = None):
Logger.__init__(self, source, file_prefix)
@property
def log_file(self):
return f"logs/{self._file_prefix}_{datetime.now().strftime('%Y-%m-%d')}.jsonl"
def _log(self, level: LogLevel, *messages: Messages):
try:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
formatted_message = self._format_message(level.value, timestamp, *messages)
structured_message = self._get_structured_message(level.value, timestamp, formatted_message)
self._write_log_to_file(level, structured_message)
self._write_to_console(level, formatted_message)
except Exception as e:
print(f"Error while logging: {e} -> {traceback.format_exc()}")
def _get_structured_message(self, level: str, timestamp: str, messages: str) -> str:
structured_message = {
"timestamp": timestamp,
"level": level.upper(),
"source": self._source,
"messages": messages,
}
self._enrich_message_with_request(structured_message)
self._enrich_message_with_user(structured_message)
return json.dumps(structured_message, ensure_ascii=False)
@staticmethod
def _scope_to_json(request: Request, include_headers: bool = False) -> dict:
scope = dict(request.scope)
def convert(value):
if isinstance(value, bytes):
return value.decode("utf-8")
if isinstance(value, (list, tuple)):
return [convert(v) for v in value]
if isinstance(value, dict):
return {str(k): convert(v) for k, v in value.items()}
if not isinstance(value, (str, int, float, bool, type(None))):
return str(value)
return value
serializable_scope = {str(k): convert(v) for k, v in scope.items()}
if not include_headers and "headers" in serializable_scope:
serializable_scope["headers"] = "<omitted>"
return serializable_scope
def _enrich_message_with_request(self, message: dict):
if importlib.util.find_spec("cpl.api") is None:
return
from cpl.api.middleware.request import get_request
from starlette.requests import Request
request = get_request()
if request is None:
return
message["request"] = {
"url": str(request.url),
"method": request.method,
"scope": self._scope_to_json(request),
}
if isinstance(request, Request) and request.scope == "http":
request: Request = request # fix typing for IDEs
message["request"]["data"] = asyncio.create_task(request.body())
@staticmethod
def _enrich_message_with_user(message: dict):
if importlib.util.find_spec("cpl-auth") is None:
return
from cpl.core.ctx import get_user
user = get_user()
if user is None:
return
from cpl.dependency.service_provider_abc import ServiceProviderABC
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin)
kc_user = keycloak.get_user(user.keycloak_id)
message["user"] = {
"id": str(user.id),
"username": kc_user.get("username"),
"email": kc_user.get("email"),
}

View File

@@ -1,100 +0,0 @@
import inspect
from typing import Type
from cpl.core.log import LoggerABC, LogLevel
from cpl.core.typing import Messages
from cpl.dependency.service_provider_abc import ServiceProviderABC
class WrappedLogger(LoggerABC):
def __init__(self, file_prefix: str):
LoggerABC.__init__(self)
assert file_prefix is not None and file_prefix != "", "file_prefix must be a non-empty string"
self._source = None
self._file_prefix = file_prefix
self._set_logger()
@ServiceProviderABC.inject
def _set_logger(self, services: ServiceProviderABC):
from cpl.core.log import Logger
t_logger: Type[Logger] = services.get_service_type(LoggerABC)
if t_logger is None:
raise Exception("No LoggerABC service registered in ServiceProviderABC")
self._logger = t_logger(self._source, self._file_prefix)
def set_level(self, level: LogLevel):
self._logger.set_level(level)
def _format_message(self, level: str, timestamp, *messages: Messages) -> str:
return self._logger._format_message(level, timestamp, *messages)
@staticmethod
def _get_source() -> str | None:
stack = inspect.stack()
if len(stack) <= 1:
return None
from cpl.dependency import ServiceCollection
ignore_classes = [
ServiceProviderABC,
ServiceProviderABC.__subclasses__(),
ServiceCollection,
WrappedLogger,
WrappedLogger.__subclasses__(),
]
ignore_modules = [x.__module__ for x in ignore_classes if isinstance(x, type)]
for i, frame_info in enumerate(stack[1:]):
module = inspect.getmodule(frame_info.frame)
if module is None:
continue
if module.__name__ in ignore_classes or module in ignore_classes:
continue
if module in ignore_modules or module.__name__ in ignore_modules:
continue
if module.__name__ != __name__:
return module.__name__
return None
def _set_source(self):
self._source = self._get_source()
self._set_logger()
def header(self, string: str):
self._set_source()
self._logger.header(string)
def trace(self, *messages: Messages):
self._set_source()
self._logger.trace(*messages)
def debug(self, *messages: Messages):
self._set_source()
self._logger.debug(*messages)
def info(self, *messages: Messages):
self._set_source()
self._logger.info(*messages)
def warning(self, *messages: Messages):
self._set_source()
self._logger.warning(*messages)
def error(self, messages: str, e: Exception = None):
self._set_source()
self._logger.error(messages, e)
def fatal(self, messages: str, e: Exception = None):
self._set_source()
self._logger.fatal(messages, e)

View File

@@ -14,4 +14,3 @@ UuidId = str | UUID
SerialId = int SerialId = int
Id = UuidId | SerialId Id = UuidId | SerialId
TNumber = int | float | complex

View File

@@ -1,57 +0,0 @@
import time
import tracemalloc
from typing import List, Callable
from cpl.core.console import Console
class Benchmark:
@staticmethod
def all(label: str, func: Callable, iterations: int = 5):
times: List[float] = []
mems: List[float] = []
for _ in range(iterations):
start = time.perf_counter()
func()
end = time.perf_counter()
times.append(end - start)
for _ in range(iterations):
tracemalloc.start()
func()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
mems.append(peak)
avg_time = sum(times) / len(times)
avg_mem = sum(mems) / len(mems) / (1024 * 1024)
Console.write_line(f"{label:20s} -> min {min(times):.6f}s avg {avg_time:.6f}s mem {avg_mem:.8f} MB")
@staticmethod
def time(label: str, func: Callable, iterations: int = 5):
times: List[float] = []
for _ in range(iterations):
start = time.perf_counter()
func()
end = time.perf_counter()
times.append(end - start)
avg_time = sum(times) / len(times)
Console.write_line(f"{label:20s} -> min {min(times):.6f}s avg {avg_time:.6f}s")
@staticmethod
def memory(label: str, func: Callable, iterations: int = 5):
mems: List[float] = []
for _ in range(iterations):
tracemalloc.start()
func()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
mems.append(peak)
avg_mem = sum(mems) / len(mems) / (1024 * 1024)
Console.write_line(f"{label:20s} -> mem {avg_mem:.2f} MB")

View File

@@ -1,100 +0,0 @@
import threading
import time
from typing import Generic
from cpl.core.typing import T
class Cache(Generic[T]):
def __init__(self, default_ttl: int = None, cleanup_interval: int = 60, t: type = None):
self._store = {}
self._default_ttl = default_ttl
self._lock = threading.Lock()
self._cleanup_interval = cleanup_interval
self._stop_event = threading.Event()
self._type = t
# Start background cleanup thread
self._thread = threading.Thread(target=self._auto_cleanup, daemon=True)
self._thread.start()
def set(self, key: str, value: T, ttl: int = None) -> None:
"""Store a value in the cache with optional TTL override."""
expire_at = None
ttl = ttl if ttl is not None else self._default_ttl
if ttl is not None:
expire_at = time.time() + ttl
with self._lock:
self._store[key] = (value, expire_at)
def get(self, key: str) -> T | None:
"""Retrieve a value from the cache if not expired."""
with self._lock:
item = self._store.get(key)
if not item:
return None
value, expire_at = item
if expire_at and expire_at < time.time():
# Expired -> remove and return None
del self._store[key]
return None
return value
def get_all(self) -> list[T]:
"""Retrieve all non-expired values from the cache."""
now = time.time()
with self._lock:
valid_items = []
expired_keys = []
for k, (v, exp) in self._store.items():
if exp and exp < now:
expired_keys.append(k)
else:
valid_items.append(v)
for k in expired_keys:
del self._store[k]
return valid_items
def has(self, key: str) -> bool:
"""Check if a key exists and is not expired."""
with self._lock:
item = self._store.get(key)
if not item:
return False
_, expire_at = item
if expire_at and expire_at < time.time():
# Expired -> remove and return False
del self._store[key]
return False
return True
def delete(self, key: str) -> None:
"""Remove an item from the cache."""
with self._lock:
self._store.pop(key, None)
def clear(self) -> None:
"""Clear the entire cache."""
with self._lock:
self._store.clear()
def _auto_cleanup(self):
"""Background thread to clean expired items."""
while not self._stop_event.is_set():
self.cleanup()
self._stop_event.wait(self._cleanup_interval)
def cleanup(self) -> None:
"""Remove expired items immediately."""
now = time.time()
with self._lock:
expired_keys = [k for k, (_, exp) in self._store.items() if exp and exp < now]
for k in expired_keys:
del self._store[k]
def stop(self):
"""Stop the background cleanup thread."""
self._stop_event.set()
self._thread.join()

View File

@@ -1,48 +0,0 @@
from typing import Any
class Number:
@staticmethod
def is_number(value: Any) -> bool:
"""Check if the value is a number (int or float)."""
return isinstance(value, (int, float, complex))
@staticmethod
def to_number(value: Any) -> int | float | complex:
"""
Convert a given value into int, float, or complex.
Raises ValueError if conversion is not possible.
"""
if isinstance(value, (int, float, complex)):
return value
if isinstance(value, str):
value = value.strip()
for caster in (int, float, complex):
try:
return caster(value)
except ValueError:
continue
raise ValueError(f"Cannot convert string '{value}' to number.")
if isinstance(value, bool):
return int(value)
try:
return int(value)
except Exception:
pass
try:
return float(value)
except Exception:
pass
try:
return complex(value)
except Exception:
pass
raise ValueError(f"Cannot convert type {type(value)} to number.")

View File

@@ -2,4 +2,5 @@ art==6.5
colorama==0.4.6 colorama==0.4.6
tabulate==0.9.0 tabulate==0.9.0
termcolor==3.1.0 termcolor==3.1.0
mysql-connector-python==9.4.0
pynput==1.8.1 pynput==1.8.1

View File

@@ -1,4 +1,3 @@
import os
from typing import Type from typing import Type
from cpl.application.abc import ApplicationABC as _ApplicationABC from cpl.application.abc import ApplicationABC as _ApplicationABC
@@ -8,19 +7,13 @@ from . import postgres as _postgres
from .table_manager import TableManager from .table_manager import TableManager
def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC: def _with_migrations(self: _ApplicationABC, *paths: list[str]) -> _ApplicationABC:
from cpl.application.host import Host from cpl.application.host import Host
from cpl.database.service.migration_service import MigrationService from cpl.database.service.migration_service import MigrationService
migration_service = self._services.get_service(MigrationService) migration_service = self._services.get_service(MigrationService)
migration_service.with_directory(os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts")) migration_service.with_directory("./scripts")
if isinstance(paths, str):
paths = [paths]
for path in paths:
migration_service.with_directory(path)
Host.run(migration_service.migrate) Host.run(migration_service.migrate)
return self return self

View File

@@ -9,7 +9,7 @@ from cpl.core.utils.get_value import get_value
from cpl.core.utils.string import String from cpl.core.utils.string import String
from cpl.database.abc.db_context_abc import DBContextABC from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.const import DATETIME_FORMAT from cpl.database.const import DATETIME_FORMAT
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
@@ -18,12 +18,16 @@ from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSor
class DataAccessObjectABC(ABC, Generic[T_DBM]): class DataAccessObjectABC(ABC, Generic[T_DBM]):
@abstractmethod @abstractmethod
def __init__(self, model_type: Type[T_DBM], table_name: str): def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
from cpl.dependency.service_provider_abc import ServiceProviderABC from cpl.dependency.service_provider_abc import ServiceProviderABC
self._db = ServiceProviderABC.get_global_service(DBContextABC) self._db = ServiceProviderABC.get_global_service(DBContextABC)
self._logger = ServiceProviderABC.get_global_service(DBLogger) self._logger = DBLogger(source)
self._model_type = model_type
self._table_name = table_name
self._logger = DBLogger(source)
self._model_type = model_type self._model_type = model_type
self._table_name = table_name self._table_name = table_name
@@ -152,16 +156,13 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
:param dict result: Result from the database :param dict result: Result from the database
:return: :return:
""" """
value_map: dict[str, Any] = {} value_map: dict[str, T] = {}
db_names = self.__db_names.items()
for db_name, value in result.items(): for db_name, value in result.items():
# Find the attribute name corresponding to the db_name # Find the attribute name corresponding to the db_name
attr_name = next((k for k, v in db_names if v == db_name), None) attr_name = next((k for k, v in self.__db_names.items() if v == db_name), None)
if not attr_name: if attr_name:
continue value_map[attr_name] = self._get_value_from_sql(self.__attributes[attr_name], value)
value_map[attr_name] = self._get_value_from_sql(self.__attributes[attr_name], value)
return self._model_type(**value_map) return self._model_type(**value_map)
@@ -483,7 +484,7 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
builder.with_temp_table(self._external_fields[temp]) builder.with_temp_table(self._external_fields[temp])
if for_count: if for_count:
builder.with_attribute("COUNT(*) as count", ignore_table_name=True) builder.with_attribute("COUNT(*)", ignore_table_name=True)
else: else:
builder.with_attribute("*") builder.with_attribute("*")

View File

@@ -10,8 +10,8 @@ from cpl.database.abc.db_model_abc import DbModelABC
class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]): class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]):
@abstractmethod @abstractmethod
def __init__(self, model_type: Type[T_DBM], table_name: str): def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
DataAccessObjectABC.__init__(self, model_type, table_name) DataAccessObjectABC.__init__(self, source, model_type, table_name)
self.attribute(DbModelABC.id, int, ignore=True) self.attribute(DbModelABC.id, int, ignore=True)
self.attribute(DbModelABC.deleted, bool) self.attribute(DbModelABC.deleted, bool)

View File

@@ -0,0 +1,8 @@
from cpl.core.log import Logger
from cpl.core.typing import Source
class DBLogger(Logger):
def __init__(self, source: Source):
Logger.__init__(self, source, "db")

View File

@@ -1,7 +0,0 @@
from cpl.core.log.wrapped_logger import WrappedLogger
class DBLogger(WrappedLogger):
def __init__(self):
WrappedLogger.__init__(self, "db")

View File

@@ -4,17 +4,18 @@ from typing import Any, List, Dict, Tuple, Union
from mysql.connector import Error as MySQLError, PoolError from mysql.connector import Error as MySQLError, PoolError
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.core.environment import Environment
from cpl.database.abc.db_context_abc import DBContextABC from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
from cpl.database.model.database_settings import DatabaseSettings from cpl.database.model.database_settings import DatabaseSettings
from cpl.database.mysql.mysql_pool import MySQLPool from cpl.database.mysql.mysql_pool import MySQLPool
_logger = DBLogger(__name__)
class DBContext(DBContextABC): class DBContext(DBContextABC):
def __init__(self, logger: DBLogger): def __init__(self):
DBContextABC.__init__(self) DBContextABC.__init__(self)
self._logger = logger
self._pool: MySQLPool = None self._pool: MySQLPool = None
self._fails = 0 self._fails = 0
@@ -22,62 +23,62 @@ class DBContext(DBContextABC):
def connect(self, database_settings: DatabaseSettings): def connect(self, database_settings: DatabaseSettings):
try: try:
self._logger.debug("Connecting to database") _logger.debug("Connecting to database")
self._pool = MySQLPool( self._pool = MySQLPool(
database_settings, database_settings,
) )
self._logger.info("Connected to database") _logger.info("Connected to database")
except Exception as e: except Exception as e:
self._logger.fatal("Connecting to database failed", e) _logger.fatal("Connecting to database failed", e)
async def execute(self, statement: str, args=None, multi=True) -> List[List]: async def execute(self, statement: str, args=None, multi=True) -> List[List]:
self._logger.trace(f"execute {statement} with args: {args}") _logger.trace(f"execute {statement} with args: {args}")
return await self._pool.execute(statement, args, multi) return await self._pool.execute(statement, args, multi)
async def select_map(self, statement: str, args=None) -> List[Dict]: async def select_map(self, statement: str, args=None) -> List[Dict]:
self._logger.trace(f"select {statement} with args: {args}") _logger.trace(f"select {statement} with args: {args}")
try: try:
return await self._pool.select_map(statement, args) return await self._pool.select_map(statement, args)
except (MySQLError, PoolError) as e: except (MySQLError, PoolError) as e:
if self._fails >= 3: if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4() uid = uuid.uuid4()
raise Exception( raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
) )
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1 self._fails += 1
try: try:
self._logger.debug("Retry select") _logger.debug("Retry select")
return await self.select_map(statement, args) return await self.select_map(statement, args)
except Exception as e: except Exception as e:
pass pass
return [] return []
except Exception as e: except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
raise e raise e
async def select(self, statement: str, args=None) -> Union[List[str], List[Tuple], List[Any]]: async def select(self, statement: str, args=None) -> Union[List[str], List[Tuple], List[Any]]:
self._logger.trace(f"select {statement} with args: {args}") _logger.trace(f"select {statement} with args: {args}")
try: try:
return await self._pool.select(statement, args) return await self._pool.select(statement, args)
except (MySQLError, PoolError) as e: except (MySQLError, PoolError) as e:
if self._fails >= 3: if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4() uid = uuid.uuid4()
raise Exception( raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
) )
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1 self._fails += 1
try: try:
self._logger.debug("Retry select") _logger.debug("Retry select")
return await self.select(statement, args) return await self.select(statement, args)
except Exception as e: except Exception as e:
pass pass
return [] return []
except Exception as e: except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
raise e raise e

View File

@@ -1,92 +1,105 @@
from typing import Optional, Any from typing import Optional, Any
import sqlparse import sqlparse
from mysql.connector.aio import MySQLConnectionPool import aiomysql
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
from cpl.database.model import DatabaseSettings from cpl.database.model import DatabaseSettings
from cpl.dependency import ServiceProviderABC
_logger = DBLogger(__name__)
class MySQLPool: class MySQLPool:
"""
Create a pool when connecting to MySQL, which will decrease the time spent in
requesting connection, creating connection, and closing connection.
"""
def __init__(self, database_settings: DatabaseSettings): def __init__(self, database_settings: DatabaseSettings):
self._dbconfig = { self._db_settings = database_settings
"host": database_settings.host, self.pool: Optional[aiomysql.Pool] = None
"port": database_settings.port,
"user": database_settings.user,
"password": database_settings.password,
"database": database_settings.database,
"ssl_disabled": True,
}
self._pool: Optional[MySQLConnectionPool] = None
async def _get_pool(self): async def _get_pool(self):
if self._pool is None: if self.pool is None or self.pool._closed:
self._pool = MySQLConnectionPool(
pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig
)
await self._pool.initialize_pool()
con = await self._pool.get_connection()
try: try:
async with await con.cursor() as cursor: self.pool = await aiomysql.create_pool(
await cursor.execute("SELECT 1") host=self._db_settings.host,
await cursor.fetchall() port=self._db_settings.port,
user=self._db_settings.user,
password=self._db_settings.password,
db=self._db_settings.database,
minsize=1,
maxsize=Environment.get("DB_POOL_SIZE", int, 1),
)
except Exception as e: except Exception as e:
logger = ServiceProviderABC.get_global_service(DBLogger) _logger.fatal("Failed to connect to the database", e)
logger.fatal(f"Error connecting to the database: {e}") raise
finally: return self.pool
await con.close()
return self._pool
@staticmethod @staticmethod
async def _exec_sql(cursor: Any, query: str, args=None, multi=True): async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
result = []
if multi: if multi:
queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()] queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()]
for q in queries: for q in queries:
if q.strip() == "": if q.strip() == "":
continue continue
await cursor.execute(q, args) await cursor.execute(q, args)
if cursor.description is not None:
result = await cursor.fetchall()
else: else:
await cursor.execute(query, args) await cursor.execute(query, args)
if cursor.description is not None:
result = await cursor.fetchall()
return result
async def execute(self, query: str, args=None, multi=True) -> list[list]: async def execute(self, query: str, args=None, multi=True) -> list[list]:
"""
Execute a SQL statement, it could be with args and without args. The usage is
similar to the execute() function in aiomysql.
:param query: SQL clause
:param args: args needed by the SQL clause
:param multi: if the query is a multi-statement
:return: return result
"""
pool = await self._get_pool() pool = await self._get_pool()
con = await pool.get_connection() async with pool.acquire() as con:
try: async with con.cursor() as cursor:
async with await con.cursor() as cursor: await self._exec_sql(cursor, query, args, multi)
result = await self._exec_sql(cursor, query, args, multi)
await con.commit() await con.commit()
return result
finally: if cursor.description is not None: # Query returns rows
await con.close() res = await cursor.fetchall()
if res is None:
return []
return [list(row) for row in res]
else:
return []
async def select(self, query: str, args=None, multi=True) -> list[str]: async def select(self, query: str, args=None, multi=True) -> list[str]:
"""
Execute a SQL statement, it could be with args and without args. The usage is
similar to the execute() function in aiomysql.
:param query: SQL clause
:param args: args needed by the SQL clause
:param multi: if the query is a multi-statement
:return: return result
"""
pool = await self._get_pool() pool = await self._get_pool()
con = await pool.get_connection() async with pool.acquire() as con:
try: async with con.cursor() as cursor:
async with await con.cursor() as cursor: await self._exec_sql(cursor, query, args, multi)
res = await self._exec_sql(cursor, query, args, multi) res = await cursor.fetchall()
return list(res) return list(res)
finally:
await con.close()
async def select_map(self, query: str, args=None, multi=True) -> list[dict]: async def select_map(self, query: str, args=None, multi=True) -> list[dict]:
"""
Execute a SQL statement, it could be with args and without args. The usage is
similar to the execute() function in aiomysql.
:param query: SQL clause
:param args: args needed by the SQL clause
:param multi: if the query is a multi-statement
:return: return result
"""
pool = await self._get_pool() pool = await self._get_pool()
con = await pool.get_connection() async with pool.acquire() as con:
try: async with con.cursor(aiomysql.DictCursor) as cursor:
async with await con.cursor(dictionary=True) as cursor: await self._exec_sql(cursor, query, args, multi)
res = await self._exec_sql(cursor, query, args, multi) res = await cursor.fetchall()
return list(res) return list(res)
finally:
await con.close()

View File

@@ -7,16 +7,16 @@ from psycopg_pool import PoolTimeout
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.database.abc.db_context_abc import DBContextABC from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.logger import DBLogger from cpl.database.database_settings import DatabaseSettings
from cpl.database.model import DatabaseSettings from cpl.database.db_logger import DBLogger
from cpl.database.postgres.postgres_pool import PostgresPool from cpl.database.postgres.postgres_pool import PostgresPool
_logger = DBLogger(__name__)
class DBContext(DBContextABC): class DBContext(DBContextABC):
def __init__(self, logger: DBLogger): def __init__(self):
DBContextABC.__init__(self) DBContextABC.__init__(self)
self._logger = logger
self._pool: PostgresPool = None self._pool: PostgresPool = None
self._fails = 0 self._fails = 0
@@ -24,63 +24,63 @@ class DBContext(DBContextABC):
def connect(self, database_settings: DatabaseSettings): def connect(self, database_settings: DatabaseSettings):
try: try:
self._logger.debug("Connecting to database") _logger.debug("Connecting to database")
self._pool = PostgresPool( self._pool = PostgresPool(
database_settings, database_settings,
Environment.get("DB_POOL_SIZE", int, 1), Environment.get("DB_POOL_SIZE", int, 1),
) )
self._logger.info("Connected to database") _logger.info("Connected to database")
except Exception as e: except Exception as e:
self._logger.fatal("Connecting to database failed", e) _logger.fatal("Connecting to database failed", e)
async def execute(self, statement: str, args=None, multi=True) -> list[list]: async def execute(self, statement: str, args=None, multi=True) -> list[list]:
self._logger.trace(f"execute {statement} with args: {args}") _logger.trace(f"execute {statement} with args: {args}")
return await self._pool.execute(statement, args, multi) return await self._pool.execute(statement, args, multi)
async def select_map(self, statement: str, args=None) -> list[dict]: async def select_map(self, statement: str, args=None) -> list[dict]:
self._logger.trace(f"select {statement} with args: {args}") _logger.trace(f"select {statement} with args: {args}")
try: try:
return await self._pool.select_map(statement, args) return await self._pool.select_map(statement, args)
except (OperationalError, PoolTimeout) as e: except (OperationalError, PoolTimeout) as e:
if self._fails >= 3: if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4() uid = uuid.uuid4()
raise Exception( raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
) )
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1 self._fails += 1
try: try:
self._logger.debug("Retry select") _logger.debug("Retry select")
return await self.select_map(statement, args) return await self.select_map(statement, args)
except Exception as e: except Exception as e:
pass pass
return [] return []
except Exception as e: except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
raise e raise e
async def select(self, statement: str, args=None) -> list[str] | list[tuple] | list[Any]: async def select(self, statement: str, args=None) -> list[str] | list[tuple] | list[Any]:
self._logger.trace(f"select {statement} with args: {args}") _logger.trace(f"select {statement} with args: {args}")
try: try:
return await self._pool.select(statement, args) return await self._pool.select(statement, args)
except (OperationalError, PoolTimeout) as e: except (OperationalError, PoolTimeout) as e:
if self._fails >= 3: if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4() uid = uuid.uuid4()
raise Exception( raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}" f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
) )
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1 self._fails += 1
try: try:
self._logger.debug("Retry select") _logger.debug("Retry select")
return await self.select(statement, args) return await self.select(statement, args)
except Exception as e: except Exception as e:
pass pass
return [] return []
except Exception as e: except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e) _logger.error(f"Database error caused by `{statement}`", e)
raise e raise e

View File

@@ -5,9 +5,10 @@ from psycopg import sql
from psycopg_pool import AsyncConnectionPool, PoolTimeout from psycopg_pool import AsyncConnectionPool, PoolTimeout
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
from cpl.database.model import DatabaseSettings from cpl.database.model import DatabaseSettings
from cpl.dependency import ServiceProviderABC
_logger = DBLogger(__name__)
class PostgresPool: class PostgresPool:
@@ -24,24 +25,21 @@ class PostgresPool:
f"password={database_settings.password} " f"password={database_settings.password} "
f"dbname={database_settings.database}" f"dbname={database_settings.database}"
) )
self._pool: Optional[AsyncConnectionPool] = None
self.pool: Optional[AsyncConnectionPool] = None
async def _get_pool(self): async def _get_pool(self):
if self._pool is None: pool = AsyncConnectionPool(
pool = AsyncConnectionPool( conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1) )
) await pool.open()
await pool.open() try:
try: async with pool.connection() as con:
async with pool.connection() as con: await pool.check_connection(con)
await pool.check_connection(con) except PoolTimeout as e:
except PoolTimeout as e: await pool.close()
await pool.close() _logger.fatal(f"Failed to connect to the database", e)
logger = ServiceProviderABC.get_global_service(DBLogger) return pool
logger.fatal(f"Failed to connect to the database", e)
self._pool = pool
return self._pool
@staticmethod @staticmethod
async def _exec_sql(cursor: Any, query: str, args=None, multi=True): async def _exec_sql(cursor: Any, query: str, args=None, multi=True):

View File

@@ -1,11 +1,14 @@
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.database.db_logger import DBLogger
from cpl.database.schema.executed_migration import ExecutedMigration from cpl.database.schema.executed_migration import ExecutedMigration
_logger = DBLogger(__name__)
class ExecutedMigrationDao(DataAccessObjectABC[ExecutedMigration]): class ExecutedMigrationDao(DataAccessObjectABC[ExecutedMigration]):
def __init__(self): def __init__(self):
DataAccessObjectABC.__init__(self, ExecutedMigration, TableManager.get("executed_migrations")) DataAccessObjectABC.__init__(self, __name__, ExecutedMigration, TableManager.get("executed_migrations"))
self.attribute(ExecutedMigration.migration_id, str, primary_key=True, db_name="migrationId") self.attribute(ExecutedMigration.migration_id, str, primary_key=True, db_name="migrationId")

View File

@@ -2,17 +2,18 @@ import glob
import os import os
from cpl.database.abc import DBContextABC from cpl.database.abc import DBContextABC
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
from cpl.database.model import Migration from cpl.database.model import Migration
from cpl.database.model.server_type import ServerType, ServerTypes from cpl.database.model.server_type import ServerType, ServerTypes
from cpl.database.schema.executed_migration import ExecutedMigration from cpl.database.schema.executed_migration import ExecutedMigration
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
_logger = DBLogger(__name__)
class MigrationService: class MigrationService:
def __init__(self, logger: DBLogger, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao): def __init__(self, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao):
self._logger = logger
self._db = db self._db = db
self._executedMigrationDao = executedMigrationDao self._executedMigrationDao = executedMigrationDao
@@ -95,13 +96,13 @@ class MigrationService:
if migration_from_db is not None: if migration_from_db is not None:
continue continue
self._logger.debug(f"Running upgrade migration: {migration.name}") _logger.debug(f"Running upgrade migration: {migration.name}")
await self._db.execute(migration.script, multi=True) await self._db.execute(migration.script, multi=True)
await self._executedMigrationDao.create(ExecutedMigration(migration.name), skip_editor=True) await self._executedMigrationDao.create(ExecutedMigration(migration.name), skip_editor=True)
except Exception as e: except Exception as e:
self._logger.fatal( _logger.fatal(
f"Migration failed: {migration.name}\n{active_statement}", f"Migration failed: {migration.name}\n{active_statement}",
e, e,
) )

View File

@@ -1,16 +1,18 @@
from cpl.database.abc.data_seeder_abc import DataSeederABC from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.logger import DBLogger from cpl.database.db_logger import DBLogger
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProviderABC
_logger = DBLogger(__name__)
class SeederService: class SeederService:
def __init__(self, provider: ServiceProviderABC): def __init__(self, provider: ServiceProviderABC):
self._provider = provider self._provider = provider
self._logger = provider.get_service(DBLogger)
async def seed(self): async def seed(self):
seeders = self._provider.get_services(DataSeederABC) seeders = self._provider.get_services(DataSeederABC)
self._logger.debug(f"Found {len(seeders)} seeders") _logger.debug(f"Found {len(seeders)} seeders")
for seeder in seeders: for seeder in seeders:
await seeder.seed() await seeder.seed()

View File

@@ -33,7 +33,7 @@ class TableManager:
}, },
"role_users": { "role_users": {
ServerTypes.POSTGRES: "permission.role_users", ServerTypes.POSTGRES: "permission.role_users",
ServerTypes.MYSQL: "permission_role_auth_users", ServerTypes.MYSQL: "permission_role_users",
}, },
} }

View File

@@ -1,8 +1,8 @@
from typing import Union, Type, Callable, Self from typing import Union, Type, Callable
from cpl.core.log.logger import Logger
from cpl.core.log.logger_abc import LoggerABC from cpl.core.log.logger_abc import LoggerABC
from cpl.core.typing import T, Service from cpl.core.typing import T, Service
from cpl.core.utils.cache import Cache
from cpl.dependency.service_descriptor import ServiceDescriptor from cpl.dependency.service_descriptor import ServiceDescriptor
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
@@ -15,17 +15,12 @@ class ServiceCollection:
_modules: dict[str, Callable] = {} _modules: dict[str, Callable] = {}
@classmethod @classmethod
def with_module(cls, func: Callable, name: str = None) -> type[Self]: def with_module(cls, func: Callable, name: str = None):
cls._modules[func.__name__ if name is None else name] = func cls._modules[func.__name__ if name is None else name] = func
return cls return cls
def __init__(self): def __init__(self):
self._service_descriptors: list[ServiceDescriptor] = [] self._service_descriptors: list[ServiceDescriptor] = []
self._loaded_modules: set[str] = set()
@property
def loaded_modules(self) -> set[str]:
return self._loaded_modules
def _add_descriptor(self, service: Union[type, object], lifetime: ServiceLifetimeEnum, base_type: Callable = None): def _add_descriptor(self, service: Union[type, object], lifetime: ServiceLifetimeEnum, base_type: Callable = None):
found = False found = False
@@ -50,15 +45,15 @@ class ServiceCollection:
return self return self
def add_singleton(self, service_type: T, service: Service = None) -> Self: def add_singleton(self, service_type: T, service: Service = None):
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.singleton, service) self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.singleton, service)
return self return self
def add_scoped(self, service_type: T, service: Service = None) -> Self: def add_scoped(self, service_type: T, service: Service = None):
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.scoped, service) self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.scoped, service)
return self return self
def add_transient(self, service_type: T, service: Service = None) -> Self: def add_transient(self, service_type: T, service: Service = None):
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service) self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service)
return self return self
@@ -67,7 +62,7 @@ class ServiceCollection:
ServiceProviderABC.set_global_provider(sp) ServiceProviderABC.set_global_provider(sp)
return sp return sp
def add_module(self, module: str | object) -> Self: def add_module(self, module: str | object):
if not isinstance(module, str): if not isinstance(module, str):
module = module.__name__ module = module.__name__
@@ -75,29 +70,7 @@ class ServiceCollection:
raise ValueError(f"Module {module} not found") raise ValueError(f"Module {module} not found")
self._modules[module](self) self._modules[module](self)
if module not in self._loaded_modules:
self._loaded_modules.add(module)
return self
def add_logging(self) -> Self:
from cpl.core.log.logger import Logger
from cpl.core.log.wrapped_logger import WrappedLogger
def add_logging(self):
self.add_transient(LoggerABC, Logger) self.add_transient(LoggerABC, Logger)
for wrapper in WrappedLogger.__subclasses__():
self.add_transient(wrapper)
return self
def add_structured_logging(self) -> Self:
from cpl.core.log.structured_logger import StructuredLogger
from cpl.core.log.wrapped_logger import WrappedLogger
self.add_transient(LoggerABC, StructuredLogger)
for wrapper in WrappedLogger.__subclasses__():
self.add_transient(wrapper)
return self
def add_cache(self, t: Type[T]):
self._service_descriptors.append(ServiceDescriptor(Cache(t=t), ServiceLifetimeEnum.singleton, Cache[t]))
return self return self

View File

@@ -1,7 +1,7 @@
import copy import copy
import typing import typing
from inspect import signature, Parameter, Signature from inspect import signature, Parameter, Signature
from typing import Optional, Type from typing import Optional
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC
@@ -37,23 +37,8 @@ class ServiceProvider(ServiceProviderABC):
self._scope: Optional[ScopeABC] = None self._scope: Optional[ScopeABC] = None
def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]: def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]:
origin_type = typing.get_origin(service_type) or service_type
type_args = list(typing.get_args(service_type))
for descriptor in self._service_descriptors: for descriptor in self._service_descriptors:
descriptor_base_type = typing.get_origin(descriptor.base_type) or descriptor.base_type if descriptor.service_type == service_type or issubclass(descriptor.base_type, service_type):
descriptor_type_args = list(typing.get_args(descriptor.base_type))
if descriptor_base_type == origin_type and len(descriptor_type_args) == 0 and len(type_args) == 0:
return descriptor
if descriptor_base_type != origin_type or len(descriptor_type_args) != len(type_args):
continue
if descriptor_base_type == origin_type and type_args != descriptor_type_args:
continue
if descriptor.service_type == origin_type or issubclass(descriptor.base_type, origin_type):
return descriptor return descriptor
return None return None
@@ -92,7 +77,7 @@ class ServiceProvider(ServiceProviderABC):
return implementations return implementations
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[R]: def _build_by_signature(self, sig: Signature, origin_service_type: type=None) -> list[R]:
params = [] params = []
for param in sig.parameters.items(): for param in sig.parameters.items():
parameter = param[1] parameter = param[1]
@@ -173,12 +158,6 @@ class ServiceProvider(ServiceProviderABC):
return implementation return implementation
def get_service_type(self, service_type: Type[T]) -> Optional[Type[T]]:
for descriptor in self._service_descriptors:
if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type):
return descriptor.service_type
return None
def get_services(self, service_type: T, *args, **kwargs) -> list[Optional[R]]: def get_services(self, service_type: T, *args, **kwargs) -> list[Optional[R]]:
implementations = [] implementations = []
@@ -188,10 +167,3 @@ class ServiceProvider(ServiceProviderABC):
implementations.extend(self._get_services(service_type)) implementations.extend(self._get_services(service_type))
return implementations return implementations
def get_service_types(self, service_type: Type[T]) -> list[Type[T]]:
types = []
for descriptor in self._service_descriptors:
if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type):
types.append(descriptor.service_type)
return types

View File

@@ -24,19 +24,19 @@ class ServiceProviderABC(ABC):
return cls._provider return cls._provider
@classmethod @classmethod
def get_global_service(cls, instance_type: Type[T], *args, **kwargs) -> Optional[T]: def get_global_service(cls, instance_type: T, *args, **kwargs) -> Optional[R]:
if cls._provider is None: if cls._provider is None:
return None return None
return cls._provider.get_service(instance_type, *args, **kwargs) return cls._provider.get_service(instance_type, *args, **kwargs)
@classmethod @classmethod
def get_global_services(cls, instance_type: Type[T], *args, **kwargs) -> list[Optional[T]]: def get_global_services(cls, instance_type: T, *args, **kwargs) -> list[Optional[R]]:
if cls._provider is None: if cls._provider is None:
return [] return []
return cls._provider.get_services(instance_type, *args, **kwargs) return cls._provider.get_services(instance_type, *args, **kwargs)
@abstractmethod @abstractmethod
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[T]: ... def _build_by_signature(self, sig: Signature, origin_service_type: type=None) -> list[R]: ...
@abstractmethod @abstractmethod
def _build_service(self, service_type: type, *args, **kwargs) -> object: def _build_service(self, service_type: type, *args, **kwargs) -> object:
@@ -85,20 +85,6 @@ class ServiceProviderABC(ABC):
Object of type Optional[:class:`cpl.core.type.T`] Object of type Optional[:class:`cpl.core.type.T`]
""" """
@abstractmethod
def get_service_type(self, instance_type: Type[T]) -> Optional[Type[T]]:
r"""Returns the registered service type for loggers
Parameter
---------
instance_type: :class:`cpl.core.type.T`
The type of the searched instance
Returns
-------
Object of type Optional[:class:`type`]
"""
@abstractmethod @abstractmethod
def get_services(self, service_type: Type[T], *args, **kwargs) -> list[Optional[T]]: def get_services(self, service_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
r"""Returns instance of given type r"""Returns instance of given type
@@ -113,20 +99,6 @@ class ServiceProviderABC(ABC):
Object of type list[Optional[:class:`cpl.core.type.T`] Object of type list[Optional[:class:`cpl.core.type.T`]
""" """
@abstractmethod
def get_service_types(self, service_type: Type[T]) -> list[Type[T]]:
r"""Returns all registered service types
Parameter
---------
service_type: :class:`cpl.core.type.T`
The type of the searched instance
Returns
-------
Object of type list[:class:`type`]
"""
@classmethod @classmethod
def inject(cls, f=None): def inject(cls, f=None):
r"""Decorator to allow injection into static and class methods r"""Decorator to allow injection into static and class methods
@@ -142,24 +114,14 @@ class ServiceProviderABC(ABC):
if f is None: if f is None:
return functools.partial(cls.inject) return functools.partial(cls.inject)
if iscoroutinefunction(f):
@functools.wraps(f)
async def async_inner(*args, **kwargs):
if cls._provider is None:
raise Exception(f"{cls.__name__} not build!")
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
return await f(*args, *injection, **kwargs)
return async_inner
@functools.wraps(f) @functools.wraps(f)
def inner(*args, **kwargs): async def inner(*args, **kwargs):
if cls._provider is None: if cls._provider is None:
raise Exception(f"{cls.__name__} not build!") raise Exception(f"{cls.__name__} not build!")
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None] injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
if iscoroutinefunction(f):
return await f(*args, *injection, **kwargs)
return f(*args, *injection, **kwargs) return f(*args, *injection, **kwargs)
return inner return inner

View File

@@ -3,7 +3,7 @@ from .abc.email_client_abc import EMailClientABC
from .email_client import EMailClient from .email_client import EMailClient
from .email_client_settings import EMailClientSettings from .email_client_settings import EMailClientSettings
from .email_model import EMail from .email_model import EMail
from .logger import MailLogger from .mail_logger import MailLogger
def add_mail(collection: _ServiceCollection): def add_mail(collection: _ServiceCollection):

View File

@@ -5,7 +5,7 @@ from typing import Optional
from cpl.mail.abc.email_client_abc import EMailClientABC from cpl.mail.abc.email_client_abc import EMailClientABC
from cpl.mail.email_client_settings import EMailClientSettings from cpl.mail.email_client_settings import EMailClientSettings
from cpl.mail.email_model import EMail from cpl.mail.email_model import EMail
from cpl.mail.logger import MailLogger from cpl.mail.mail_logger import MailLogger
class EMailClient(EMailClientABC): class EMailClient(EMailClientABC):

View File

@@ -1,7 +0,0 @@
from cpl.core.log.wrapped_logger import WrappedLogger
class MailLogger(WrappedLogger):
def __init__(self):
WrappedLogger.__init__(self, "mail")

View File

@@ -0,0 +1,8 @@
from cpl.core.log.logger import Logger
from cpl.core.typing import Source
class MailLogger(Logger):
def __init__(self, source: Source):
Logger.__init__(self, source, "mail")

View File

@@ -1,7 +1 @@
from .array import Array
from .enumerable import Enumerable
from .immutable_list import ImmutableList
from .immutable_set import ImmutableSet
from .list import List
from .ordered_enumerable import OrderedEnumerable
from .set import Set

View File

@@ -0,0 +1,2 @@
def is_number(t: type) -> bool:
return issubclass(t, int) or issubclass(t, float) or issubclass(t, complex)

View File

@@ -1,44 +0,0 @@
from typing import Generic, Iterable, Optional
from cpl.core.typing import T
from cpl.query.list import List
from cpl.query.enumerable import Enumerable
class Array(Generic[T], List[T]):
def __init__(self, length: int, source: Optional[Iterable[T]] = None):
List.__init__(self, source)
self._length = length
@property
def length(self) -> int:
return len(self._source)
def add(self, item: T) -> None:
if self._length == self.length:
raise IndexError("Array is full")
self._source.append(item)
def extend(self, items: Iterable[T]) -> None:
if self._length == self.length:
raise IndexError("Array is full")
self._source.extend(items)
def insert(self, index: int, item: T) -> None:
if index < 0 or index > self.length:
raise IndexError("Index out of range")
self._source.insert(index, item)
def remove(self, item: T) -> None:
self._source.remove(item)
def pop(self, index: int = -1) -> T:
return self._source.pop(index)
def clear(self) -> None:
self._source.clear()
def to_enumerable(self) -> "Enumerable[T]":
from cpl.query.enumerable import Enumerable
return Enumerable(self._source)

View File

@@ -0,0 +1,5 @@
from .default_lambda import default_lambda
from .ordered_queryable import OrderedQueryable
from .sequence import Sequence
from .ordered_queryable_abc import OrderedQueryableABC
from .queryable_abc import QueryableABC

View File

@@ -0,0 +1,2 @@
def default_lambda(x: object):
return x

View File

@@ -0,0 +1,34 @@
from collections.abc import Callable
from cpl.query.base.ordered_queryable_abc import OrderedQueryableABC
from cpl.query.exceptions import ArgumentNoneException, ExceptionArgument
class OrderedQueryable(OrderedQueryableABC):
r"""Implementation of :class: `cpl.query.base.ordered_queryable_abc.OrderedQueryableABC`"""
def __init__(self, _t: type, _values: OrderedQueryableABC = None, _func: Callable = None):
OrderedQueryableABC.__init__(self, _t, _values, _func)
def then_by(self, _func: Callable) -> OrderedQueryableABC:
if self is None:
raise ArgumentNoneException(ExceptionArgument.list)
if _func is None:
raise ArgumentNoneException(ExceptionArgument.func)
self._funcs.append(_func)
return OrderedQueryable(self.type, sorted(self, key=lambda *args: [f(*args) for f in self._funcs]), _func)
def then_by_descending(self, _func: Callable) -> OrderedQueryableABC:
if self is None:
raise ArgumentNoneException(ExceptionArgument.list)
if _func is None:
raise ArgumentNoneException(ExceptionArgument.func)
self._funcs.append(_func)
return OrderedQueryable(
self.type, sorted(self, key=lambda *args: [f(*args) for f in self._funcs], reverse=True), _func
)

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
from abc import abstractmethod
from collections.abc import Callable
from typing import Iterable
from cpl.query.base.queryable_abc import QueryableABC
class OrderedQueryableABC(QueryableABC):
@abstractmethod
def __init__(self, _t: type, _values: Iterable = None, _func: Callable = None):
QueryableABC.__init__(self, _t, _values)
self._funcs: list[Callable] = []
if _func is not None:
self._funcs.append(_func)
@abstractmethod
def then_by(self, func: Callable) -> OrderedQueryableABC:
r"""Sorts OrderedList in ascending order by function
Parameter:
func: :class:`Callable`
Returns:
list of :class:`cpl.query.base.ordered_queryable_abc.OrderedQueryableABC`
"""
@abstractmethod
def then_by_descending(self, func: Callable) -> OrderedQueryableABC:
r"""Sorts OrderedList in descending order by function
Parameter:
func: :class:`Callable`
Returns:
list of :class:`cpl.query.base.ordered_queryable_abc.OrderedQueryableABC`
"""

View File

@@ -0,0 +1,569 @@
from __future__ import annotations
from typing import Optional, Callable, Union, Iterable, Any
from cpl.query._helper import is_number
from cpl.query.base import default_lambda
from cpl.query.base.sequence import Sequence
from cpl.query.exceptions import (
InvalidTypeException,
ArgumentNoneException,
ExceptionArgument,
IndexOutOfRangeException,
)
class QueryableABC(Sequence):
def __init__(self, t: type, values: Iterable = None):
Sequence.__init__(self, t, values)
def all(self, _func: Callable = None) -> bool:
r"""Checks if every element of list equals result found by function
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
bool
"""
if _func is None:
_func = default_lambda
return self.count(_func) == self.count()
def any(self, _func: Callable = None) -> bool:
r"""Checks if list contains result found by function
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
bool
"""
if _func is None:
_func = default_lambda
return self.where(_func).count() > 0
def average(self, _func: Callable = None) -> Union[int, float, complex]:
r"""Returns average value of list
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
Union[int, float, complex]
"""
if _func is None and not is_number(self.type):
raise InvalidTypeException()
return self.sum(_func) / self.count()
def contains(self, _value: object) -> bool:
r"""Checks if list contains value given by function
Parameter
---------
value: :class:`object`
value
Returns
-------
bool
"""
if _value is None:
raise ArgumentNoneException(ExceptionArgument.value)
return self.where(lambda x: x == _value).count() > 0
def count(self, _func: Callable = None) -> int:
r"""Returns length of list or count of found elements
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
int
"""
if _func is None:
return self.__len__()
return self.where(_func).count()
def distinct(self, _func: Callable = None) -> QueryableABC:
r"""Returns list without redundancies
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
if _func is None:
_func = default_lambda
result = []
known_values = []
for element in self:
value = _func(element)
if value in known_values:
continue
known_values.append(value)
result.append(element)
return type(self)(self._type, result)
def element_at(self, _index: int) -> any:
r"""Returns element at given index
Parameter
---------
_index: :class:`int`
index
Returns
-------
Value at _index: any
"""
if _index is None:
raise ArgumentNoneException(ExceptionArgument.index)
if _index < 0 or _index >= self.count():
raise IndexOutOfRangeException
result = self._values[_index]
if result is None:
raise IndexOutOfRangeException
return result
def element_at_or_default(self, _index: int) -> Optional[any]:
r"""Returns element at given index or None
Parameter
---------
_index: :class:`int`
index
Returns
-------
Value at _index: Optional[any]
"""
if _index is None:
raise ArgumentNoneException(ExceptionArgument.index)
try:
return self._values[_index]
except IndexError:
return None
def first(self) -> any:
r"""Returns first element
Returns
-------
First element of list: any
"""
if self.count() == 0:
raise IndexOutOfRangeException()
return self._values[0]
def first_or_default(self) -> any:
r"""Returns first element or None
Returns
-------
First element of list: Optional[any]
"""
if self.count() == 0:
return None
return self._values[0]
def for_each(self, _func: Callable = None):
r"""Runs given function for each element of list
Parameter
---------
func: :class: `Callable`
function to call
"""
if _func is not None:
for element in self:
_func(element)
return self
def group_by(self, _func: Callable = None) -> QueryableABC:
r"""Groups by func
Returns
-------
Grouped list[list[any]]: any
"""
if _func is None:
_func = default_lambda
groups = {}
for v in self:
value = _func(v)
if v not in groups:
groups[value] = []
groups[value].append(v)
v = []
for g in groups.values():
v.append(type(self)(object, g))
x = type(self)(type(self), v)
return x
def last(self) -> any:
r"""Returns last element
Returns
-------
Last element of list: any
"""
if self.count() == 0:
raise IndexOutOfRangeException()
return self._values[self.count() - 1]
def last_or_default(self) -> any:
r"""Returns last element or None
Returns
-------
Last element of list: Optional[any]
"""
if self.count() == 0:
return None
return self._values[self.count() - 1]
def max(self, _func: Callable = None) -> object:
r"""Returns the highest value
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
object
"""
if _func is None and not is_number(self.type):
raise InvalidTypeException()
if _func is None:
_func = default_lambda
return _func(max(self, key=_func))
def median(self, _func=None) -> Union[int, float]:
r"""Return the median value of data elements
Returns
-------
Union[int, float]
"""
if _func is None:
_func = default_lambda
result = self.order_by(_func).select(_func).to_list()
length = len(result)
i = int(length / 2)
return result[i] if length % 2 == 1 else (float(result[i - 1]) + float(result[i])) / float(2)
def min(self, _func: Callable = None) -> object:
r"""Returns the lowest value
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
object
"""
if _func is None and not is_number(self.type):
raise InvalidTypeException()
if _func is None:
_func = default_lambda
return _func(min(self, key=_func))
def order_by(self, _func: Callable = None) -> "OrderedQueryableABC":
r"""Sorts elements by function in ascending order
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
:class: `cpl.query.base.ordered_queryable_abc.OrderedQueryableABC`
"""
if _func is None:
_func = default_lambda
from cpl.query.base.ordered_queryable import OrderedQueryable
return OrderedQueryable(self.type, sorted(self, key=_func), _func)
def order_by_descending(self, _func: Callable = None) -> "OrderedQueryableABC":
r"""Sorts elements by function in descending order
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
:class: `cpl.query.base.ordered_queryable_abc.OrderedQueryableABC`
"""
if _func is None:
_func = default_lambda
from cpl.query.base.ordered_queryable import OrderedQueryable
return OrderedQueryable(self.type, sorted(self, key=_func, reverse=True), _func)
def reverse(self) -> QueryableABC:
r"""Reverses list
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
return type(self)(self._type, reversed(self._values))
def select(self, _func: Callable) -> QueryableABC:
r"""Formats each element of list to a given format
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
if _func is None:
_func = default_lambda
_l = [_func(_o) for _o in self]
_t = type(_l[0]) if len(_l) > 0 else Any
return type(self)(_t, _l)
def select_many(self, _func: Callable) -> QueryableABC:
r"""Flattens resulting lists to one
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
# The line below is pain. I don't understand anything of it...
# written on 09.11.2022 by Sven Heidemann
return type(self)(object, [_a for _o in self for _a in _func(_o)])
def single(self) -> any:
r"""Returns one single element of list
Returns
-------
Found value: any
Raises
------
ArgumentNoneException: when argument is None
Exception: when argument is None or found more than one element
"""
if self.count() > 1:
raise Exception("Found more than one element")
elif self.count() == 0:
raise Exception("Found no element")
return self._values[0]
def single_or_default(self) -> Optional[any]:
r"""Returns one single element of list
Returns
-------
Found value: Optional[any]
"""
if self.count() > 1:
raise Exception("Index out of range")
elif self.count() == 0:
return None
return self._values[0]
def skip(self, _index: int) -> QueryableABC:
r"""Skips all elements from index
Parameter
---------
_index: :class:`int`
index
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
if _index is None:
raise ArgumentNoneException(ExceptionArgument.index)
return type(self)(self.type, self._values[_index:])
def skip_last(self, _index: int) -> QueryableABC:
r"""Skips all elements after index
Parameter
---------
_index: :class:`int`
index
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
if _index is None:
raise ArgumentNoneException(ExceptionArgument.index)
index = self.count() - _index
return type(self)(self._type, self._values[:index])
def sum(self, _func: Callable = None) -> Union[int, float, complex]:
r"""Sum of all values
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
Union[int, float, complex]
"""
if _func is None and not is_number(self.type):
raise InvalidTypeException()
if _func is None:
_func = default_lambda
result = 0
for x in self:
result += _func(x)
return result
def split(self, _func: Callable) -> QueryableABC:
r"""Splits the list by given function
Parameter
---------
func: :class:`Callable`
seperator
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
groups = []
group = []
for x in self:
v = _func(x)
if x == v:
groups.append(group)
group = []
group.append(x)
groups.append(group)
query_groups = []
for g in groups:
if len(g) == 0:
continue
query_groups.append(type(self)(self._type, g))
return type(self)(self._type, query_groups)
def take(self, _index: int) -> QueryableABC:
r"""Takes all elements from index
Parameter
---------
_index: :class:`int`
index
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
if _index is None:
raise ArgumentNoneException(ExceptionArgument.index)
return type(self)(self._type, self._values[:_index])
def take_last(self, _index: int) -> QueryableABC:
r"""Takes all elements after index
Parameter
---------
_index: :class:`int`
index
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
index = self.count() - _index
if index >= self.count() or index < 0:
raise IndexOutOfRangeException()
return type(self)(self._type, self._values[index:])
def where(self, _func: Callable = None) -> QueryableABC:
r"""Select element by function
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
if _func is None:
raise ArgumentNoneException(ExceptionArgument.func)
if _func is None:
_func = default_lambda
return type(self)(self.type, filter(_func, self))

View File

@@ -0,0 +1,96 @@
from abc import abstractmethod, ABC
from typing import Iterable
class Sequence(ABC):
@abstractmethod
def __init__(self, t: type, values: Iterable = None):
assert t is not None
assert isinstance(t, type) or t == any
assert values is None or isinstance(values, Iterable)
if values is None:
values = []
self._values = list(values)
if t is None:
t = object
self._type = t
def __iter__(self):
return iter(self._values)
def __next__(self):
return next(iter(self._values))
def __len__(self):
return self.to_list().__len__()
@classmethod
def __class_getitem__(cls, _t: type) -> type:
return _t
def __repr__(self):
return f"<{type(self).__name__} {self.to_list().__repr__()}>"
@property
def type(self) -> type:
return self._type
def _check_type(self, __object: any):
if self._type == any:
return
if (
self._type is not None
and type(__object) != self._type
and not isinstance(type(__object), self._type)
and not issubclass(type(__object), self._type)
):
raise Exception(f"Unexpected type: {type(__object)}\nExpected type: {self._type}")
def to_list(self) -> list:
r"""Converts :class: `cpl.query.base.sequence_abc.SequenceABC` to :class: `list`
Returns:
:class: `list`
"""
return [x for x in self._values]
def copy(self) -> "Sequence":
r"""Creates a copy of sequence
Returns:
Sequence
"""
return type(self)(self._type, self.to_list())
@classmethod
def empty(cls) -> "Sequence":
r"""Returns an empty sequence
Returns:
Sequence object that contains no elements
"""
return cls(object, [])
def index_of(self, _object: object) -> int:
r"""Returns the index of given element
Returns:
Index of object
Raises:
IndexError if object not in sequence
"""
for i, o in enumerate(self):
if o == _object:
return i
raise IndexError
@classmethod
def range(cls, start: int, length: int) -> "Sequence":
return cls(int, range(start, length))

Some files were not shown because too many files have changed in this diff Show More