Compare commits
15 Commits
2025.09.19
...
2025.09.23
| Author | SHA1 | Date | |
|---|---|---|---|
| e296c0992b | |||
| 6639946346 | |||
| b9ac11e15f | |||
| 77d821bb6e | |||
| 86ad953ff1 | |||
| d6b7eb9b30 | |||
| 12b7c62b69 | |||
| 7fc70747bb | |||
| 6de4f3c03a | |||
| ea3055527c | |||
| 7b37748ca6 | |||
| 073b35f71a | |||
| eceff6128b | |||
| 17dfb245bf | |||
| 4f698269b5 |
@@ -16,7 +16,7 @@ jobs:
|
|||||||
uses: ./.gitea/workflows/package.yaml
|
uses: ./.gitea/workflows/package.yaml
|
||||||
needs: [ prepare, application, auth, core, dependency ]
|
needs: [ prepare, application, auth, core, dependency ]
|
||||||
with:
|
with:
|
||||||
working_directory: src/cpl-application
|
working_directory: src/cpl-api
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
application:
|
application:
|
||||||
|
|||||||
26
.gitea/workflows/test_before_merge.yaml
Normal file
26
.gitea/workflows/test_before_merge.yaml
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
name: Test before pr merge
|
||||||
|
run-name: Test before pr merge
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types:
|
||||||
|
- opened
|
||||||
|
- edited
|
||||||
|
- reopened
|
||||||
|
- synchronize
|
||||||
|
- ready_for_review
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test-lint:
|
||||||
|
runs-on: [ runner ]
|
||||||
|
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
|
||||||
|
steps:
|
||||||
|
- name: Clone Repository
|
||||||
|
uses: https://github.com/actions/checkout@v3
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.CI_ACCESS_TOKEN }}
|
||||||
|
|
||||||
|
- name: Installing black
|
||||||
|
run: python3.12 -m pip install black
|
||||||
|
|
||||||
|
- name: Checking black
|
||||||
|
run: python3.12 -m black src --check
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -139,3 +139,6 @@ PythonImportHelper-v2-Completion.json
|
|||||||
|
|
||||||
# cpl unittest stuff
|
# cpl unittest stuff
|
||||||
unittests/test_*_playground
|
unittests/test_*_playground
|
||||||
|
|
||||||
|
# cpl logs
|
||||||
|
**/logs/*.jsonl
|
||||||
|
|||||||
61
install.sh
Normal file
61
install.sh
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Find and combine requirements from src/cpl-*/requirements.txt,
|
||||||
|
# filtering out lines whose *package name* starts with "cpl-".
|
||||||
|
# Works with pinned versions, extras, markers, editable installs, and VCS refs.
|
||||||
|
|
||||||
|
shopt -s nullglob
|
||||||
|
|
||||||
|
req_files=(src/cpl-*/requirements.txt)
|
||||||
|
if ((${#req_files[@]} == 0)); then
|
||||||
|
echo "No requirements files found at src/cpl-*/requirements.txt" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
tmp_combined="$(mktemp)"
|
||||||
|
trap 'rm -f "$tmp_combined"' EXIT
|
||||||
|
|
||||||
|
# Concatenate, trim comments/whitespace, filter out cpl-* packages, dedupe.
|
||||||
|
# We keep non-package options/flags/constraints as-is.
|
||||||
|
awk '
|
||||||
|
function trim(s){ sub(/^[[:space:]]+/,"",s); sub(/[[:space:]]+$/,"",s); return s }
|
||||||
|
|
||||||
|
{
|
||||||
|
line=$0
|
||||||
|
# drop full-line comments and strip inline comments
|
||||||
|
if (line ~ /^[[:space:]]*#/) next
|
||||||
|
sub(/#[^!].*$/,"",line) # strip trailing comment (simple heuristic)
|
||||||
|
line=trim(line)
|
||||||
|
if (line == "") next
|
||||||
|
|
||||||
|
# Determine the package *name* even for "-e", extras, pins, markers, or VCS "@"
|
||||||
|
e = line
|
||||||
|
sub(/^-e[[:space:]]+/,"",e) # remove editable prefix
|
||||||
|
# Tokenize up to the first of these separators: space, [ < > = ! ~ ; @
|
||||||
|
token = e
|
||||||
|
sub(/\[.*/,"",token) # remove extras quickly
|
||||||
|
n = split(token, a, /[<>=!~;@[:space:]]/)
|
||||||
|
name = tolower(a[1])
|
||||||
|
|
||||||
|
# If the first token (name) starts with "cpl-", skip this requirement
|
||||||
|
if (name ~ /^cpl-/) next
|
||||||
|
|
||||||
|
print line
|
||||||
|
}
|
||||||
|
' "${req_files[@]}" | sort -u > "$tmp_combined"
|
||||||
|
|
||||||
|
if ! [ -s "$tmp_combined" ]; then
|
||||||
|
echo "Nothing to install after filtering out cpl-* packages." >&2
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Installing dependencies (excluding cpl-*) from:"
|
||||||
|
printf ' - %s\n' "${req_files[@]}"
|
||||||
|
echo
|
||||||
|
echo "Final set to install:"
|
||||||
|
cat "$tmp_combined"
|
||||||
|
echo
|
||||||
|
|
||||||
|
# Use python -m pip for reliability; change to python3 if needed.
|
||||||
|
python -m pip install -r "$tmp_combined"
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
|
||||||
|
|
||||||
|
from .error import APIError, AlreadyExists, EndpointNotImplemented, Forbidden, NotFound, Unauthorized
|
||||||
|
from .logger import APILogger
|
||||||
|
from .settings import ApiSettings
|
||||||
|
|
||||||
|
|
||||||
|
def add_api(collection: _ServiceCollection):
|
||||||
|
try:
|
||||||
|
from cpl.database import mysql
|
||||||
|
|
||||||
|
collection.add_module(mysql)
|
||||||
|
except ImportError as e:
|
||||||
|
from cpl.core.errors import dependency_error
|
||||||
|
|
||||||
|
dependency_error("cpl-database", e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from cpl import auth
|
||||||
|
from cpl.auth import permission
|
||||||
|
|
||||||
|
collection.add_module(auth)
|
||||||
|
collection.add_module(permission)
|
||||||
|
except ImportError as e:
|
||||||
|
from cpl.core.errors import dependency_error
|
||||||
|
|
||||||
|
dependency_error("cpl-auth", e)
|
||||||
|
|
||||||
|
from cpl.api.registry.policy import PolicyRegistry
|
||||||
|
from cpl.api.registry.route import RouteRegistry
|
||||||
|
|
||||||
|
collection.add_singleton(PolicyRegistry)
|
||||||
|
collection.add_singleton(RouteRegistry)
|
||||||
|
|
||||||
|
|
||||||
|
_ServiceCollection.with_module(add_api, __name__)
|
||||||
|
|||||||
1
src/cpl-api/cpl/api/abc/__init__.py
Normal file
1
src/cpl-api/cpl/api/abc/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .asgi_middleware_abc import ASGIMiddleware
|
||||||
15
src/cpl-api/cpl/api/abc/asgi_middleware_abc.py
Normal file
15
src/cpl-api/cpl/api/abc/asgi_middleware_abc.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from starlette.types import Scope, Receive, Send
|
||||||
|
|
||||||
|
|
||||||
|
class ASGIMiddleware(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, app):
|
||||||
|
self._app = app
|
||||||
|
|
||||||
|
def _call_next(self, scope: Scope, receive: Receive, send: Send):
|
||||||
|
return self._app(scope, receive, send)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send): ...
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
from cpl.core.log.logger import Logger
|
|
||||||
|
|
||||||
|
|
||||||
class APILogger(Logger):
|
|
||||||
|
|
||||||
def __init__(self, source: str):
|
|
||||||
Logger.__init__(self, source, "api")
|
|
||||||
1
src/cpl-api/cpl/api/application/__init__.py
Normal file
1
src/cpl-api/cpl/api/application/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .web_app import WebApp
|
||||||
249
src/cpl-api/cpl/api/application/web_app.py
Normal file
249
src/cpl-api/cpl/api/application/web_app.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
import os
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Mapping, Any, Callable, Self, Union
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.middleware import Middleware
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.types import ExceptionHandler
|
||||||
|
|
||||||
|
from cpl import api, auth
|
||||||
|
from cpl.api.error import APIError
|
||||||
|
from cpl.api.logger import APILogger
|
||||||
|
from cpl.api.middleware.authentication import AuthenticationMiddleware
|
||||||
|
from cpl.api.middleware.authorization import AuthorizationMiddleware
|
||||||
|
from cpl.api.middleware.logging import LoggingMiddleware
|
||||||
|
from cpl.api.middleware.request import RequestMiddleware
|
||||||
|
from cpl.api.model.api_route import ApiRoute
|
||||||
|
from cpl.api.model.policy import Policy
|
||||||
|
from cpl.api.model.validation_match import ValidationMatch
|
||||||
|
from cpl.api.registry.policy import PolicyRegistry
|
||||||
|
from cpl.api.registry.route import RouteRegistry
|
||||||
|
from cpl.api.router import Router
|
||||||
|
from cpl.api.settings import ApiSettings
|
||||||
|
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
|
||||||
|
from cpl.application.abc.application_abc import ApplicationABC
|
||||||
|
from cpl.core.configuration import Configuration
|
||||||
|
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||||
|
|
||||||
|
|
||||||
|
PolicyInput = Union[dict[str, PolicyResolver], Policy]
|
||||||
|
|
||||||
|
|
||||||
|
class WebApp(ApplicationABC):
|
||||||
|
def __init__(self, services: ServiceProviderABC):
|
||||||
|
super().__init__(services, [auth, api])
|
||||||
|
self._app: Starlette | None = None
|
||||||
|
|
||||||
|
self._logger = services.get_service(APILogger)
|
||||||
|
|
||||||
|
self._api_settings = Configuration.get(ApiSettings)
|
||||||
|
self._policies = services.get_service(PolicyRegistry)
|
||||||
|
self._routes = services.get_service(RouteRegistry)
|
||||||
|
|
||||||
|
self._middleware: list[Middleware] = [
|
||||||
|
Middleware(RequestMiddleware),
|
||||||
|
Middleware(LoggingMiddleware),
|
||||||
|
]
|
||||||
|
self._exception_handlers: Mapping[Any, ExceptionHandler] = {
|
||||||
|
Exception: self._handle_exception,
|
||||||
|
APIError: self._handle_exception,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _handle_exception(self, request: Request, exc: Exception):
|
||||||
|
if isinstance(exc, APIError):
|
||||||
|
self._logger.error(exc)
|
||||||
|
return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
|
||||||
|
|
||||||
|
if hasattr(request.state, "request_id"):
|
||||||
|
self._logger.error(f"Request {request.state.request_id}", exc)
|
||||||
|
else:
|
||||||
|
self._logger.error("Request unknown", exc)
|
||||||
|
|
||||||
|
return JSONResponse({"error": str(exc)}, status_code=500)
|
||||||
|
|
||||||
|
def _get_allowed_origins(self):
|
||||||
|
origins = self._api_settings.allowed_origins
|
||||||
|
|
||||||
|
if origins is None or origins == "":
|
||||||
|
self._logger.warning("No allowed origins specified, allowing all origins")
|
||||||
|
return ["*"]
|
||||||
|
|
||||||
|
self._logger.debug(f"Allowed origins: {origins}")
|
||||||
|
return origins.split(",")
|
||||||
|
|
||||||
|
def with_database(self) -> Self:
|
||||||
|
self.with_migrations()
|
||||||
|
self.with_seeders()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_app(self, app: Starlette) -> Self:
|
||||||
|
assert app is not None, "app must not be None"
|
||||||
|
assert isinstance(app, Starlette), "app must be an instance of Starlette"
|
||||||
|
self._app = app
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _check_for_app(self):
|
||||||
|
if self._app is not None:
|
||||||
|
raise ValueError("App is already set, cannot add routes or middleware")
|
||||||
|
|
||||||
|
def with_routes_directory(self, directory: str) -> Self:
|
||||||
|
self._check_for_app()
|
||||||
|
assert directory is not None, "directory must not be None"
|
||||||
|
|
||||||
|
base = directory.replace("/", ".").replace("\\", ".")
|
||||||
|
|
||||||
|
for filename in os.listdir(directory):
|
||||||
|
if not filename.endswith(".py") or filename == "__init__.py":
|
||||||
|
continue
|
||||||
|
|
||||||
|
__import__(f"{base}.{filename[:-3]}")
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_routes(
|
||||||
|
self,
|
||||||
|
routes: list[ApiRoute],
|
||||||
|
method: HTTPMethods,
|
||||||
|
authentication: bool = False,
|
||||||
|
roles: list[str | Enum] = None,
|
||||||
|
permissions: list[str | Enum] = None,
|
||||||
|
policies: list[str] = None,
|
||||||
|
match: ValidationMatch = None,
|
||||||
|
) -> Self:
|
||||||
|
self._check_for_app()
|
||||||
|
assert self._routes is not None, "routes must not be None"
|
||||||
|
assert all(isinstance(route, ApiRoute) for route in routes), "all routes must be of type ApiRoute"
|
||||||
|
for route in routes:
|
||||||
|
self.with_route(
|
||||||
|
route.path,
|
||||||
|
route.fn,
|
||||||
|
method,
|
||||||
|
authentication,
|
||||||
|
roles,
|
||||||
|
permissions,
|
||||||
|
policies,
|
||||||
|
match,
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_route(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
fn: Callable[[Request], Any],
|
||||||
|
method: HTTPMethods,
|
||||||
|
authentication: bool = False,
|
||||||
|
roles: list[str | Enum] = None,
|
||||||
|
permissions: list[str | Enum] = None,
|
||||||
|
policies: list[str] = None,
|
||||||
|
match: ValidationMatch = None,
|
||||||
|
) -> Self:
|
||||||
|
self._check_for_app()
|
||||||
|
assert path is not None, "path must not be None"
|
||||||
|
assert fn is not None, "fn must not be None"
|
||||||
|
assert method in [
|
||||||
|
"GET",
|
||||||
|
"HEAD",
|
||||||
|
"POST",
|
||||||
|
"PUT",
|
||||||
|
"PATCH",
|
||||||
|
"DELETE",
|
||||||
|
"OPTIONS",
|
||||||
|
], "method must be a valid HTTP method"
|
||||||
|
|
||||||
|
Router.route(path, method, registry=self._routes)(fn)
|
||||||
|
|
||||||
|
if authentication:
|
||||||
|
Router.authenticate()(fn)
|
||||||
|
|
||||||
|
if roles or permissions or policies:
|
||||||
|
Router.authorize(roles, permissions, policies, match)(fn)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_middleware(self, middleware: PartialMiddleware) -> Self:
|
||||||
|
self._check_for_app()
|
||||||
|
|
||||||
|
if isinstance(middleware, Middleware):
|
||||||
|
self._middleware.append(middleware)
|
||||||
|
elif callable(middleware):
|
||||||
|
self._middleware.append(Middleware(middleware))
|
||||||
|
else:
|
||||||
|
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_authentication(self) -> Self:
|
||||||
|
self.with_middleware(AuthenticationMiddleware)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_authorization(self, *policies: list[PolicyInput] | PolicyInput) -> Self:
|
||||||
|
if policies:
|
||||||
|
_policies = []
|
||||||
|
|
||||||
|
if not isinstance(policies, list):
|
||||||
|
policies = list(policies)
|
||||||
|
|
||||||
|
for i, policy in enumerate(policies):
|
||||||
|
if isinstance(policy, dict):
|
||||||
|
for name, resolver in policy.items():
|
||||||
|
if not isinstance(name, str):
|
||||||
|
self._logger.warning(f"Skipping policy at index {i}, name must be a string")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not callable(resolver):
|
||||||
|
self._logger.warning(f"Skipping policy {name}, resolver must be callable")
|
||||||
|
continue
|
||||||
|
|
||||||
|
_policies.append(Policy(name, resolver))
|
||||||
|
continue
|
||||||
|
|
||||||
|
_policies.append(policy)
|
||||||
|
|
||||||
|
self._policies.extend(_policies)
|
||||||
|
|
||||||
|
self.with_middleware(AuthorizationMiddleware)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _validate_policies(self):
|
||||||
|
for rule in Router.get_authorization_rules():
|
||||||
|
for policy_name in rule["policies"]:
|
||||||
|
policy = self._policies.get(policy_name)
|
||||||
|
if not policy:
|
||||||
|
self._logger.fatal(f"Authorization policy '{policy_name}' not found")
|
||||||
|
|
||||||
|
async def main(self):
|
||||||
|
self._logger.debug(f"Preparing API")
|
||||||
|
self._validate_policies()
|
||||||
|
|
||||||
|
if self._app is None:
|
||||||
|
routes = [route.to_starlette(self._services.inject) for route in self._routes.all()]
|
||||||
|
|
||||||
|
app = Starlette(
|
||||||
|
routes=routes,
|
||||||
|
middleware=[
|
||||||
|
*self._middleware,
|
||||||
|
Middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=self._get_allowed_origins(),
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
exception_handlers=self._exception_handlers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
app = self._app
|
||||||
|
|
||||||
|
self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
|
||||||
|
|
||||||
|
config = uvicorn.Config(
|
||||||
|
app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio"
|
||||||
|
)
|
||||||
|
server = uvicorn.Server(config)
|
||||||
|
await server.serve()
|
||||||
|
|
||||||
|
self._logger.info("Shutdown API")
|
||||||
@@ -1,9 +1,30 @@
|
|||||||
from http.client import HTTPException
|
from http.client import HTTPException
|
||||||
|
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.types import Scope, Receive, Send
|
||||||
|
|
||||||
|
|
||||||
class APIError(HTTPException):
|
class APIError(HTTPException):
|
||||||
status_code = 500
|
status_code = 500
|
||||||
|
|
||||||
|
def __init__(self, message: str = ""):
|
||||||
|
super().__init__(self.status_code, message)
|
||||||
|
self._message = message
|
||||||
|
|
||||||
|
@property
|
||||||
|
def error_message(self) -> str:
|
||||||
|
if self._message:
|
||||||
|
return f"{type(self).__name__}: {self._message}"
|
||||||
|
|
||||||
|
return f"{type(self).__name__}"
|
||||||
|
|
||||||
|
async def asgi_response(self, scope: Scope, receive: Receive, send: Send):
|
||||||
|
r = JSONResponse({"error": self.error_message}, status_code=self.status_code)
|
||||||
|
return await r(scope, receive, send)
|
||||||
|
|
||||||
|
def response(self):
|
||||||
|
return JSONResponse({"error": self.error_message}, status_code=self.status_code)
|
||||||
|
|
||||||
|
|
||||||
class Unauthorized(APIError):
|
class Unauthorized(APIError):
|
||||||
status_code = 401
|
status_code = 401
|
||||||
|
|||||||
7
src/cpl-api/cpl/api/logger.py
Normal file
7
src/cpl-api/cpl/api/logger.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||||
|
|
||||||
|
|
||||||
|
class APILogger(WrappedLogger):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
WrappedLogger.__init__(self, "api")
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
from .authentication import AuthenticationMiddleware
|
||||||
|
from .authorization import AuthorizationMiddleware
|
||||||
|
from .logging import LoggingMiddleware
|
||||||
|
from .request import RequestMiddleware
|
||||||
|
|||||||
80
src/cpl-api/cpl/api/middleware/authentication.py
Normal file
80
src/cpl-api/cpl/api/middleware/authentication.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
from keycloak import KeycloakAuthenticationError
|
||||||
|
from starlette.types import Scope, Receive, Send
|
||||||
|
|
||||||
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
|
from cpl.api.error import Unauthorized
|
||||||
|
from cpl.api.logger import APILogger
|
||||||
|
from cpl.api.middleware.request import get_request
|
||||||
|
from cpl.api.router import Router
|
||||||
|
from cpl.auth.keycloak import KeycloakClient
|
||||||
|
from cpl.auth.schema import AuthUserDao, AuthUser
|
||||||
|
from cpl.core.ctx import set_user
|
||||||
|
from cpl.dependency import ServiceProviderABC
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationMiddleware(ASGIMiddleware):
|
||||||
|
|
||||||
|
@ServiceProviderABC.inject
|
||||||
|
def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
|
||||||
|
ASGIMiddleware.__init__(self, app)
|
||||||
|
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
|
self._keycloak = keycloak
|
||||||
|
self._user_dao = user_dao
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||||
|
request = get_request()
|
||||||
|
url = request.url.path
|
||||||
|
|
||||||
|
if url not in Router.get_auth_required_routes():
|
||||||
|
self._logger.trace(f"No authentication required for {url}")
|
||||||
|
return await self._app(scope, receive, send)
|
||||||
|
|
||||||
|
if not request.headers.get("Authorization"):
|
||||||
|
self._logger.debug(f"Unauthorized access to {url}, missing Authorization header")
|
||||||
|
return await Unauthorized(f"Missing header Authorization").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", None)
|
||||||
|
if not auth_header or not auth_header.startswith("Bearer "):
|
||||||
|
return await Unauthorized("Invalid Authorization header").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
|
token = auth_header.split("Bearer ")[1]
|
||||||
|
if not await self._verify_login(token):
|
||||||
|
self._logger.debug(f"Unauthorized access to {url}, invalid token")
|
||||||
|
return await Unauthorized("Invalid token").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
|
# check user exists in db, if not create
|
||||||
|
keycloak_id = self._keycloak.get_user_id(token)
|
||||||
|
if keycloak_id is None:
|
||||||
|
return await Unauthorized("Failed to get user id from token").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
|
user = await self._get_or_crate_user(keycloak_id)
|
||||||
|
if user.deleted:
|
||||||
|
self._logger.debug(f"Unauthorized access to {url}, user is deleted")
|
||||||
|
return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
|
request.state.user = user
|
||||||
|
set_user(user)
|
||||||
|
|
||||||
|
return await self._call_next(scope, receive, send)
|
||||||
|
|
||||||
|
async def _get_or_crate_user(self, keycloak_id: str) -> AuthUser:
|
||||||
|
existing = await self._user_dao.find_by_keycloak_id(keycloak_id)
|
||||||
|
if existing is not None:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
user = AuthUser(0, keycloak_id)
|
||||||
|
uid = await self._user_dao.create(user)
|
||||||
|
return await self._user_dao.get_by_id(uid)
|
||||||
|
|
||||||
|
async def _verify_login(self, token: str) -> bool:
|
||||||
|
try:
|
||||||
|
token_info = self._keycloak.introspect(token)
|
||||||
|
return token_info.get("active", False)
|
||||||
|
except KeycloakAuthenticationError as e:
|
||||||
|
self._logger.debug(f"Keycloak authentication error: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.error(f"Unexpected error during token verification: {e}")
|
||||||
|
return False
|
||||||
73
src/cpl-api/cpl/api/middleware/authorization.py
Normal file
73
src/cpl-api/cpl/api/middleware/authorization.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
from starlette.types import Scope, Receive, Send
|
||||||
|
|
||||||
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
|
from cpl.api.error import Unauthorized, Forbidden
|
||||||
|
from cpl.api.logger import APILogger
|
||||||
|
from cpl.api.middleware.request import get_request
|
||||||
|
from cpl.api.model.validation_match import ValidationMatch
|
||||||
|
from cpl.api.registry.policy import PolicyRegistry
|
||||||
|
from cpl.api.router import Router
|
||||||
|
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||||
|
from cpl.core.ctx.user_context import get_user
|
||||||
|
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationMiddleware(ASGIMiddleware):
|
||||||
|
|
||||||
|
@ServiceProviderABC.inject
|
||||||
|
def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao):
|
||||||
|
ASGIMiddleware.__init__(self, app)
|
||||||
|
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
|
self._policies = policies
|
||||||
|
self._user_dao = user_dao
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||||
|
request = get_request()
|
||||||
|
url = request.url.path
|
||||||
|
|
||||||
|
if url not in Router.get_authorization_rules_paths():
|
||||||
|
self._logger.trace(f"No authorization required for {url}")
|
||||||
|
return await self._app(scope, receive, send)
|
||||||
|
|
||||||
|
user = get_user()
|
||||||
|
if not user:
|
||||||
|
return await Unauthorized(f"Unknown user").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
|
roles = await user.roles
|
||||||
|
request.state.roles = roles
|
||||||
|
role_names = [r.name for r in roles]
|
||||||
|
|
||||||
|
perms = await user.permissions
|
||||||
|
request.state.permissions = perms
|
||||||
|
perm_names = [p.name for p in perms]
|
||||||
|
|
||||||
|
for rule in Router.get_authorization_rules():
|
||||||
|
match = rule["match"]
|
||||||
|
if rule["roles"]:
|
||||||
|
if match == ValidationMatch.all and not all(r in role_names for r in rule["roles"]):
|
||||||
|
return await Forbidden(f"missing roles: {rule["roles"]}").asgi_response(scope, receive, send)
|
||||||
|
if match == ValidationMatch.any and not any(r in role_names for r in rule["roles"]):
|
||||||
|
return await Forbidden(f"missing roles: {rule["roles"]}").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
|
if rule["permissions"]:
|
||||||
|
if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]):
|
||||||
|
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
|
||||||
|
scope, receive, send
|
||||||
|
)
|
||||||
|
if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]):
|
||||||
|
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
|
||||||
|
scope, receive, send
|
||||||
|
)
|
||||||
|
|
||||||
|
for policy_name in rule["policies"]:
|
||||||
|
policy = self._policies.get(policy_name)
|
||||||
|
if not policy:
|
||||||
|
self._logger.warning(f"Authorization policy '{policy_name}' not found")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not await policy.resolve(user):
|
||||||
|
return await Forbidden(f"policy {policy.name} failed").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
|
return await self._call_next(scope, receive, send)
|
||||||
@@ -1,21 +1,46 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response
|
from starlette.types import Receive, Scope, Send
|
||||||
|
|
||||||
from cpl.api.api_logger import APILogger
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
|
from cpl.api.logger import APILogger
|
||||||
_logger = APILogger(__name__)
|
from cpl.api.middleware.request import get_request
|
||||||
|
from cpl.dependency import ServiceProviderABC
|
||||||
|
|
||||||
|
|
||||||
class LoggingMiddleware(BaseHTTPMiddleware):
|
class LoggingMiddleware(ASGIMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
|
||||||
|
@ServiceProviderABC.inject
|
||||||
|
def __init__(self, app, logger: APILogger):
|
||||||
|
ASGIMiddleware.__init__(self, app)
|
||||||
|
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||||
|
if scope["type"] != "http":
|
||||||
|
await self._call_next(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
request = get_request()
|
||||||
await self._log_request(request)
|
await self._log_request(request)
|
||||||
response = await call_next(request)
|
start_time = time.time()
|
||||||
await self._log_after_request(request, response)
|
|
||||||
|
|
||||||
return response
|
response_body = b""
|
||||||
|
status_code = 500
|
||||||
|
|
||||||
|
async def send_wrapper(message):
|
||||||
|
nonlocal response_body, status_code
|
||||||
|
if message["type"] == "http.response.start":
|
||||||
|
status_code = message["status"]
|
||||||
|
if message["type"] == "http.response.body":
|
||||||
|
response_body += message.get("body", b"")
|
||||||
|
await send(message)
|
||||||
|
|
||||||
|
await self._call_next(scope, receive, send_wrapper)
|
||||||
|
|
||||||
|
duration = (time.time() - start_time) * 1000
|
||||||
|
await self._log_after_request(request, status_code, duration)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _filter_relevant_headers(headers: dict) -> dict:
|
def _filter_relevant_headers(headers: dict) -> dict:
|
||||||
@@ -30,10 +55,9 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
|||||||
}
|
}
|
||||||
return {key: value for key, value in headers.items() if key in relevant_keys}
|
return {key: value for key, value in headers.items() if key in relevant_keys}
|
||||||
|
|
||||||
@classmethod
|
async def _log_request(self, request: Request):
|
||||||
async def _log_request(cls, request: Request):
|
self._logger.debug(
|
||||||
_logger.debug(
|
f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}"
|
||||||
f"Request {request.state.request_id}: {request.method}@{request.url.path} from {request.client.host}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from cpl.core.ctx.user_context import get_user
|
from cpl.core.ctx.user_context import get_user
|
||||||
@@ -41,7 +65,7 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
|||||||
user = get_user()
|
user = get_user()
|
||||||
|
|
||||||
request_info = {
|
request_info = {
|
||||||
"headers": cls._filter_relevant_headers(dict(request.headers)),
|
"headers": self._filter_relevant_headers(dict(request.headers)),
|
||||||
"args": dict(request.query_params),
|
"args": dict(request.query_params),
|
||||||
"form-data": (
|
"form-data": (
|
||||||
await request.form()
|
await request.form()
|
||||||
@@ -55,11 +79,9 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
_logger.trace(f"Request {request.state.request_id}: {request_info}")
|
self._logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
|
||||||
|
|
||||||
@staticmethod
|
async def _log_after_request(self, request: Request, status_code: int, duration: float):
|
||||||
async def _log_after_request(request: Request, response: Response):
|
self._logger.info(
|
||||||
duration = (time.time() - request.state.start_time) * 1000
|
f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
|
||||||
_logger.info(
|
|
||||||
f"Request finished {request.state.request_id}: {response.status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,46 +3,54 @@ from contextvars import ContextVar
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.requests import Request
|
||||||
from starlette.websockets import WebSocket
|
from starlette.types import Scope, Receive, Send
|
||||||
|
|
||||||
from cpl.api.api_logger import APILogger
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
|
from cpl.api.logger import APILogger
|
||||||
from cpl.api.typing import TRequest
|
from cpl.api.typing import TRequest
|
||||||
|
from cpl.dependency import ServiceProviderABC
|
||||||
|
|
||||||
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
|
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
|
||||||
|
|
||||||
_logger = APILogger(__name__)
|
|
||||||
|
|
||||||
|
class RequestMiddleware(ASGIMiddleware):
|
||||||
|
|
||||||
class RequestMiddleware(BaseHTTPMiddleware):
|
@ServiceProviderABC.inject
|
||||||
_request_token = {}
|
def __init__(self, app, logger: APILogger):
|
||||||
_user_token = {}
|
ASGIMiddleware.__init__(self, app)
|
||||||
|
|
||||||
@classmethod
|
self._logger = logger
|
||||||
async def set_request_data(cls, request: TRequest):
|
|
||||||
|
self._ctx_token = None
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||||
|
request = Request(scope, receive, send)
|
||||||
|
await self.set_request_data(request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._app(scope, receive, send)
|
||||||
|
finally:
|
||||||
|
await self.clean_request_data()
|
||||||
|
|
||||||
|
async def set_request_data(self, request: TRequest):
|
||||||
request.state.request_id = uuid4()
|
request.state.request_id = uuid4()
|
||||||
request.state.start_time = time.time()
|
request.state.start_time = time.time()
|
||||||
_logger.trace(f"Set new current request: {request.state.request_id}")
|
self._logger.trace(f"Set new current request: {request.state.request_id}")
|
||||||
|
|
||||||
cls._request_token[request.state.request_id] = _request_context.set(request)
|
self._ctx_token = _request_context.set(request)
|
||||||
|
|
||||||
@classmethod
|
async def clean_request_data(self):
|
||||||
async def clean_request_data(cls):
|
|
||||||
request = get_request()
|
request = get_request()
|
||||||
if request is None:
|
if request is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if request.state.request_id in cls._request_token:
|
if self._ctx_token is None:
|
||||||
_request_context.reset(cls._request_token[request.state.request_id])
|
return
|
||||||
|
|
||||||
async def dispatch(self, request: TRequest, call_next):
|
self._logger.trace(f"Clearing current request: {request.state.request_id}")
|
||||||
await self.set_request_data(request)
|
_request_context.reset(self._ctx_token)
|
||||||
try:
|
|
||||||
response = await call_next(request)
|
|
||||||
return response
|
|
||||||
finally:
|
|
||||||
await self.clean_request_data()
|
|
||||||
|
|
||||||
|
|
||||||
def get_request() -> Optional[Union[TRequest, WebSocket]]:
|
def get_request() -> Optional[TRequest]:
|
||||||
return _request_context.get()
|
return _request_context.get()
|
||||||
|
|||||||
3
src/cpl-api/cpl/api/model/__init__.py
Normal file
3
src/cpl-api/cpl/api/model/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .api_route import ApiRoute
|
||||||
|
from .policy import Policy
|
||||||
|
from .validation_match import ValidationMatch
|
||||||
43
src/cpl-api/cpl/api/model/api_route.py
Normal file
43
src/cpl-api/cpl/api/model/api_route.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from starlette.routing import Route
|
||||||
|
|
||||||
|
from cpl.api.typing import HTTPMethods
|
||||||
|
|
||||||
|
|
||||||
|
class ApiRoute:
|
||||||
|
|
||||||
|
def __init__(self, path: str, fn: Callable, method: HTTPMethods, **kwargs):
|
||||||
|
self._path = path
|
||||||
|
self._fn = fn
|
||||||
|
self._method = method
|
||||||
|
|
||||||
|
self._kwargs = kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._fn.__name__
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fn(self) -> Callable:
|
||||||
|
return self._fn
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self) -> str:
|
||||||
|
return self._path
|
||||||
|
|
||||||
|
@property
|
||||||
|
def method(self) -> HTTPMethods:
|
||||||
|
return self._method
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kwargs(self) -> dict:
|
||||||
|
return self._kwargs
|
||||||
|
|
||||||
|
def to_starlette(self, wrap_endpoint: Callable = None) -> Route:
|
||||||
|
return Route(
|
||||||
|
self._path,
|
||||||
|
self._fn if not wrap_endpoint else wrap_endpoint(self._fn),
|
||||||
|
methods=[self._method],
|
||||||
|
**self._kwargs,
|
||||||
|
)
|
||||||
34
src/cpl-api/cpl/api/model/policy.py
Normal file
34
src/cpl-api/cpl/api/model/policy.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from asyncio import iscoroutinefunction
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from cpl.api.typing import PolicyResolver
|
||||||
|
from cpl.core.ctx import get_user
|
||||||
|
|
||||||
|
|
||||||
|
class Policy:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
resolver: PolicyResolver = None,
|
||||||
|
):
|
||||||
|
self._name = name
|
||||||
|
self._resolver: Optional[PolicyResolver] = resolver
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def resolvers(self) -> PolicyResolver:
|
||||||
|
return self._resolver
|
||||||
|
|
||||||
|
async def resolve(self, *args, **kwargs) -> bool:
|
||||||
|
if not self._resolver:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if callable(self._resolver):
|
||||||
|
if iscoroutinefunction(self._resolver):
|
||||||
|
return await self._resolver(get_user())
|
||||||
|
|
||||||
|
return self._resolver(get_user())
|
||||||
|
return False
|
||||||
6
src/cpl-api/cpl/api/model/validation_match.py
Normal file
6
src/cpl-api/cpl/api/model/validation_match.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationMatch(Enum):
|
||||||
|
any = "any"
|
||||||
|
all = "all"
|
||||||
2
src/cpl-api/cpl/api/registry/__init__.py
Normal file
2
src/cpl-api/cpl/api/registry/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .policy import PolicyRegistry
|
||||||
|
from .route import RouteRegistry
|
||||||
28
src/cpl-api/cpl/api/registry/policy.py
Normal file
28
src/cpl-api/cpl/api/registry/policy.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from cpl.api.model.policy import Policy
|
||||||
|
from cpl.core.abc.registry_abc import RegistryABC
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyRegistry(RegistryABC):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
RegistryABC.__init__(self)
|
||||||
|
|
||||||
|
def extend(self, items: list[Policy]):
|
||||||
|
for policy in items:
|
||||||
|
self.add(policy)
|
||||||
|
|
||||||
|
def add(self, item: Policy):
|
||||||
|
assert isinstance(item, Policy), "policy must be an instance of Policy"
|
||||||
|
|
||||||
|
if item.name in self._items:
|
||||||
|
raise ValueError(f"Policy {item.name} is already registered")
|
||||||
|
|
||||||
|
self._items[item.name] = item
|
||||||
|
|
||||||
|
def get(self, key: str) -> Optional[Policy]:
|
||||||
|
return self._items.get(key)
|
||||||
|
|
||||||
|
def all(self) -> list[Policy]:
|
||||||
|
return list(self._items.values())
|
||||||
32
src/cpl-api/cpl/api/registry/route.py
Normal file
32
src/cpl-api/cpl/api/registry/route.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from cpl.api.model.api_route import ApiRoute
|
||||||
|
from cpl.core.abc.registry_abc import RegistryABC
|
||||||
|
|
||||||
|
|
||||||
|
class RouteRegistry(RegistryABC):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
RegistryABC.__init__(self)
|
||||||
|
|
||||||
|
def extend(self, items: list[ApiRoute]):
|
||||||
|
for policy in items:
|
||||||
|
self.add(policy)
|
||||||
|
|
||||||
|
def add(self, item: ApiRoute):
|
||||||
|
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
|
||||||
|
|
||||||
|
if item.path in self._items:
|
||||||
|
raise ValueError(f"ApiRoute {item.path} is already registered")
|
||||||
|
|
||||||
|
self._items[item.path] = item
|
||||||
|
|
||||||
|
def set(self, item: ApiRoute):
|
||||||
|
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
|
||||||
|
self._items[item.path] = item
|
||||||
|
|
||||||
|
def get(self, key: str) -> Optional[ApiRoute]:
|
||||||
|
return self._items.get(key)
|
||||||
|
|
||||||
|
def all(self) -> list[ApiRoute]:
|
||||||
|
return list(self._items.values())
|
||||||
@@ -1,41 +1,136 @@
|
|||||||
from starlette.routing import Route
|
from enum import Enum
|
||||||
|
|
||||||
|
from cpl.api.model.validation_match import ValidationMatch
|
||||||
|
from cpl.api.registry.route import RouteRegistry
|
||||||
|
from cpl.api.typing import HTTPMethods
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
_registered_routes: list[Route] = []
|
_auth_required: list[str] = []
|
||||||
|
_authorization_rules: dict[str, dict] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_routes(cls) -> list[Route]:
|
def get_auth_required_routes(cls) -> list[str]:
|
||||||
return cls._registered_routes
|
return cls._auth_required
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def route(cls, path=None, **kwargs):
|
def get_authorization_rules_paths(cls) -> list[str]:
|
||||||
|
return list(cls._authorization_rules.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_authorization_rules(cls) -> list[dict]:
|
||||||
|
return list(cls._authorization_rules.values())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def authenticate(cls):
|
||||||
|
"""
|
||||||
|
Decorator to mark a route as requiring authentication.
|
||||||
|
Usage:
|
||||||
|
@Route.authenticate()
|
||||||
|
@Route.get("/example")
|
||||||
|
async def example_endpoint(request: TRequest):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
def inner(fn):
|
def inner(fn):
|
||||||
cls._registered_routes.append(Route(path, fn, **kwargs))
|
route_path = getattr(fn, "_route_path", None)
|
||||||
|
if route_path and route_path not in cls._auth_required:
|
||||||
|
cls._auth_required.append(route_path)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def authorize(
|
||||||
|
cls,
|
||||||
|
roles: list[str | Enum] = None,
|
||||||
|
permissions: list[str | Enum] = None,
|
||||||
|
policies: list[str] = None,
|
||||||
|
match: ValidationMatch = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Decorator to mark a route as requiring authorization.
|
||||||
|
Usage:
|
||||||
|
@Route.authorize()
|
||||||
|
@Route.get("/example")
|
||||||
|
async def example_endpoint(request: TRequest):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
assert roles is None or isinstance(roles, list), "roles must be a list of strings"
|
||||||
|
assert permissions is None or isinstance(permissions, list), "permissions must be a list of strings"
|
||||||
|
assert policies is None or isinstance(policies, list), "policies must be a list of strings"
|
||||||
|
assert match is None or isinstance(match, ValidationMatch), "match must be an instance of ValidationMatch"
|
||||||
|
|
||||||
|
if roles is not None:
|
||||||
|
for role in roles:
|
||||||
|
if isinstance(role, Enum):
|
||||||
|
roles[roles.index(role)] = role.value
|
||||||
|
|
||||||
|
if permissions is not None:
|
||||||
|
for perm in permissions:
|
||||||
|
if isinstance(perm, Enum):
|
||||||
|
permissions[permissions.index(perm)] = perm.value
|
||||||
|
|
||||||
|
def inner(fn):
|
||||||
|
path = getattr(fn, "_route_path", None)
|
||||||
|
if not path:
|
||||||
|
return fn
|
||||||
|
|
||||||
|
if path in cls._authorization_rules:
|
||||||
|
raise ValueError(f"Route {path} is already registered for authorization")
|
||||||
|
|
||||||
|
cls._authorization_rules[path] = {
|
||||||
|
"roles": roles or [],
|
||||||
|
"permissions": permissions or [],
|
||||||
|
"policies": policies or [],
|
||||||
|
"match": match or ValidationMatch.all,
|
||||||
|
}
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs):
|
||||||
|
from cpl.api.model.api_route import ApiRoute
|
||||||
|
|
||||||
|
if not registry:
|
||||||
|
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||||
|
|
||||||
|
routes = ServiceProviderABC.get_global_service(RouteRegistry)
|
||||||
|
else:
|
||||||
|
routes = registry
|
||||||
|
|
||||||
|
def inner(fn):
|
||||||
|
routes.add(ApiRoute(path, fn, method, **kwargs))
|
||||||
setattr(fn, "_route_path", path)
|
setattr(fn, "_route_path", path)
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, path=None, **kwargs):
|
def get(cls, path: str, **kwargs):
|
||||||
return cls.route(path, methods=["GET"], **kwargs)
|
return cls.route(path, "GET", **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def post(cls, path=None, **kwargs):
|
def head(cls, path: str, **kwargs):
|
||||||
return cls.route(path, methods=["POST"], **kwargs)
|
return cls.route(path, "HEAD", **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def head(cls, path=None, **kwargs):
|
def post(cls, path: str, **kwargs):
|
||||||
return cls.route(path, methods=["HEAD"], **kwargs)
|
return cls.route(path, "POST", **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def put(cls, path=None, **kwargs):
|
def put(cls, path: str, **kwargs):
|
||||||
return cls.route(path, methods=["PUT"], **kwargs)
|
return cls.route(path, "PUT", **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def delete(cls, path=None, **kwargs):
|
def patch(cls, path: str, **kwargs):
|
||||||
return cls.route(path, methods=["DELETE"], **kwargs)
|
return cls.route(path, "PATCH", **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete(cls, path: str, **kwargs):
|
||||||
|
return cls.route(path, "DELETE", **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override(cls):
|
def override(cls):
|
||||||
@@ -48,13 +143,22 @@ class Router:
|
|||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from cpl.api.model.api_route import ApiRoute
|
||||||
|
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||||
|
|
||||||
|
routes = ServiceProviderABC.get_global_service(RouteRegistry)
|
||||||
|
|
||||||
def inner(fn):
|
def inner(fn):
|
||||||
route_path = getattr(fn, "_route_path", None)
|
path = getattr(fn, "_route_path", None)
|
||||||
|
if path is None:
|
||||||
|
raise ValueError("Cannot override a route that has not been registered yet")
|
||||||
|
|
||||||
routes = list(filter(lambda x: x.path == route_path, cls._registered_routes))
|
route = routes.get(path)
|
||||||
for route in routes[:-1]:
|
if route is None:
|
||||||
cls._registered_routes.remove(route)
|
raise ValueError(f"Cannot override a route that does not exist: {path}")
|
||||||
|
|
||||||
|
routes.add(ApiRoute(path, fn, route.method, **route.kwargs))
|
||||||
|
setattr(fn, "_route_path", path)
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
@@ -1,13 +1,19 @@
|
|||||||
from typing import Union, Literal, Callable
|
from typing import Union, Literal, Callable, Type, Awaitable
|
||||||
from urllib.request import Request
|
from urllib.request import Request
|
||||||
|
|
||||||
from starlette.middleware import Middleware
|
from starlette.middleware import Middleware
|
||||||
from starlette.types import ASGIApp
|
from starlette.types import ASGIApp
|
||||||
from starlette.websockets import WebSocket
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
|
from cpl.auth.schema import AuthUser
|
||||||
|
|
||||||
TRequest = Union[Request, WebSocket]
|
TRequest = Union[Request, WebSocket]
|
||||||
HTTPMethods = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
|
HTTPMethods = Literal["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
|
||||||
PartialMiddleware = Union[
|
PartialMiddleware = Union[
|
||||||
|
ASGIMiddleware,
|
||||||
|
Type[ASGIMiddleware],
|
||||||
Middleware,
|
Middleware,
|
||||||
Callable[[ASGIApp], ASGIApp],
|
Callable[[ASGIApp], ASGIApp],
|
||||||
]
|
]
|
||||||
|
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]
|
||||||
|
|||||||
@@ -1,153 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import Mapping, Any, Callable
|
|
||||||
|
|
||||||
import uvicorn
|
|
||||||
from starlette.applications import Starlette
|
|
||||||
from starlette.middleware import Middleware
|
|
||||||
from starlette.middleware.cors import CORSMiddleware
|
|
||||||
from starlette.requests import Request
|
|
||||||
from starlette.responses import JSONResponse
|
|
||||||
from starlette.routing import Route
|
|
||||||
from starlette.types import ExceptionHandler
|
|
||||||
|
|
||||||
from cpl.api.api_logger import APILogger
|
|
||||||
from cpl.api.api_settings import ApiSettings
|
|
||||||
from cpl.api.error import APIError
|
|
||||||
from cpl.api.middleware.logging import LoggingMiddleware
|
|
||||||
from cpl.api.middleware.request import RequestMiddleware
|
|
||||||
from cpl.api.router import Router
|
|
||||||
from cpl.api.typing import HTTPMethods, PartialMiddleware
|
|
||||||
from cpl.application.abc.application_abc import ApplicationABC
|
|
||||||
from cpl.core.configuration import Configuration
|
|
||||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
|
||||||
|
|
||||||
_logger = APILogger("API")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WebApp(ApplicationABC):
|
|
||||||
def __init__(self, services: ServiceProviderABC):
|
|
||||||
super().__init__(services)
|
|
||||||
self._app: Starlette | None = None
|
|
||||||
|
|
||||||
self._api_settings = Configuration.get(ApiSettings)
|
|
||||||
|
|
||||||
self._routes: list[Route] = []
|
|
||||||
self._middleware: list[Middleware] = [
|
|
||||||
Middleware(RequestMiddleware),
|
|
||||||
Middleware(LoggingMiddleware),
|
|
||||||
]
|
|
||||||
self._exception_handlers: Mapping[Any, ExceptionHandler] = {Exception: self.handle_exception}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def handle_exception(request: Request, exc: Exception):
|
|
||||||
if hasattr(request.state, "request_id"):
|
|
||||||
_logger.error(f"Request {request.state.request_id}", exc)
|
|
||||||
else:
|
|
||||||
_logger.error("Request unknown", exc)
|
|
||||||
|
|
||||||
if isinstance(exc, APIError):
|
|
||||||
return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
|
|
||||||
|
|
||||||
return JSONResponse({"error": str(exc)}, status_code=500)
|
|
||||||
|
|
||||||
def _get_allowed_origins(self):
|
|
||||||
origins = self._api_settings.allowed_origins
|
|
||||||
|
|
||||||
if origins is None or origins == "":
|
|
||||||
_logger.warning("No allowed origins specified, allowing all origins")
|
|
||||||
return ["*"]
|
|
||||||
|
|
||||||
_logger.debug(f"Allowed origins: {origins}")
|
|
||||||
return origins.split(",")
|
|
||||||
|
|
||||||
def with_app(self, app: Starlette):
|
|
||||||
assert app is not None, "app must not be None"
|
|
||||||
assert isinstance(app, Starlette), "app must be an instance of Starlette"
|
|
||||||
self._app = app
|
|
||||||
return self
|
|
||||||
|
|
||||||
def _check_for_app(self):
|
|
||||||
if self._app is not None:
|
|
||||||
raise ValueError("App is already set, cannot add routes or middleware")
|
|
||||||
|
|
||||||
def with_routes_directory(self, directory: str) -> "WebApp":
|
|
||||||
self._check_for_app()
|
|
||||||
assert directory is not None, "directory must not be None"
|
|
||||||
|
|
||||||
base = directory.replace("/", ".").replace("\\", ".")
|
|
||||||
|
|
||||||
for filename in os.listdir(directory):
|
|
||||||
if not filename.endswith(".py") or filename == "__init__.py":
|
|
||||||
continue
|
|
||||||
|
|
||||||
__import__(f"{base}.{filename[:-3]}")
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def with_routes(self, routes: list[Route]) -> "WebApp":
|
|
||||||
self._check_for_app()
|
|
||||||
assert self._routes is not None, "routes must not be None"
|
|
||||||
assert all(isinstance(route, Route) for route in routes), "all routes must be of type starlette.routing.Route"
|
|
||||||
self._routes.extend(routes)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def with_route(self, path: str, fn: Callable[[Request], Any], method: HTTPMethods, **kwargs) -> "WebApp":
|
|
||||||
self._check_for_app()
|
|
||||||
assert path is not None, "path must not be None"
|
|
||||||
assert fn is not None, "fn must not be None"
|
|
||||||
assert method in ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"], "method must be a valid HTTP method"
|
|
||||||
self._routes.append(Route(path, fn, methods=[method], **kwargs))
|
|
||||||
return self
|
|
||||||
|
|
||||||
def with_middleware(self, middleware: PartialMiddleware) -> "WebApp":
|
|
||||||
self._check_for_app()
|
|
||||||
|
|
||||||
if isinstance(middleware, Middleware):
|
|
||||||
self._middleware.append(middleware)
|
|
||||||
|
|
||||||
elif callable(middleware):
|
|
||||||
self._middleware.append(Middleware(middleware))
|
|
||||||
else:
|
|
||||||
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
|
|
||||||
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def main(self):
|
|
||||||
_logger.debug(f"Preparing API")
|
|
||||||
if self._app is None:
|
|
||||||
routes = [
|
|
||||||
Route(
|
|
||||||
path=route.path,
|
|
||||||
endpoint=self._services.inject(route.endpoint),
|
|
||||||
methods=route.methods,
|
|
||||||
name=route.name,
|
|
||||||
)
|
|
||||||
for route in self._routes + Router.get_routes()
|
|
||||||
]
|
|
||||||
|
|
||||||
app = Starlette(
|
|
||||||
routes=routes,
|
|
||||||
middleware=[
|
|
||||||
*self._middleware,
|
|
||||||
Middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=self._get_allowed_origins(),
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
exception_handlers=self._exception_handlers,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
app = self._app
|
|
||||||
|
|
||||||
_logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
|
|
||||||
uvicorn.run(
|
|
||||||
app,
|
|
||||||
host=self._api_settings.host,
|
|
||||||
port=self._api_settings.port,
|
|
||||||
log_config=None,
|
|
||||||
)
|
|
||||||
_logger.info("Shutdown API")
|
|
||||||
@@ -3,16 +3,16 @@ requires = ["setuptools>=70.1.0", "wheel>=0.43.0"]
|
|||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "cpl-application"
|
name = "cpl-api"
|
||||||
version = "2024.7.0"
|
version = "2024.7.0"
|
||||||
description = "CPL application"
|
description = "CPL api"
|
||||||
readme ="CPL application package"
|
readme ="CPL api package"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
license = { text = "MIT" }
|
license = { text = "MIT" }
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Sven Heidemann", email = "sven.heidemann@sh-edraft.de" }
|
{ name = "Sven Heidemann", email = "sven.heidemann@sh-edraft.de" }
|
||||||
]
|
]
|
||||||
keywords = ["cpl", "application", "backend", "shared", "library"]
|
keywords = ["cpl", "api", "backend", "shared", "library"]
|
||||||
|
|
||||||
dynamic = ["dependencies", "optional-dependencies"]
|
dynamic = ["dependencies", "optional-dependencies"]
|
||||||
|
|
||||||
|
|||||||
@@ -4,3 +4,4 @@ cpl-core
|
|||||||
cpl-dependency
|
cpl-dependency
|
||||||
starlette==0.48.0
|
starlette==0.48.0
|
||||||
python-multipart==0.0.20
|
python-multipart==0.0.20
|
||||||
|
uvicorn==0.35.0
|
||||||
@@ -2,9 +2,8 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Callable, Self
|
from typing import Callable, Self
|
||||||
|
|
||||||
from cpl.application.host import Host
|
from cpl.application.host import Host
|
||||||
from cpl.core.console.console import Console
|
|
||||||
from cpl.core.log import LogSettings
|
|
||||||
from cpl.core.log.log_level import LogLevel
|
from cpl.core.log.log_level import LogLevel
|
||||||
|
from cpl.core.log.log_settings import LogSettings
|
||||||
from cpl.core.log.logger_abc import LoggerABC
|
from cpl.core.log.logger_abc import LoggerABC
|
||||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||||
|
|
||||||
@@ -22,8 +21,15 @@ class ApplicationABC(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, services: ServiceProviderABC):
|
def __init__(self, services: ServiceProviderABC, required_modules: list[str | object] = None):
|
||||||
self._services = services
|
self._services = services
|
||||||
|
self._required_modules = (
|
||||||
|
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required_modules(self) -> list[str]:
|
||||||
|
return self._required_modules
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extend(cls, name: str | Callable, func: Callable[[Self], Self]):
|
def extend(cls, name: str | Callable, func: Callable[[Self], Self]):
|
||||||
@@ -80,7 +86,7 @@ class ApplicationABC(ABC):
|
|||||||
try:
|
try:
|
||||||
Host.run(self.main)
|
Host.run(self.main)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
Console.close()
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def main(self): ...
|
def main(self): ...
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from cpl.dependency import ServiceProviderABC
|
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||||
|
|
||||||
|
|
||||||
class ApplicationExtensionABC(ABC):
|
class ApplicationExtensionABC(ABC):
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from cpl.application.abc.application_extension_abc import ApplicationExtensionAB
|
|||||||
from cpl.application.abc.startup_abc import StartupABC
|
from cpl.application.abc.startup_abc import StartupABC
|
||||||
from cpl.application.abc.startup_extension_abc import StartupExtensionABC
|
from cpl.application.abc.startup_extension_abc import StartupExtensionABC
|
||||||
from cpl.application.host import Host
|
from cpl.application.host import Host
|
||||||
|
from cpl.core.errors import dependency_error
|
||||||
from cpl.dependency.service_collection import ServiceCollection
|
from cpl.dependency.service_collection import ServiceCollection
|
||||||
|
|
||||||
TApp = TypeVar("TApp", bound=ApplicationABC)
|
TApp = TypeVar("TApp", bound=ApplicationABC)
|
||||||
@@ -35,6 +36,18 @@ class ApplicationBuilder(Generic[TApp]):
|
|||||||
def service_provider(self):
|
def service_provider(self):
|
||||||
return self._services.build()
|
return self._services.build()
|
||||||
|
|
||||||
|
def validate_app_required_modules(self, app: ApplicationABC):
|
||||||
|
for module in app.required_modules:
|
||||||
|
if module in self._services.loaded_modules:
|
||||||
|
continue
|
||||||
|
|
||||||
|
dependency_error(
|
||||||
|
module,
|
||||||
|
ImportError(
|
||||||
|
f"Required module '{module}' for application '{app.__class__.__name__}' is not loaded. Load using 'add_module({module})' method."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def with_startup(self, startup: Type[StartupABC]) -> "ApplicationBuilder":
|
def with_startup(self, startup: Type[StartupABC]) -> "ApplicationBuilder":
|
||||||
self._startup = startup
|
self._startup = startup
|
||||||
return self
|
return self
|
||||||
@@ -62,4 +75,6 @@ class ApplicationBuilder(Generic[TApp]):
|
|||||||
for extension in self._app_extensions:
|
for extension in self._app_extensions:
|
||||||
Host.run(extension.run, self.service_provider)
|
Host.run(extension.run, self.service_provider)
|
||||||
|
|
||||||
return self._app(self.service_provider)
|
app = self._app(self.service_provider)
|
||||||
|
self.validate_app_required_modules(app)
|
||||||
|
return app
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from cpl.auth import permission as _permission
|
|||||||
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin as _KeycloakAdmin
|
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin as _KeycloakAdmin
|
||||||
from cpl.auth.keycloak.keycloak_client import KeycloakClient as _KeycloakClient
|
from cpl.auth.keycloak.keycloak_client import KeycloakClient as _KeycloakClient
|
||||||
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
|
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
|
||||||
from .auth_logger import AuthLogger
|
from .logger import AuthLogger
|
||||||
from .keycloak_settings import KeycloakSettings
|
from .keycloak_settings import KeycloakSettings
|
||||||
from .permission_seeder import PermissionSeeder
|
from .permission_seeder import PermissionSeeder
|
||||||
|
|
||||||
@@ -40,11 +40,10 @@ def _add_daos(collection: _ServiceCollection):
|
|||||||
def add_auth(collection: _ServiceCollection):
|
def add_auth(collection: _ServiceCollection):
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from cpl.core.console import Console
|
|
||||||
from cpl.database.service.migration_service import MigrationService
|
|
||||||
from cpl.database.model.server_type import ServerType, ServerTypes
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from cpl.database.service.migration_service import MigrationService
|
||||||
|
from cpl.database.model.server_type import ServerType, ServerTypes
|
||||||
|
|
||||||
collection.add_singleton(_KeycloakClient)
|
collection.add_singleton(_KeycloakClient)
|
||||||
collection.add_singleton(_KeycloakAdmin)
|
collection.add_singleton(_KeycloakAdmin)
|
||||||
|
|
||||||
@@ -59,22 +58,25 @@ def add_auth(collection: _ServiceCollection):
|
|||||||
elif ServerType.server_type == ServerTypes.MYSQL:
|
elif ServerType.server_type == ServerTypes.MYSQL:
|
||||||
migration_service.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/mysql"))
|
migration_service.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/mysql"))
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
Console.error("cpl-auth is not installed", str(e))
|
from cpl.core.console import Console
|
||||||
|
|
||||||
|
Console.error("cpl-database is not installed", str(e))
|
||||||
|
|
||||||
|
|
||||||
def add_permission(collection: _ServiceCollection):
|
def add_permission(collection: _ServiceCollection):
|
||||||
from cpl.auth.permission_seeder import PermissionSeeder
|
from .permission_seeder import PermissionSeeder
|
||||||
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
from .permission.permissions_registry import PermissionsRegistry
|
||||||
from cpl.auth.permission.permissions_registry import PermissionsRegistry
|
from .permission.permissions import Permissions
|
||||||
from cpl.auth.permission.permissions import Permissions
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
||||||
|
|
||||||
collection.add_singleton(DataSeederABC, PermissionSeeder)
|
collection.add_singleton(DataSeederABC, PermissionSeeder)
|
||||||
PermissionsRegistry.with_enum(Permissions)
|
PermissionsRegistry.with_enum(Permissions)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
from cpl.core.console import Console
|
from cpl.core.console import Console
|
||||||
|
|
||||||
Console.error("cpl-auth is not installed", str(e))
|
Console.error("cpl-database is not installed", str(e))
|
||||||
|
|
||||||
|
|
||||||
_ServiceCollection.with_module(add_auth, __name__)
|
_ServiceCollection.with_module(add_auth, __name__)
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
from cpl.core.log import Logger
|
|
||||||
from cpl.core.typing import Source
|
|
||||||
|
|
||||||
|
|
||||||
class AuthLogger(Logger):
|
|
||||||
|
|
||||||
def __init__(self, source: Source):
|
|
||||||
Logger.__init__(self, source, "auth")
|
|
||||||
@@ -1,15 +1,13 @@
|
|||||||
from keycloak import KeycloakAdmin as _KeycloakAdmin, KeycloakOpenIDConnection
|
from keycloak import KeycloakAdmin as _KeycloakAdmin, KeycloakOpenIDConnection
|
||||||
|
|
||||||
from cpl.auth.auth_logger import AuthLogger
|
|
||||||
from cpl.auth.keycloak_settings import KeycloakSettings
|
from cpl.auth.keycloak_settings import KeycloakSettings
|
||||||
|
from cpl.auth.logger import AuthLogger
|
||||||
_logger = AuthLogger("keycloak")
|
|
||||||
|
|
||||||
|
|
||||||
class KeycloakAdmin(_KeycloakAdmin):
|
class KeycloakAdmin(_KeycloakAdmin):
|
||||||
|
|
||||||
def __init__(self, settings: KeycloakSettings):
|
def __init__(self, logger: AuthLogger, settings: KeycloakSettings):
|
||||||
_logger.info("Initializing Keycloak admin")
|
# logger.info("Initializing Keycloak admin")
|
||||||
_connection = KeycloakOpenIDConnection(
|
_connection = KeycloakOpenIDConnection(
|
||||||
server_url=settings.url,
|
server_url=settings.url,
|
||||||
client_id=settings.client_id,
|
client_id=settings.client_id,
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
from keycloak import KeycloakOpenID, KeycloakAdmin, KeycloakOpenIDConnection
|
from typing import Optional
|
||||||
|
|
||||||
from cpl.auth.auth_logger import AuthLogger
|
from keycloak import KeycloakOpenID
|
||||||
|
|
||||||
|
from cpl.auth.logger import AuthLogger
|
||||||
from cpl.auth.keycloak_settings import KeycloakSettings
|
from cpl.auth.keycloak_settings import KeycloakSettings
|
||||||
|
|
||||||
_logger = AuthLogger("keycloak")
|
|
||||||
|
|
||||||
|
|
||||||
class KeycloakClient(KeycloakOpenID):
|
class KeycloakClient(KeycloakOpenID):
|
||||||
|
|
||||||
def __init__(self, settings: KeycloakSettings):
|
def __init__(self, logger: AuthLogger, settings: KeycloakSettings):
|
||||||
KeycloakOpenID.__init__(
|
KeycloakOpenID.__init__(
|
||||||
self,
|
self,
|
||||||
server_url=settings.url,
|
server_url=settings.url,
|
||||||
@@ -16,11 +16,8 @@ class KeycloakClient(KeycloakOpenID):
|
|||||||
realm_name=settings.realm,
|
realm_name=settings.realm,
|
||||||
client_secret_key=settings.client_secret,
|
client_secret_key=settings.client_secret,
|
||||||
)
|
)
|
||||||
_logger.info("Initializing Keycloak client")
|
logger.info("Initializing Keycloak client")
|
||||||
connection = KeycloakOpenIDConnection(
|
|
||||||
server_url=settings.url,
|
def get_user_id(self, token: str) -> Optional[str]:
|
||||||
client_id=settings.client_id,
|
info = self.introspect(token)
|
||||||
realm_name=settings.realm,
|
return info.get("sub", None)
|
||||||
client_secret_key=settings.client_secret,
|
|
||||||
)
|
|
||||||
self._admin = KeycloakAdmin(connection=connection)
|
|
||||||
|
|||||||
7
src/cpl-auth/cpl/auth/logger.py
Normal file
7
src/cpl-auth/cpl/auth/logger.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||||
|
|
||||||
|
|
||||||
|
class AuthLogger(WrappedLogger):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
WrappedLogger.__init__(self, "auth")
|
||||||
@@ -14,14 +14,13 @@ from cpl.auth.schema import (
|
|||||||
)
|
)
|
||||||
from cpl.core.utils.get_value import get_value
|
from cpl.core.utils.get_value import get_value
|
||||||
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
||||||
from cpl.database.db_logger import DBLogger
|
from cpl.database.logger import DBLogger
|
||||||
|
|
||||||
_logger = DBLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class PermissionSeeder(DataSeederABC):
|
class PermissionSeeder(DataSeederABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
logger: DBLogger,
|
||||||
permission_dao: PermissionDao,
|
permission_dao: PermissionDao,
|
||||||
role_dao: RoleDao,
|
role_dao: RoleDao,
|
||||||
role_permission_dao: RolePermissionDao,
|
role_permission_dao: RolePermissionDao,
|
||||||
@@ -29,6 +28,7 @@ class PermissionSeeder(DataSeederABC):
|
|||||||
api_key_permission_dao: ApiKeyPermissionDao,
|
api_key_permission_dao: ApiKeyPermissionDao,
|
||||||
):
|
):
|
||||||
DataSeederABC.__init__(self)
|
DataSeederABC.__init__(self)
|
||||||
|
self._logger = logger
|
||||||
self._permission_dao = permission_dao
|
self._permission_dao = permission_dao
|
||||||
self._role_dao = role_dao
|
self._role_dao = role_dao
|
||||||
self._role_permission_dao = role_permission_dao
|
self._role_permission_dao = role_permission_dao
|
||||||
@@ -40,7 +40,7 @@ class PermissionSeeder(DataSeederABC):
|
|||||||
possible_permissions = [permission for permission in PermissionsRegistry.get()]
|
possible_permissions = [permission for permission in PermissionsRegistry.get()]
|
||||||
|
|
||||||
if len(permissions) == len(possible_permissions):
|
if len(permissions) == len(possible_permissions):
|
||||||
_logger.info("Permissions already existing")
|
self._logger.info("Permissions already existing")
|
||||||
await self._update_missing_descriptions()
|
await self._update_missing_descriptions()
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -53,7 +53,7 @@ class PermissionSeeder(DataSeederABC):
|
|||||||
|
|
||||||
await self._permission_dao.delete_many(to_delete, hard_delete=True)
|
await self._permission_dao.delete_many(to_delete, hard_delete=True)
|
||||||
|
|
||||||
_logger.warning("Permissions incomplete")
|
self._logger.warning("Permissions incomplete")
|
||||||
permission_names = [permission.name for permission in permissions]
|
permission_names = [permission.name for permission in permissions]
|
||||||
await self._permission_dao.create_many(
|
await self._permission_dao.create_many(
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -3,15 +3,12 @@ from typing import Optional
|
|||||||
from cpl.auth.schema._administration.api_key import ApiKey
|
from cpl.auth.schema._administration.api_key import ApiKey
|
||||||
from cpl.database import TableManager
|
from cpl.database import TableManager
|
||||||
from cpl.database.abc import DbModelDaoABC
|
from cpl.database.abc import DbModelDaoABC
|
||||||
from cpl.database.db_logger import DBLogger
|
|
||||||
|
|
||||||
_logger = DBLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyDao(DbModelDaoABC[ApiKey]):
|
class ApiKeyDao(DbModelDaoABC[ApiKey]):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
DbModelDaoABC.__init__(self, __name__, ApiKey, TableManager.get("api_keys"))
|
DbModelDaoABC.__init__(self, ApiKey, TableManager.get("api_keys"))
|
||||||
|
|
||||||
self.attribute(ApiKey.identifier, str)
|
self.attribute(ApiKey.identifier, str)
|
||||||
self.attribute(ApiKey.key, str, "keystring")
|
self.attribute(ApiKey.key, str, "keystring")
|
||||||
|
|||||||
@@ -6,14 +6,12 @@ from async_property import async_property
|
|||||||
from keycloak import KeycloakGetError
|
from keycloak import KeycloakGetError
|
||||||
|
|
||||||
from cpl.auth.keycloak import KeycloakAdmin
|
from cpl.auth.keycloak import KeycloakAdmin
|
||||||
from cpl.auth.auth_logger import AuthLogger
|
|
||||||
from cpl.auth.permission.permissions import Permissions
|
from cpl.auth.permission.permissions import Permissions
|
||||||
from cpl.core.typing import SerialId
|
from cpl.core.typing import SerialId
|
||||||
from cpl.database.abc import DbModelABC
|
from cpl.database.abc import DbModelABC
|
||||||
|
from cpl.database.logger import DBLogger
|
||||||
from cpl.dependency import ServiceProviderABC
|
from cpl.dependency import ServiceProviderABC
|
||||||
|
|
||||||
_logger = AuthLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthUser(DbModelABC):
|
class AuthUser(DbModelABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -38,12 +36,13 @@ class AuthUser(DbModelABC):
|
|||||||
return "ANONYMOUS"
|
return "ANONYMOUS"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||||
return keycloak_admin.get_user(self._keycloak_id).get("username")
|
return keycloak.get_user(self._keycloak_id).get("username")
|
||||||
except KeycloakGetError as e:
|
except KeycloakGetError as e:
|
||||||
return "UNKNOWN"
|
return "UNKNOWN"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||||
|
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||||
return "UNKNOWN"
|
return "UNKNOWN"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -52,12 +51,13 @@ class AuthUser(DbModelABC):
|
|||||||
return "ANONYMOUS"
|
return "ANONYMOUS"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||||
return keycloak_admin.get_user(self._keycloak_id).get("email")
|
return keycloak.get_user(self._keycloak_id).get("email")
|
||||||
except KeycloakGetError as e:
|
except KeycloakGetError as e:
|
||||||
return "UNKNOWN"
|
return "UNKNOWN"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||||
|
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||||
return "UNKNOWN"
|
return "UNKNOWN"
|
||||||
|
|
||||||
@async_property
|
@async_property
|
||||||
|
|||||||
@@ -4,19 +4,16 @@ from cpl.auth.permission.permissions import Permissions
|
|||||||
from cpl.auth.schema._administration.auth_user import AuthUser
|
from cpl.auth.schema._administration.auth_user import AuthUser
|
||||||
from cpl.database import TableManager
|
from cpl.database import TableManager
|
||||||
from cpl.database.abc import DbModelDaoABC
|
from cpl.database.abc import DbModelDaoABC
|
||||||
from cpl.database.db_logger import DBLogger
|
|
||||||
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
|
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
|
||||||
from cpl.dependency import ServiceProviderABC
|
from cpl.dependency import ServiceProviderABC
|
||||||
|
|
||||||
_logger = DBLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthUserDao(DbModelDaoABC[AuthUser]):
|
class AuthUserDao(DbModelDaoABC[AuthUser]):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
DbModelDaoABC.__init__(self, __name__, AuthUser, TableManager.get("auth_users"))
|
DbModelDaoABC.__init__(self, AuthUser, TableManager.get("auth_users"))
|
||||||
|
|
||||||
self.attribute(AuthUser.keycloak_id, str, aliases=["keycloakId"])
|
self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId")
|
||||||
|
|
||||||
async def get_users():
|
async def get_users():
|
||||||
return [(x.id, x.username, x.email) for x in await self.get_all()]
|
return [(x.id, x.username, x.email) for x in await self.get_all()]
|
||||||
@@ -43,9 +40,9 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
|
|||||||
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
|
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
|
||||||
result = await self._db.select_map(
|
result = await self._db.select_map(
|
||||||
f"""
|
f"""
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*) as count
|
||||||
FROM permission.role_users ru
|
FROM {TableManager.get("role_users")} ru
|
||||||
JOIN permission.role_permissions rp ON ru.roleId = rp.roleId
|
JOIN {TableManager.get("role_permissions")} rp ON ru.roleId = rp.roleId
|
||||||
WHERE ru.userId = {user_id}
|
WHERE ru.userId = {user_id}
|
||||||
AND rp.permissionId = {p.id}
|
AND rp.permissionId = {p.id}
|
||||||
AND ru.deleted = FALSE
|
AND ru.deleted = FALSE
|
||||||
@@ -61,9 +58,9 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
|
|||||||
result = await self._db.select_map(
|
result = await self._db.select_map(
|
||||||
f"""
|
f"""
|
||||||
SELECT p.*
|
SELECT p.*
|
||||||
FROM permission.permissions p
|
FROM {TableManager.get("permissions")} p
|
||||||
JOIN permission.role_permissions rp ON p.id = rp.permissionId
|
JOIN {TableManager.get("role_permissions")} rp ON p.id = rp.permissionId
|
||||||
JOIN permission.role_users ru ON rp.roleId = ru.roleId
|
JOIN {TableManager.get("role_users")} ru ON rp.roleId = ru.roleId
|
||||||
WHERE ru.userId = {user_id}
|
WHERE ru.userId = {user_id}
|
||||||
AND rp.deleted = FALSE
|
AND rp.deleted = FALSE
|
||||||
AND ru.deleted = FALSE;
|
AND ru.deleted = FALSE;
|
||||||
|
|||||||
@@ -1,15 +1,12 @@
|
|||||||
from cpl.auth.schema._permission.api_key_permission import ApiKeyPermission
|
from cpl.auth.schema._permission.api_key_permission import ApiKeyPermission
|
||||||
from cpl.database import TableManager
|
from cpl.database import TableManager
|
||||||
from cpl.database.abc import DbModelDaoABC
|
from cpl.database.abc import DbModelDaoABC
|
||||||
from cpl.database.db_logger import DBLogger
|
|
||||||
|
|
||||||
_logger = DBLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyPermissionDao(DbModelDaoABC[ApiKeyPermission]):
|
class ApiKeyPermissionDao(DbModelDaoABC[ApiKeyPermission]):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
DbModelDaoABC.__init__(self, __name__, ApiKeyPermission, TableManager.get("api_key_permissions"))
|
DbModelDaoABC.__init__(self, ApiKeyPermission, TableManager.get("api_key_permissions"))
|
||||||
|
|
||||||
self.attribute(ApiKeyPermission.api_key_id, int)
|
self.attribute(ApiKeyPermission.api_key_id, int)
|
||||||
self.attribute(ApiKeyPermission.permission_id, int)
|
self.attribute(ApiKeyPermission.permission_id, int)
|
||||||
|
|||||||
@@ -3,15 +3,12 @@ from typing import Optional
|
|||||||
from cpl.auth.schema._permission.permission import Permission
|
from cpl.auth.schema._permission.permission import Permission
|
||||||
from cpl.database import TableManager
|
from cpl.database import TableManager
|
||||||
from cpl.database.abc import DbModelDaoABC
|
from cpl.database.abc import DbModelDaoABC
|
||||||
from cpl.database.db_logger import DBLogger
|
|
||||||
|
|
||||||
_logger = DBLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class PermissionDao(DbModelDaoABC[Permission]):
|
class PermissionDao(DbModelDaoABC[Permission]):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
DbModelDaoABC.__init__(self, __name__, Permission, TableManager.get("permissions"))
|
DbModelDaoABC.__init__(self, Permission, TableManager.get("permissions"))
|
||||||
|
|
||||||
self.attribute(Permission.name, str)
|
self.attribute(Permission.name, str)
|
||||||
self.attribute(Permission.description, Optional[str])
|
self.attribute(Permission.description, Optional[str])
|
||||||
|
|||||||
@@ -1,14 +1,11 @@
|
|||||||
from cpl.auth.schema._permission.role import Role
|
from cpl.auth.schema._permission.role import Role
|
||||||
from cpl.database import TableManager
|
from cpl.database import TableManager
|
||||||
from cpl.database.abc import DbModelDaoABC
|
from cpl.database.abc import DbModelDaoABC
|
||||||
from cpl.database.db_logger import DBLogger
|
|
||||||
|
|
||||||
_logger = DBLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RoleDao(DbModelDaoABC[Role]):
|
class RoleDao(DbModelDaoABC[Role]):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
DbModelDaoABC.__init__(self, __name__, Role, TableManager.get("roles"))
|
DbModelDaoABC.__init__(self, Role, TableManager.get("roles"))
|
||||||
self.attribute(Role.name, str)
|
self.attribute(Role.name, str)
|
||||||
self.attribute(Role.description, str)
|
self.attribute(Role.description, str)
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,12 @@
|
|||||||
from cpl.auth.schema._permission.role_permission import RolePermission
|
from cpl.auth.schema._permission.role_permission import RolePermission
|
||||||
from cpl.database import TableManager
|
from cpl.database import TableManager
|
||||||
from cpl.database.abc import DbModelDaoABC
|
from cpl.database.abc import DbModelDaoABC
|
||||||
from cpl.database.db_logger import DBLogger
|
|
||||||
|
|
||||||
_logger = DBLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RolePermissionDao(DbModelDaoABC[RolePermission]):
|
class RolePermissionDao(DbModelDaoABC[RolePermission]):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
DbModelDaoABC.__init__(self, __name__, RolePermission, TableManager.get("role_permissions"))
|
DbModelDaoABC.__init__(self, RolePermission, TableManager.get("role_permissions"))
|
||||||
|
|
||||||
self.attribute(RolePermission.role_id, int)
|
self.attribute(RolePermission.role_id, int)
|
||||||
self.attribute(RolePermission.permission_id, int)
|
self.attribute(RolePermission.permission_id, int)
|
||||||
|
|||||||
@@ -1,15 +1,12 @@
|
|||||||
from cpl.auth.schema._permission.role_user import RoleUser
|
from cpl.auth.schema._permission.role_user import RoleUser
|
||||||
from cpl.database import TableManager
|
from cpl.database import TableManager
|
||||||
from cpl.database.abc import DbModelDaoABC
|
from cpl.database.abc import DbModelDaoABC
|
||||||
from cpl.database.db_logger import DBLogger
|
|
||||||
|
|
||||||
_logger = DBLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RoleUserDao(DbModelDaoABC[RoleUser]):
|
class RoleUserDao(DbModelDaoABC[RoleUser]):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
DbModelDaoABC.__init__(self, __name__, RoleUser, TableManager.get("role_users"))
|
DbModelDaoABC.__init__(self, RoleUser, TableManager.get("role_users"))
|
||||||
|
|
||||||
self.attribute(RoleUser.role_id, int)
|
self.attribute(RoleUser.role_id, int)
|
||||||
self.attribute(RoleUser.user_id, int)
|
self.attribute(RoleUser.user_id, int)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ CREATE TABLE IF NOT EXISTS administration_auth_users
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS administration_auth_users_history
|
CREATE TABLE IF NOT EXISTS administration_auth_users_history
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
id INT NOT NULL,
|
||||||
keycloakId CHAR(36) NOT NULL,
|
keycloakId CHAR(36) NOT NULL,
|
||||||
-- for history
|
-- for history
|
||||||
deleted BOOL NOT NULL,
|
deleted BOOL NOT NULL,
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ CREATE TABLE IF NOT EXISTS administration_api_keys
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS administration_api_keys_history
|
CREATE TABLE IF NOT EXISTS administration_api_keys_history
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
id INT NOT NULL,
|
||||||
identifier VARCHAR(255) NOT NULL,
|
identifier VARCHAR(255) NOT NULL,
|
||||||
keyString VARCHAR(255) NOT NULL,
|
keyString VARCHAR(255) NOT NULL,
|
||||||
deleted BOOL NOT NULL,
|
deleted BOOL NOT NULL,
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS permission_permissions
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS permission_permissions_history
|
CREATE TABLE IF NOT EXISTS permission_permissions_history
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
id INT NOT NULL,
|
||||||
name VARCHAR(255) NOT NULL,
|
name VARCHAR(255) NOT NULL,
|
||||||
description TEXT NULL,
|
description TEXT NULL,
|
||||||
deleted BOOL NOT NULL,
|
deleted BOOL NOT NULL,
|
||||||
@@ -57,7 +57,7 @@ CREATE TABLE IF NOT EXISTS permission_roles
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS permission_roles_history
|
CREATE TABLE IF NOT EXISTS permission_roles_history
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
id INT NOT NULL,
|
||||||
name VARCHAR(255) NOT NULL,
|
name VARCHAR(255) NOT NULL,
|
||||||
description TEXT NULL,
|
description TEXT NULL,
|
||||||
deleted BOOL NOT NULL,
|
deleted BOOL NOT NULL,
|
||||||
@@ -103,7 +103,7 @@ CREATE TABLE IF NOT EXISTS permission_role_permissions
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS permission_role_permissions_history
|
CREATE TABLE IF NOT EXISTS permission_role_permissions_history
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
id INT NOT NULL,
|
||||||
RoleId INT NOT NULL,
|
RoleId INT NOT NULL,
|
||||||
permissionId INT NOT NULL,
|
permissionId INT NOT NULL,
|
||||||
deleted BOOL NOT NULL,
|
deleted BOOL NOT NULL,
|
||||||
@@ -149,7 +149,7 @@ CREATE TABLE IF NOT EXISTS permission_role_auth_users
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS permission_role_auth_users_history
|
CREATE TABLE IF NOT EXISTS permission_role_auth_users_history
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
id INT NOT NULL,
|
||||||
RoleId INT NOT NULL,
|
RoleId INT NOT NULL,
|
||||||
UserId INT NOT NULL,
|
UserId INT NOT NULL,
|
||||||
deleted BOOL NOT NULL,
|
deleted BOOL NOT NULL,
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ CREATE TABLE IF NOT EXISTS permission_api_key_permissions
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS permission_api_key_permissions_history
|
CREATE TABLE IF NOT EXISTS permission_api_key_permissions_history
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
id INT NOT NULL,
|
||||||
apiKeyId INT NOT NULL,
|
apiKeyId INT NOT NULL,
|
||||||
permissionId INT NOT NULL,
|
permissionId INT NOT NULL,
|
||||||
deleted BOOL NOT NULL,
|
deleted BOOL NOT NULL,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
cpl-core
|
cpl-core
|
||||||
cpl-dependency
|
cpl-dependency
|
||||||
cpl-database
|
cpl-database
|
||||||
python-keycloak-5.8.1
|
python-keycloak==5.8.1
|
||||||
0
src/cpl-core/cpl/core/abc/__init__.py
Normal file
0
src/cpl-core/cpl/core/abc/__init__.py
Normal file
23
src/cpl-core/cpl/core/abc/registry_abc.py
Normal file
23
src/cpl-core/cpl/core/abc/registry_abc.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from abc import abstractmethod, ABC
|
||||||
|
from typing import Generic
|
||||||
|
|
||||||
|
from cpl.core.typing import T
|
||||||
|
|
||||||
|
|
||||||
|
class RegistryABC(ABC, Generic[T]):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self):
|
||||||
|
self._items: dict[str, T] = {}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def extend(self, items: list[T]) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add(self, item: T) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, key: str) -> T | None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def all(self) -> list[T]: ...
|
||||||
@@ -130,7 +130,7 @@ class Configuration:
|
|||||||
key_name = key.__name__ if inspect.isclass(key) else key
|
key_name = key.__name__ if inspect.isclass(key) else key
|
||||||
|
|
||||||
result = cls._config.get(key_name, default)
|
result = cls._config.get(key_name, default)
|
||||||
if issubclass(key, ConfigurationModelABC) and result == default:
|
if isclass(key) and issubclass(key, ConfigurationModelABC) and result == default:
|
||||||
result = key()
|
result = key()
|
||||||
cls.set(key, result)
|
cls.set(key, result)
|
||||||
|
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ class ConfigurationModelABC(ABC):
|
|||||||
value = cast(Environment.get(env_field, str), cast_type)
|
value = cast(Environment.get(env_field, str), cast_type)
|
||||||
|
|
||||||
if value is None and required:
|
if value is None and required:
|
||||||
raise ValueError(f"{field} is required")
|
raise ValueError(f"{type(self).__name__}.{field} is required")
|
||||||
elif value is None:
|
elif value is None:
|
||||||
self._options[field] = default
|
self._options[field] = default
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,17 +1,18 @@
|
|||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from cpl.auth.auth_logger import AuthLogger
|
|
||||||
from cpl.auth.schema._administration.auth_user import AuthUser
|
from cpl.auth.schema._administration.auth_user import AuthUser
|
||||||
|
|
||||||
_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None)
|
_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None)
|
||||||
|
|
||||||
_logger = AuthLogger(__name__)
|
|
||||||
|
|
||||||
|
def set_user(user: Optional[AuthUser]):
|
||||||
|
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||||
|
from cpl.core.log.logger_abc import LoggerABC
|
||||||
|
|
||||||
def set_user(user_id: Optional[AuthUser]):
|
logger = ServiceProviderABC.get_global_service(LoggerABC)
|
||||||
_logger.trace("Setting user context", user_id)
|
logger.trace("Setting user context", user.id)
|
||||||
_user_context.set(user_id)
|
_user_context.set(user)
|
||||||
|
|
||||||
|
|
||||||
def get_user() -> Optional[AuthUser]:
|
def get_user() -> Optional[AuthUser]:
|
||||||
|
|||||||
15
src/cpl-core/cpl/core/errors.py
Normal file
15
src/cpl-core/cpl/core/errors.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import traceback
|
||||||
|
|
||||||
|
from cpl.core.console import Console
|
||||||
|
|
||||||
|
|
||||||
|
def dependency_error(package_name: str, e: ImportError) -> None:
|
||||||
|
Console.error(f"'{package_name}' is required to use this feature. Please install it and try again.")
|
||||||
|
tb = traceback.format_exc()
|
||||||
|
if not tb.startswith("NoneType: None"):
|
||||||
|
Console.write_line("->", tb)
|
||||||
|
|
||||||
|
elif e is not None:
|
||||||
|
Console.write_line("->", str(e))
|
||||||
|
|
||||||
|
exit(1)
|
||||||
@@ -2,3 +2,4 @@ from .logger import Logger
|
|||||||
from .logger_abc import LoggerABC
|
from .logger_abc import LoggerABC
|
||||||
from .log_level import LogLevel
|
from .log_level import LogLevel
|
||||||
from .log_settings import LogSettings
|
from .log_settings import LogSettings
|
||||||
|
from .structured_logger import StructuredLogger
|
||||||
|
|||||||
111
src/cpl-core/cpl/core/log/structured_logger.py
Normal file
111
src/cpl-core/cpl/core/log/structured_logger.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
import asyncio
|
||||||
|
import importlib.util
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
from cpl.core.log.log_level import LogLevel
|
||||||
|
from cpl.core.log.logger import Logger
|
||||||
|
from cpl.core.typing import Source, Messages
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredLogger(Logger):
|
||||||
|
|
||||||
|
def __init__(self, source: Source, file_prefix: str = None):
|
||||||
|
Logger.__init__(self, source, file_prefix)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log_file(self):
|
||||||
|
return f"logs/{self._file_prefix}_{datetime.now().strftime('%Y-%m-%d')}.jsonl"
|
||||||
|
|
||||||
|
def _log(self, level: LogLevel, *messages: Messages):
|
||||||
|
try:
|
||||||
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||||
|
formatted_message = self._format_message(level.value, timestamp, *messages)
|
||||||
|
structured_message = self._get_structured_message(level.value, timestamp, formatted_message)
|
||||||
|
|
||||||
|
self._write_log_to_file(level, structured_message)
|
||||||
|
self._write_to_console(level, formatted_message)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error while logging: {e} -> {traceback.format_exc()}")
|
||||||
|
|
||||||
|
def _get_structured_message(self, level: str, timestamp: str, messages: str) -> str:
|
||||||
|
structured_message = {
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"level": level.upper(),
|
||||||
|
"source": self._source,
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
self._enrich_message_with_request(structured_message)
|
||||||
|
self._enrich_message_with_user(structured_message)
|
||||||
|
|
||||||
|
return json.dumps(structured_message, ensure_ascii=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _scope_to_json(request: Request, include_headers: bool = False) -> dict:
|
||||||
|
scope = dict(request.scope)
|
||||||
|
|
||||||
|
def convert(value):
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
return value.decode("utf-8")
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
return [convert(v) for v in value]
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {str(k): convert(v) for k, v in value.items()}
|
||||||
|
if not isinstance(value, (str, int, float, bool, type(None))):
|
||||||
|
return str(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
serializable_scope = {str(k): convert(v) for k, v in scope.items()}
|
||||||
|
|
||||||
|
if not include_headers and "headers" in serializable_scope:
|
||||||
|
serializable_scope["headers"] = "<omitted>"
|
||||||
|
|
||||||
|
return serializable_scope
|
||||||
|
|
||||||
|
def _enrich_message_with_request(self, message: dict):
|
||||||
|
if importlib.util.find_spec("cpl.api") is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
from cpl.api.middleware.request import get_request
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
request = get_request()
|
||||||
|
|
||||||
|
if request is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
message["request"] = {
|
||||||
|
"url": str(request.url),
|
||||||
|
"method": request.method,
|
||||||
|
"scope": self._scope_to_json(request),
|
||||||
|
}
|
||||||
|
if isinstance(request, Request) and request.scope == "http":
|
||||||
|
request: Request = request # fix typing for IDEs
|
||||||
|
|
||||||
|
message["request"]["data"] = asyncio.create_task(request.body())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _enrich_message_with_user(message: dict):
|
||||||
|
if importlib.util.find_spec("cpl-auth") is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
from cpl.core.ctx import get_user
|
||||||
|
|
||||||
|
user = get_user()
|
||||||
|
if user is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||||
|
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin
|
||||||
|
|
||||||
|
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||||
|
kc_user = keycloak.get_user(user.keycloak_id)
|
||||||
|
message["user"] = {
|
||||||
|
"id": str(user.id),
|
||||||
|
"username": kc_user.get("username"),
|
||||||
|
"email": kc_user.get("email"),
|
||||||
|
}
|
||||||
97
src/cpl-core/cpl/core/log/wrapped_logger.py
Normal file
97
src/cpl-core/cpl/core/log/wrapped_logger.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
import inspect
|
||||||
|
|
||||||
|
from cpl.core.log import LoggerABC, LogLevel
|
||||||
|
from cpl.core.typing import Messages, Source
|
||||||
|
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedLogger(LoggerABC):
|
||||||
|
|
||||||
|
def __init__(self, file_prefix: str):
|
||||||
|
LoggerABC.__init__(self)
|
||||||
|
assert file_prefix is not None and file_prefix != "", "file_prefix must be a non-empty string"
|
||||||
|
|
||||||
|
t_logger = ServiceProviderABC.get_global_service(LoggerABC)
|
||||||
|
self._t_logger = type(t_logger) if t_logger is not None else None
|
||||||
|
self._source = None
|
||||||
|
self._file_prefix = file_prefix
|
||||||
|
|
||||||
|
self._set_logger()
|
||||||
|
|
||||||
|
def _set_logger(self):
|
||||||
|
if self._t_logger is None:
|
||||||
|
raise Exception("No LoggerABC service registered in ServiceProviderABC")
|
||||||
|
|
||||||
|
self._logger = self._t_logger(self._source, self._file_prefix)
|
||||||
|
|
||||||
|
def set_level(self, level: LogLevel):
|
||||||
|
self._logger.set_level(level)
|
||||||
|
|
||||||
|
def _format_message(self, level: str, timestamp, *messages: Messages) -> str:
|
||||||
|
return self._logger._format_message(level, timestamp, *messages)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_source() -> str | None:
|
||||||
|
stack = inspect.stack()
|
||||||
|
if len(stack) <= 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
from cpl.dependency import ServiceCollection
|
||||||
|
|
||||||
|
ignore_classes = [
|
||||||
|
ServiceProviderABC,
|
||||||
|
ServiceProviderABC.__subclasses__(),
|
||||||
|
ServiceCollection,
|
||||||
|
WrappedLogger,
|
||||||
|
WrappedLogger.__subclasses__(),
|
||||||
|
]
|
||||||
|
|
||||||
|
ignore_modules = [x.__module__ for x in ignore_classes if isinstance(x, type)]
|
||||||
|
|
||||||
|
for i, frame_info in enumerate(stack[1:]):
|
||||||
|
module = inspect.getmodule(frame_info.frame)
|
||||||
|
if module is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if module.__name__ in ignore_classes or module in ignore_classes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if module in ignore_modules or module.__name__ in ignore_modules:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if module.__name__ != __name__:
|
||||||
|
return module.__name__
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _set_source(self):
|
||||||
|
self._source = self._get_source()
|
||||||
|
self._set_logger()
|
||||||
|
|
||||||
|
def header(self, string: str):
|
||||||
|
self._set_source()
|
||||||
|
self._logger.header(string)
|
||||||
|
|
||||||
|
def trace(self, *messages: Messages):
|
||||||
|
self._set_source()
|
||||||
|
self._logger.trace(*messages)
|
||||||
|
|
||||||
|
def debug(self, *messages: Messages):
|
||||||
|
self._set_source()
|
||||||
|
self._logger.debug(*messages)
|
||||||
|
|
||||||
|
def info(self, *messages: Messages):
|
||||||
|
self._set_source()
|
||||||
|
self._logger.info(*messages)
|
||||||
|
|
||||||
|
def warning(self, *messages: Messages):
|
||||||
|
self._set_source()
|
||||||
|
self._logger.warning(*messages)
|
||||||
|
|
||||||
|
def error(self, messages: str, e: Exception = None):
|
||||||
|
self._set_source()
|
||||||
|
self._logger.error(messages, e)
|
||||||
|
|
||||||
|
def fatal(self, messages: str, e: Exception = None):
|
||||||
|
self._set_source()
|
||||||
|
self._logger.fatal(messages, e)
|
||||||
@@ -2,5 +2,4 @@ art==6.5
|
|||||||
colorama==0.4.6
|
colorama==0.4.6
|
||||||
tabulate==0.9.0
|
tabulate==0.9.0
|
||||||
termcolor==3.1.0
|
termcolor==3.1.0
|
||||||
mysql-connector-python==9.4.0
|
|
||||||
pynput==1.8.1
|
pynput==1.8.1
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from cpl.application.abc import ApplicationABC as _ApplicationABC
|
from cpl.application.abc import ApplicationABC as _ApplicationABC
|
||||||
@@ -7,13 +8,19 @@ from . import postgres as _postgres
|
|||||||
from .table_manager import TableManager
|
from .table_manager import TableManager
|
||||||
|
|
||||||
|
|
||||||
def _with_migrations(self: _ApplicationABC, *paths: list[str]) -> _ApplicationABC:
|
def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC:
|
||||||
from cpl.application.host import Host
|
from cpl.application.host import Host
|
||||||
|
|
||||||
from cpl.database.service.migration_service import MigrationService
|
from cpl.database.service.migration_service import MigrationService
|
||||||
|
|
||||||
migration_service = self._services.get_service(MigrationService)
|
migration_service = self._services.get_service(MigrationService)
|
||||||
migration_service.with_directory("./scripts")
|
migration_service.with_directory(os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts"))
|
||||||
|
|
||||||
|
if isinstance(paths, str):
|
||||||
|
paths = [paths]
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
migration_service.with_directory(path)
|
||||||
|
|
||||||
Host.run(migration_service.migrate)
|
Host.run(migration_service.migrate)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from cpl.core.utils.get_value import get_value
|
|||||||
from cpl.core.utils.string import String
|
from cpl.core.utils.string import String
|
||||||
from cpl.database.abc.db_context_abc import DBContextABC
|
from cpl.database.abc.db_context_abc import DBContextABC
|
||||||
from cpl.database.const import DATETIME_FORMAT
|
from cpl.database.const import DATETIME_FORMAT
|
||||||
from cpl.database.db_logger import DBLogger
|
from cpl.database.logger import DBLogger
|
||||||
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
|
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
|
||||||
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
|
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
|
||||||
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
|
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
|
||||||
@@ -18,16 +18,12 @@ from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSor
|
|||||||
class DataAccessObjectABC(ABC, Generic[T_DBM]):
|
class DataAccessObjectABC(ABC, Generic[T_DBM]):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
|
def __init__(self, model_type: Type[T_DBM], table_name: str):
|
||||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||||
|
|
||||||
self._db = ServiceProviderABC.get_global_service(DBContextABC)
|
self._db = ServiceProviderABC.get_global_service(DBContextABC)
|
||||||
|
|
||||||
self._logger = DBLogger(source)
|
self._logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||||
self._model_type = model_type
|
|
||||||
self._table_name = table_name
|
|
||||||
|
|
||||||
self._logger = DBLogger(source)
|
|
||||||
self._model_type = model_type
|
self._model_type = model_type
|
||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
|
|
||||||
@@ -156,13 +152,16 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
|
|||||||
:param dict result: Result from the database
|
:param dict result: Result from the database
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
value_map: dict[str, T] = {}
|
value_map: dict[str, Any] = {}
|
||||||
|
db_names = self.__db_names.items()
|
||||||
|
|
||||||
for db_name, value in result.items():
|
for db_name, value in result.items():
|
||||||
# Find the attribute name corresponding to the db_name
|
# Find the attribute name corresponding to the db_name
|
||||||
attr_name = next((k for k, v in self.__db_names.items() if v == db_name), None)
|
attr_name = next((k for k, v in db_names if v == db_name), None)
|
||||||
if attr_name:
|
if not attr_name:
|
||||||
value_map[attr_name] = self._get_value_from_sql(self.__attributes[attr_name], value)
|
continue
|
||||||
|
|
||||||
|
value_map[attr_name] = self._get_value_from_sql(self.__attributes[attr_name], value)
|
||||||
|
|
||||||
return self._model_type(**value_map)
|
return self._model_type(**value_map)
|
||||||
|
|
||||||
@@ -484,7 +483,7 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
|
|||||||
builder.with_temp_table(self._external_fields[temp])
|
builder.with_temp_table(self._external_fields[temp])
|
||||||
|
|
||||||
if for_count:
|
if for_count:
|
||||||
builder.with_attribute("COUNT(*)", ignore_table_name=True)
|
builder.with_attribute("COUNT(*) as count", ignore_table_name=True)
|
||||||
else:
|
else:
|
||||||
builder.with_attribute("*")
|
builder.with_attribute("*")
|
||||||
|
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ from cpl.database.abc.db_model_abc import DbModelABC
|
|||||||
class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]):
|
class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
|
def __init__(self, model_type: Type[T_DBM], table_name: str):
|
||||||
DataAccessObjectABC.__init__(self, source, model_type, table_name)
|
DataAccessObjectABC.__init__(self, model_type, table_name)
|
||||||
|
|
||||||
self.attribute(DbModelABC.id, int, ignore=True)
|
self.attribute(DbModelABC.id, int, ignore=True)
|
||||||
self.attribute(DbModelABC.deleted, bool)
|
self.attribute(DbModelABC.deleted, bool)
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
from cpl.core.log import Logger
|
|
||||||
from cpl.core.typing import Source
|
|
||||||
|
|
||||||
|
|
||||||
class DBLogger(Logger):
|
|
||||||
|
|
||||||
def __init__(self, source: Source):
|
|
||||||
Logger.__init__(self, source, "db")
|
|
||||||
7
src/cpl-database/cpl/database/logger.py
Normal file
7
src/cpl-database/cpl/database/logger.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||||
|
|
||||||
|
|
||||||
|
class DBLogger(WrappedLogger):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
WrappedLogger.__init__(self, "db")
|
||||||
@@ -4,18 +4,17 @@ from typing import Any, List, Dict, Tuple, Union
|
|||||||
from mysql.connector import Error as MySQLError, PoolError
|
from mysql.connector import Error as MySQLError, PoolError
|
||||||
|
|
||||||
from cpl.core.configuration import Configuration
|
from cpl.core.configuration import Configuration
|
||||||
from cpl.core.environment import Environment
|
|
||||||
from cpl.database.abc.db_context_abc import DBContextABC
|
from cpl.database.abc.db_context_abc import DBContextABC
|
||||||
from cpl.database.db_logger import DBLogger
|
from cpl.database.logger import DBLogger
|
||||||
from cpl.database.model.database_settings import DatabaseSettings
|
from cpl.database.model.database_settings import DatabaseSettings
|
||||||
from cpl.database.mysql.mysql_pool import MySQLPool
|
from cpl.database.mysql.mysql_pool import MySQLPool
|
||||||
|
|
||||||
_logger = DBLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DBContext(DBContextABC):
|
class DBContext(DBContextABC):
|
||||||
def __init__(self):
|
def __init__(self, logger: DBLogger):
|
||||||
DBContextABC.__init__(self)
|
DBContextABC.__init__(self)
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
self._pool: MySQLPool = None
|
self._pool: MySQLPool = None
|
||||||
self._fails = 0
|
self._fails = 0
|
||||||
|
|
||||||
@@ -23,62 +22,62 @@ class DBContext(DBContextABC):
|
|||||||
|
|
||||||
def connect(self, database_settings: DatabaseSettings):
|
def connect(self, database_settings: DatabaseSettings):
|
||||||
try:
|
try:
|
||||||
_logger.debug("Connecting to database")
|
self._logger.debug("Connecting to database")
|
||||||
self._pool = MySQLPool(
|
self._pool = MySQLPool(
|
||||||
database_settings,
|
database_settings,
|
||||||
)
|
)
|
||||||
_logger.info("Connected to database")
|
self._logger.info("Connected to database")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.fatal("Connecting to database failed", e)
|
self._logger.fatal("Connecting to database failed", e)
|
||||||
|
|
||||||
async def execute(self, statement: str, args=None, multi=True) -> List[List]:
|
async def execute(self, statement: str, args=None, multi=True) -> List[List]:
|
||||||
_logger.trace(f"execute {statement} with args: {args}")
|
self._logger.trace(f"execute {statement} with args: {args}")
|
||||||
return await self._pool.execute(statement, args, multi)
|
return await self._pool.execute(statement, args, multi)
|
||||||
|
|
||||||
async def select_map(self, statement: str, args=None) -> List[Dict]:
|
async def select_map(self, statement: str, args=None) -> List[Dict]:
|
||||||
_logger.trace(f"select {statement} with args: {args}")
|
self._logger.trace(f"select {statement} with args: {args}")
|
||||||
try:
|
try:
|
||||||
return await self._pool.select_map(statement, args)
|
return await self._pool.select_map(statement, args)
|
||||||
except (MySQLError, PoolError) as e:
|
except (MySQLError, PoolError) as e:
|
||||||
if self._fails >= 3:
|
if self._fails >= 3:
|
||||||
_logger.error(f"Database error caused by `{statement}`", e)
|
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||||
uid = uuid.uuid4()
|
uid = uuid.uuid4()
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
|
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
|
||||||
)
|
)
|
||||||
|
|
||||||
_logger.error(f"Database error caused by `{statement}`", e)
|
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||||
self._fails += 1
|
self._fails += 1
|
||||||
try:
|
try:
|
||||||
_logger.debug("Retry select")
|
self._logger.debug("Retry select")
|
||||||
return await self.select_map(statement, args)
|
return await self.select_map(statement, args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
return []
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.error(f"Database error caused by `{statement}`", e)
|
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def select(self, statement: str, args=None) -> Union[List[str], List[Tuple], List[Any]]:
|
async def select(self, statement: str, args=None) -> Union[List[str], List[Tuple], List[Any]]:
|
||||||
_logger.trace(f"select {statement} with args: {args}")
|
self._logger.trace(f"select {statement} with args: {args}")
|
||||||
try:
|
try:
|
||||||
return await self._pool.select(statement, args)
|
return await self._pool.select(statement, args)
|
||||||
except (MySQLError, PoolError) as e:
|
except (MySQLError, PoolError) as e:
|
||||||
if self._fails >= 3:
|
if self._fails >= 3:
|
||||||
_logger.error(f"Database error caused by `{statement}`", e)
|
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||||
uid = uuid.uuid4()
|
uid = uuid.uuid4()
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
|
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
|
||||||
)
|
)
|
||||||
|
|
||||||
_logger.error(f"Database error caused by `{statement}`", e)
|
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||||
self._fails += 1
|
self._fails += 1
|
||||||
try:
|
try:
|
||||||
_logger.debug("Retry select")
|
self._logger.debug("Retry select")
|
||||||
return await self.select(statement, args)
|
return await self.select(statement, args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
return []
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.error(f"Database error caused by `{statement}`", e)
|
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@@ -1,105 +1,92 @@
|
|||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
|
|
||||||
import sqlparse
|
import sqlparse
|
||||||
import aiomysql
|
from mysql.connector.aio import MySQLConnectionPool
|
||||||
|
|
||||||
from cpl.core.environment import Environment
|
from cpl.core.environment import Environment
|
||||||
from cpl.database.db_logger import DBLogger
|
from cpl.database.logger import DBLogger
|
||||||
from cpl.database.model import DatabaseSettings
|
from cpl.database.model import DatabaseSettings
|
||||||
|
from cpl.dependency import ServiceProviderABC
|
||||||
_logger = DBLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MySQLPool:
|
class MySQLPool:
|
||||||
"""
|
|
||||||
Create a pool when connecting to MySQL, which will decrease the time spent in
|
|
||||||
requesting connection, creating connection, and closing connection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, database_settings: DatabaseSettings):
|
def __init__(self, database_settings: DatabaseSettings):
|
||||||
self._db_settings = database_settings
|
self._dbconfig = {
|
||||||
self.pool: Optional[aiomysql.Pool] = None
|
"host": database_settings.host,
|
||||||
|
"port": database_settings.port,
|
||||||
|
"user": database_settings.user,
|
||||||
|
"password": database_settings.password,
|
||||||
|
"database": database_settings.database,
|
||||||
|
"ssl_disabled": True,
|
||||||
|
}
|
||||||
|
self._pool: Optional[MySQLConnectionPool] = None
|
||||||
|
|
||||||
async def _get_pool(self):
|
async def _get_pool(self):
|
||||||
if self.pool is None or self.pool._closed:
|
if self._pool is None:
|
||||||
|
self._pool = MySQLConnectionPool(
|
||||||
|
pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig
|
||||||
|
)
|
||||||
|
await self._pool.initialize_pool()
|
||||||
|
|
||||||
|
con = await self._pool.get_connection()
|
||||||
try:
|
try:
|
||||||
self.pool = await aiomysql.create_pool(
|
async with await con.cursor() as cursor:
|
||||||
host=self._db_settings.host,
|
await cursor.execute("SELECT 1")
|
||||||
port=self._db_settings.port,
|
await cursor.fetchall()
|
||||||
user=self._db_settings.user,
|
|
||||||
password=self._db_settings.password,
|
|
||||||
db=self._db_settings.database,
|
|
||||||
minsize=1,
|
|
||||||
maxsize=Environment.get("DB_POOL_SIZE", int, 1),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.fatal("Failed to connect to the database", e)
|
logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||||
raise
|
logger.fatal(f"Error connecting to the database: {e}")
|
||||||
return self.pool
|
finally:
|
||||||
|
await con.close()
|
||||||
|
|
||||||
|
return self._pool
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
|
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
|
||||||
|
result = []
|
||||||
if multi:
|
if multi:
|
||||||
queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()]
|
queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()]
|
||||||
for q in queries:
|
for q in queries:
|
||||||
if q.strip() == "":
|
if q.strip() == "":
|
||||||
continue
|
continue
|
||||||
await cursor.execute(q, args)
|
await cursor.execute(q, args)
|
||||||
|
if cursor.description is not None:
|
||||||
|
result = await cursor.fetchall()
|
||||||
else:
|
else:
|
||||||
await cursor.execute(query, args)
|
await cursor.execute(query, args)
|
||||||
|
if cursor.description is not None:
|
||||||
|
result = await cursor.fetchall()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def execute(self, query: str, args=None, multi=True) -> list[list]:
|
async def execute(self, query: str, args=None, multi=True) -> list[list]:
|
||||||
"""
|
|
||||||
Execute a SQL statement, it could be with args and without args. The usage is
|
|
||||||
similar to the execute() function in aiomysql.
|
|
||||||
:param query: SQL clause
|
|
||||||
:param args: args needed by the SQL clause
|
|
||||||
:param multi: if the query is a multi-statement
|
|
||||||
:return: return result
|
|
||||||
"""
|
|
||||||
pool = await self._get_pool()
|
pool = await self._get_pool()
|
||||||
async with pool.acquire() as con:
|
con = await pool.get_connection()
|
||||||
async with con.cursor() as cursor:
|
try:
|
||||||
await self._exec_sql(cursor, query, args, multi)
|
async with await con.cursor() as cursor:
|
||||||
|
result = await self._exec_sql(cursor, query, args, multi)
|
||||||
await con.commit()
|
await con.commit()
|
||||||
|
return result
|
||||||
if cursor.description is not None: # Query returns rows
|
finally:
|
||||||
res = await cursor.fetchall()
|
await con.close()
|
||||||
if res is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return [list(row) for row in res]
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def select(self, query: str, args=None, multi=True) -> list[str]:
|
async def select(self, query: str, args=None, multi=True) -> list[str]:
|
||||||
"""
|
|
||||||
Execute a SQL statement, it could be with args and without args. The usage is
|
|
||||||
similar to the execute() function in aiomysql.
|
|
||||||
:param query: SQL clause
|
|
||||||
:param args: args needed by the SQL clause
|
|
||||||
:param multi: if the query is a multi-statement
|
|
||||||
:return: return result
|
|
||||||
"""
|
|
||||||
pool = await self._get_pool()
|
pool = await self._get_pool()
|
||||||
async with pool.acquire() as con:
|
con = await pool.get_connection()
|
||||||
async with con.cursor() as cursor:
|
try:
|
||||||
await self._exec_sql(cursor, query, args, multi)
|
async with await con.cursor() as cursor:
|
||||||
res = await cursor.fetchall()
|
res = await self._exec_sql(cursor, query, args, multi)
|
||||||
return list(res)
|
return list(res)
|
||||||
|
finally:
|
||||||
|
await con.close()
|
||||||
|
|
||||||
async def select_map(self, query: str, args=None, multi=True) -> list[dict]:
|
async def select_map(self, query: str, args=None, multi=True) -> list[dict]:
|
||||||
"""
|
|
||||||
Execute a SQL statement, it could be with args and without args. The usage is
|
|
||||||
similar to the execute() function in aiomysql.
|
|
||||||
:param query: SQL clause
|
|
||||||
:param args: args needed by the SQL clause
|
|
||||||
:param multi: if the query is a multi-statement
|
|
||||||
:return: return result
|
|
||||||
"""
|
|
||||||
pool = await self._get_pool()
|
pool = await self._get_pool()
|
||||||
async with pool.acquire() as con:
|
con = await pool.get_connection()
|
||||||
async with con.cursor(aiomysql.DictCursor) as cursor:
|
try:
|
||||||
await self._exec_sql(cursor, query, args, multi)
|
async with await con.cursor(dictionary=True) as cursor:
|
||||||
res = await cursor.fetchall()
|
res = await self._exec_sql(cursor, query, args, multi)
|
||||||
return list(res)
|
return list(res)
|
||||||
|
finally:
|
||||||
|
await con.close()
|
||||||
|
|||||||
@@ -7,16 +7,16 @@ from psycopg_pool import PoolTimeout
|
|||||||
from cpl.core.configuration import Configuration
|
from cpl.core.configuration import Configuration
|
||||||
from cpl.core.environment import Environment
|
from cpl.core.environment import Environment
|
||||||
from cpl.database.abc.db_context_abc import DBContextABC
|
from cpl.database.abc.db_context_abc import DBContextABC
|
||||||
from cpl.database.database_settings import DatabaseSettings
|
from cpl.database.logger import DBLogger
|
||||||
from cpl.database.db_logger import DBLogger
|
from cpl.database.model import DatabaseSettings
|
||||||
from cpl.database.postgres.postgres_pool import PostgresPool
|
from cpl.database.postgres.postgres_pool import PostgresPool
|
||||||
|
|
||||||
_logger = DBLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DBContext(DBContextABC):
|
class DBContext(DBContextABC):
|
||||||
def __init__(self):
|
def __init__(self, logger: DBLogger):
|
||||||
DBContextABC.__init__(self)
|
DBContextABC.__init__(self)
|
||||||
|
|
||||||
|
self._logger = logger
|
||||||
self._pool: PostgresPool = None
|
self._pool: PostgresPool = None
|
||||||
self._fails = 0
|
self._fails = 0
|
||||||
|
|
||||||
@@ -24,63 +24,63 @@ class DBContext(DBContextABC):
|
|||||||
|
|
||||||
def connect(self, database_settings: DatabaseSettings):
|
def connect(self, database_settings: DatabaseSettings):
|
||||||
try:
|
try:
|
||||||
_logger.debug("Connecting to database")
|
self._logger.debug("Connecting to database")
|
||||||
self._pool = PostgresPool(
|
self._pool = PostgresPool(
|
||||||
database_settings,
|
database_settings,
|
||||||
Environment.get("DB_POOL_SIZE", int, 1),
|
Environment.get("DB_POOL_SIZE", int, 1),
|
||||||
)
|
)
|
||||||
_logger.info("Connected to database")
|
self._logger.info("Connected to database")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.fatal("Connecting to database failed", e)
|
self._logger.fatal("Connecting to database failed", e)
|
||||||
|
|
||||||
async def execute(self, statement: str, args=None, multi=True) -> list[list]:
|
async def execute(self, statement: str, args=None, multi=True) -> list[list]:
|
||||||
_logger.trace(f"execute {statement} with args: {args}")
|
self._logger.trace(f"execute {statement} with args: {args}")
|
||||||
return await self._pool.execute(statement, args, multi)
|
return await self._pool.execute(statement, args, multi)
|
||||||
|
|
||||||
async def select_map(self, statement: str, args=None) -> list[dict]:
|
async def select_map(self, statement: str, args=None) -> list[dict]:
|
||||||
_logger.trace(f"select {statement} with args: {args}")
|
self._logger.trace(f"select {statement} with args: {args}")
|
||||||
try:
|
try:
|
||||||
return await self._pool.select_map(statement, args)
|
return await self._pool.select_map(statement, args)
|
||||||
except (OperationalError, PoolTimeout) as e:
|
except (OperationalError, PoolTimeout) as e:
|
||||||
if self._fails >= 3:
|
if self._fails >= 3:
|
||||||
_logger.error(f"Database error caused by `{statement}`", e)
|
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||||
uid = uuid.uuid4()
|
uid = uuid.uuid4()
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
|
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
|
||||||
)
|
)
|
||||||
|
|
||||||
_logger.error(f"Database error caused by `{statement}`", e)
|
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||||
self._fails += 1
|
self._fails += 1
|
||||||
try:
|
try:
|
||||||
_logger.debug("Retry select")
|
self._logger.debug("Retry select")
|
||||||
return await self.select_map(statement, args)
|
return await self.select_map(statement, args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
return []
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.error(f"Database error caused by `{statement}`", e)
|
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def select(self, statement: str, args=None) -> list[str] | list[tuple] | list[Any]:
|
async def select(self, statement: str, args=None) -> list[str] | list[tuple] | list[Any]:
|
||||||
_logger.trace(f"select {statement} with args: {args}")
|
self._logger.trace(f"select {statement} with args: {args}")
|
||||||
try:
|
try:
|
||||||
return await self._pool.select(statement, args)
|
return await self._pool.select(statement, args)
|
||||||
except (OperationalError, PoolTimeout) as e:
|
except (OperationalError, PoolTimeout) as e:
|
||||||
if self._fails >= 3:
|
if self._fails >= 3:
|
||||||
_logger.error(f"Database error caused by `{statement}`", e)
|
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||||
uid = uuid.uuid4()
|
uid = uuid.uuid4()
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
|
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
|
||||||
)
|
)
|
||||||
|
|
||||||
_logger.error(f"Database error caused by `{statement}`", e)
|
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||||
self._fails += 1
|
self._fails += 1
|
||||||
try:
|
try:
|
||||||
_logger.debug("Retry select")
|
self._logger.debug("Retry select")
|
||||||
return await self.select(statement, args)
|
return await self.select(statement, args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
return []
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.error(f"Database error caused by `{statement}`", e)
|
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@@ -5,10 +5,9 @@ from psycopg import sql
|
|||||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||||
|
|
||||||
from cpl.core.environment import Environment
|
from cpl.core.environment import Environment
|
||||||
from cpl.database.db_logger import DBLogger
|
from cpl.database.logger import DBLogger
|
||||||
from cpl.database.model import DatabaseSettings
|
from cpl.database.model import DatabaseSettings
|
||||||
|
from cpl.dependency import ServiceProviderABC
|
||||||
_logger = DBLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class PostgresPool:
|
class PostgresPool:
|
||||||
@@ -25,21 +24,24 @@ class PostgresPool:
|
|||||||
f"password={database_settings.password} "
|
f"password={database_settings.password} "
|
||||||
f"dbname={database_settings.database}"
|
f"dbname={database_settings.database}"
|
||||||
)
|
)
|
||||||
|
self._pool: Optional[AsyncConnectionPool] = None
|
||||||
self.pool: Optional[AsyncConnectionPool] = None
|
|
||||||
|
|
||||||
async def _get_pool(self):
|
async def _get_pool(self):
|
||||||
pool = AsyncConnectionPool(
|
if self._pool is None:
|
||||||
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
|
pool = AsyncConnectionPool(
|
||||||
)
|
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
|
||||||
await pool.open()
|
)
|
||||||
try:
|
await pool.open()
|
||||||
async with pool.connection() as con:
|
try:
|
||||||
await pool.check_connection(con)
|
async with pool.connection() as con:
|
||||||
except PoolTimeout as e:
|
await pool.check_connection(con)
|
||||||
await pool.close()
|
except PoolTimeout as e:
|
||||||
_logger.fatal(f"Failed to connect to the database", e)
|
await pool.close()
|
||||||
return pool
|
logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||||
|
logger.fatal(f"Failed to connect to the database", e)
|
||||||
|
self._pool = pool
|
||||||
|
|
||||||
|
return self._pool
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
|
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
|
||||||
|
|||||||
@@ -1,14 +1,11 @@
|
|||||||
from cpl.database import TableManager
|
from cpl.database import TableManager
|
||||||
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
||||||
from cpl.database.db_logger import DBLogger
|
|
||||||
from cpl.database.schema.executed_migration import ExecutedMigration
|
from cpl.database.schema.executed_migration import ExecutedMigration
|
||||||
|
|
||||||
_logger = DBLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutedMigrationDao(DataAccessObjectABC[ExecutedMigration]):
|
class ExecutedMigrationDao(DataAccessObjectABC[ExecutedMigration]):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
DataAccessObjectABC.__init__(self, __name__, ExecutedMigration, TableManager.get("executed_migrations"))
|
DataAccessObjectABC.__init__(self, ExecutedMigration, TableManager.get("executed_migrations"))
|
||||||
|
|
||||||
self.attribute(ExecutedMigration.migration_id, str, primary_key=True, db_name="migrationId")
|
self.attribute(ExecutedMigration.migration_id, str, primary_key=True, db_name="migrationId")
|
||||||
|
|||||||
@@ -2,18 +2,17 @@ import glob
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from cpl.database.abc import DBContextABC
|
from cpl.database.abc import DBContextABC
|
||||||
from cpl.database.db_logger import DBLogger
|
from cpl.database.logger import DBLogger
|
||||||
from cpl.database.model import Migration
|
from cpl.database.model import Migration
|
||||||
from cpl.database.model.server_type import ServerType, ServerTypes
|
from cpl.database.model.server_type import ServerType, ServerTypes
|
||||||
from cpl.database.schema.executed_migration import ExecutedMigration
|
from cpl.database.schema.executed_migration import ExecutedMigration
|
||||||
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
|
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
|
||||||
|
|
||||||
_logger = DBLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MigrationService:
|
class MigrationService:
|
||||||
|
|
||||||
def __init__(self, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao):
|
def __init__(self, logger: DBLogger, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao):
|
||||||
|
self._logger = logger
|
||||||
self._db = db
|
self._db = db
|
||||||
self._executedMigrationDao = executedMigrationDao
|
self._executedMigrationDao = executedMigrationDao
|
||||||
|
|
||||||
@@ -96,13 +95,13 @@ class MigrationService:
|
|||||||
if migration_from_db is not None:
|
if migration_from_db is not None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
_logger.debug(f"Running upgrade migration: {migration.name}")
|
self._logger.debug(f"Running upgrade migration: {migration.name}")
|
||||||
|
|
||||||
await self._db.execute(migration.script, multi=True)
|
await self._db.execute(migration.script, multi=True)
|
||||||
|
|
||||||
await self._executedMigrationDao.create(ExecutedMigration(migration.name), skip_editor=True)
|
await self._executedMigrationDao.create(ExecutedMigration(migration.name), skip_editor=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.fatal(
|
self._logger.fatal(
|
||||||
f"Migration failed: {migration.name}\n{active_statement}",
|
f"Migration failed: {migration.name}\n{active_statement}",
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,18 +1,16 @@
|
|||||||
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
||||||
from cpl.database.db_logger import DBLogger
|
from cpl.database.logger import DBLogger
|
||||||
from cpl.dependency import ServiceProviderABC
|
from cpl.dependency import ServiceProviderABC
|
||||||
|
|
||||||
|
|
||||||
_logger = DBLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SeederService:
|
class SeederService:
|
||||||
|
|
||||||
def __init__(self, provider: ServiceProviderABC):
|
def __init__(self, provider: ServiceProviderABC):
|
||||||
self._provider = provider
|
self._provider = provider
|
||||||
|
self._logger = provider.get_service(DBLogger)
|
||||||
|
|
||||||
async def seed(self):
|
async def seed(self):
|
||||||
seeders = self._provider.get_services(DataSeederABC)
|
seeders = self._provider.get_services(DataSeederABC)
|
||||||
_logger.debug(f"Found {len(seeders)} seeders")
|
self._logger.debug(f"Found {len(seeders)} seeders")
|
||||||
for seeder in seeders:
|
for seeder in seeders:
|
||||||
await seeder.seed()
|
await seeder.seed()
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class TableManager:
|
|||||||
},
|
},
|
||||||
"role_users": {
|
"role_users": {
|
||||||
ServerTypes.POSTGRES: "permission.role_users",
|
ServerTypes.POSTGRES: "permission.role_users",
|
||||||
ServerTypes.MYSQL: "permission_role_users",
|
ServerTypes.MYSQL: "permission_role_auth_users",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from typing import Union, Type, Callable
|
from typing import Union, Type, Callable, Self
|
||||||
|
|
||||||
from cpl.core.log.logger import Logger
|
|
||||||
from cpl.core.log.logger_abc import LoggerABC
|
from cpl.core.log.logger_abc import LoggerABC
|
||||||
from cpl.core.typing import T, Service
|
from cpl.core.typing import T, Service
|
||||||
from cpl.dependency.service_descriptor import ServiceDescriptor
|
from cpl.dependency.service_descriptor import ServiceDescriptor
|
||||||
@@ -15,12 +14,17 @@ class ServiceCollection:
|
|||||||
_modules: dict[str, Callable] = {}
|
_modules: dict[str, Callable] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def with_module(cls, func: Callable, name: str = None):
|
def with_module(cls, func: Callable, name: str = None) -> type[Self]:
|
||||||
cls._modules[func.__name__ if name is None else name] = func
|
cls._modules[func.__name__ if name is None else name] = func
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._service_descriptors: list[ServiceDescriptor] = []
|
self._service_descriptors: list[ServiceDescriptor] = []
|
||||||
|
self._loaded_modules: set[str] = set()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loaded_modules(self) -> set[str]:
|
||||||
|
return self._loaded_modules
|
||||||
|
|
||||||
def _add_descriptor(self, service: Union[type, object], lifetime: ServiceLifetimeEnum, base_type: Callable = None):
|
def _add_descriptor(self, service: Union[type, object], lifetime: ServiceLifetimeEnum, base_type: Callable = None):
|
||||||
found = False
|
found = False
|
||||||
@@ -45,15 +49,15 @@ class ServiceCollection:
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def add_singleton(self, service_type: T, service: Service = None):
|
def add_singleton(self, service_type: T, service: Service = None) -> Self:
|
||||||
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.singleton, service)
|
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.singleton, service)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def add_scoped(self, service_type: T, service: Service = None):
|
def add_scoped(self, service_type: T, service: Service = None) -> Self:
|
||||||
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.scoped, service)
|
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.scoped, service)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def add_transient(self, service_type: T, service: Service = None):
|
def add_transient(self, service_type: T, service: Service = None) -> Self:
|
||||||
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service)
|
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -62,7 +66,7 @@ class ServiceCollection:
|
|||||||
ServiceProviderABC.set_global_provider(sp)
|
ServiceProviderABC.set_global_provider(sp)
|
||||||
return sp
|
return sp
|
||||||
|
|
||||||
def add_module(self, module: str | object):
|
def add_module(self, module: str | object) -> Self:
|
||||||
if not isinstance(module, str):
|
if not isinstance(module, str):
|
||||||
module = module.__name__
|
module = module.__name__
|
||||||
|
|
||||||
@@ -70,7 +74,25 @@ class ServiceCollection:
|
|||||||
raise ValueError(f"Module {module} not found")
|
raise ValueError(f"Module {module} not found")
|
||||||
|
|
||||||
self._modules[module](self)
|
self._modules[module](self)
|
||||||
|
if module not in self._loaded_modules:
|
||||||
def add_logging(self):
|
self._loaded_modules.add(module)
|
||||||
self.add_transient(LoggerABC, Logger)
|
return self
|
||||||
|
|
||||||
|
def add_logging(self) -> Self:
|
||||||
|
from cpl.core.log.logger import Logger
|
||||||
|
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||||
|
|
||||||
|
self.add_transient(LoggerABC, Logger)
|
||||||
|
for wrapper in WrappedLogger.__subclasses__():
|
||||||
|
self.add_transient(wrapper)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_structured_logging(self) -> Self:
|
||||||
|
from cpl.core.log.structured_logger import StructuredLogger
|
||||||
|
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||||
|
|
||||||
|
self.add_transient(LoggerABC, StructuredLogger)
|
||||||
|
|
||||||
|
for wrapper in WrappedLogger.__subclasses__():
|
||||||
|
self.add_transient(wrapper)
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ class ServiceProvider(ServiceProviderABC):
|
|||||||
|
|
||||||
return implementations
|
return implementations
|
||||||
|
|
||||||
def _build_by_signature(self, sig: Signature, origin_service_type: type=None) -> list[R]:
|
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[R]:
|
||||||
params = []
|
params = []
|
||||||
for param in sig.parameters.items():
|
for param in sig.parameters.items():
|
||||||
parameter = param[1]
|
parameter = param[1]
|
||||||
|
|||||||
@@ -24,19 +24,19 @@ class ServiceProviderABC(ABC):
|
|||||||
return cls._provider
|
return cls._provider
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_global_service(cls, instance_type: T, *args, **kwargs) -> Optional[R]:
|
def get_global_service(cls, instance_type: Type[T], *args, **kwargs) -> Optional[T]:
|
||||||
if cls._provider is None:
|
if cls._provider is None:
|
||||||
return None
|
return None
|
||||||
return cls._provider.get_service(instance_type, *args, **kwargs)
|
return cls._provider.get_service(instance_type, *args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_global_services(cls, instance_type: T, *args, **kwargs) -> list[Optional[R]]:
|
def get_global_services(cls, instance_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
|
||||||
if cls._provider is None:
|
if cls._provider is None:
|
||||||
return []
|
return []
|
||||||
return cls._provider.get_services(instance_type, *args, **kwargs)
|
return cls._provider.get_services(instance_type, *args, **kwargs)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _build_by_signature(self, sig: Signature, origin_service_type: type=None) -> list[R]: ...
|
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[T]: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _build_service(self, service_type: type, *args, **kwargs) -> object:
|
def _build_service(self, service_type: type, *args, **kwargs) -> object:
|
||||||
@@ -114,14 +114,24 @@ class ServiceProviderABC(ABC):
|
|||||||
if f is None:
|
if f is None:
|
||||||
return functools.partial(cls.inject)
|
return functools.partial(cls.inject)
|
||||||
|
|
||||||
|
if iscoroutinefunction(f):
|
||||||
|
|
||||||
|
@functools.wraps(f)
|
||||||
|
async def async_inner(*args, **kwargs):
|
||||||
|
if cls._provider is None:
|
||||||
|
raise Exception(f"{cls.__name__} not build!")
|
||||||
|
|
||||||
|
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
|
||||||
|
return await f(*args, *injection, **kwargs)
|
||||||
|
|
||||||
|
return async_inner
|
||||||
|
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
async def inner(*args, **kwargs):
|
def inner(*args, **kwargs):
|
||||||
if cls._provider is None:
|
if cls._provider is None:
|
||||||
raise Exception(f"{cls.__name__} not build!")
|
raise Exception(f"{cls.__name__} not build!")
|
||||||
|
|
||||||
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
|
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
|
||||||
if iscoroutinefunction(f):
|
|
||||||
return await f(*args, *injection, **kwargs)
|
|
||||||
return f(*args, *injection, **kwargs)
|
return f(*args, *injection, **kwargs)
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from .abc.email_client_abc import EMailClientABC
|
|||||||
from .email_client import EMailClient
|
from .email_client import EMailClient
|
||||||
from .email_client_settings import EMailClientSettings
|
from .email_client_settings import EMailClientSettings
|
||||||
from .email_model import EMail
|
from .email_model import EMail
|
||||||
from .mail_logger import MailLogger
|
from .logger import MailLogger
|
||||||
|
|
||||||
|
|
||||||
def add_mail(collection: _ServiceCollection):
|
def add_mail(collection: _ServiceCollection):
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Optional
|
|||||||
from cpl.mail.abc.email_client_abc import EMailClientABC
|
from cpl.mail.abc.email_client_abc import EMailClientABC
|
||||||
from cpl.mail.email_client_settings import EMailClientSettings
|
from cpl.mail.email_client_settings import EMailClientSettings
|
||||||
from cpl.mail.email_model import EMail
|
from cpl.mail.email_model import EMail
|
||||||
from cpl.mail.mail_logger import MailLogger
|
from cpl.mail.logger import MailLogger
|
||||||
|
|
||||||
|
|
||||||
class EMailClient(EMailClientABC):
|
class EMailClient(EMailClientABC):
|
||||||
|
|||||||
7
src/cpl-mail/cpl/mail/logger.py
Normal file
7
src/cpl-mail/cpl/mail/logger.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||||
|
|
||||||
|
|
||||||
|
class MailLogger(WrappedLogger):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
WrappedLogger.__init__(self, "mail")
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
from cpl.core.log.logger import Logger
|
|
||||||
from cpl.core.typing import Source
|
|
||||||
|
|
||||||
|
|
||||||
class MailLogger(Logger):
|
|
||||||
|
|
||||||
def __init__(self, source: Source):
|
|
||||||
Logger.__init__(self, source, "mail")
|
|
||||||
8
tests/custom/api/src/appsettings.development.json
Normal file
8
tests/custom/api/src/appsettings.development.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"Logging": {
|
||||||
|
"Path": "logs/",
|
||||||
|
"Filename": "log_$start_time.log",
|
||||||
|
"ConsoleLevel": "TRACE",
|
||||||
|
"Level": "TRACE"
|
||||||
|
}
|
||||||
|
}
|
||||||
26
tests/custom/api/src/appsettings.edrafts-pc.json
Normal file
26
tests/custom/api/src/appsettings.edrafts-pc.json
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"TimeFormat": {
|
||||||
|
"DateFormat": "%Y-%m-%d",
|
||||||
|
"TimeFormat": "%H:%M:%S",
|
||||||
|
"DateTimeFormat": "%Y-%m-%d %H:%M:%S.%f",
|
||||||
|
"DateTimeLogFormat": "%Y-%m-%d_%H-%M-%S"
|
||||||
|
},
|
||||||
|
|
||||||
|
"Log": {
|
||||||
|
"Path": "logs/",
|
||||||
|
"Filename": "log_$start_time.log",
|
||||||
|
"ConsoleLevel": "TRACE",
|
||||||
|
"Level": "TRACE"
|
||||||
|
},
|
||||||
|
|
||||||
|
"Database": {
|
||||||
|
"Host": "localhost",
|
||||||
|
"User": "cpl",
|
||||||
|
"Port": 3306,
|
||||||
|
"Password": "cpl",
|
||||||
|
"Database": "cpl",
|
||||||
|
"Charset": "utf8mb4",
|
||||||
|
"UseUnicode": "true",
|
||||||
|
"Buffered": "true"
|
||||||
|
}
|
||||||
|
}
|
||||||
15
tests/custom/api/src/appsettings.json
Normal file
15
tests/custom/api/src/appsettings.json
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
{
|
||||||
|
"TimeFormat": {
|
||||||
|
"DateFormat": "%Y-%m-%d",
|
||||||
|
"TimeFormat": "%H:%M:%S",
|
||||||
|
"DateTimeFormat": "%Y-%m-%d %H:%M:%S.%f",
|
||||||
|
"DateTimeLogFormat": "%Y-%m-%d_%H-%M-%S"
|
||||||
|
},
|
||||||
|
|
||||||
|
"Log": {
|
||||||
|
"Path": "logs/",
|
||||||
|
"Filename": "log_$start_time.log",
|
||||||
|
"ConsoleLevel": "ERROR",
|
||||||
|
"Level": "WARNING"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,20 +1,35 @@
|
|||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
from cpl.api.web_app import WebApp
|
from cpl import api
|
||||||
|
from cpl.api.application.web_app import WebApp
|
||||||
from cpl.application import ApplicationBuilder
|
from cpl.application import ApplicationBuilder
|
||||||
from custom.api.src.service import PingService
|
from cpl.auth.permission.permissions import Permissions
|
||||||
|
from cpl.core.configuration import Configuration
|
||||||
|
from cpl.core.environment import Environment
|
||||||
|
from service import PingService
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
builder = ApplicationBuilder[WebApp](WebApp)
|
builder = ApplicationBuilder[WebApp](WebApp)
|
||||||
|
|
||||||
builder.services.add_logging()
|
Configuration.add_json_file(f"appsettings.json")
|
||||||
|
Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json")
|
||||||
|
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
|
||||||
|
|
||||||
|
# builder.services.add_logging()
|
||||||
|
builder.services.add_structured_logging()
|
||||||
builder.services.add_transient(PingService)
|
builder.services.add_transient(PingService)
|
||||||
|
builder.services.add_module(api)
|
||||||
|
|
||||||
app = builder.build()
|
app = builder.build()
|
||||||
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET")
|
|
||||||
app.with_routes_directory("routes")
|
|
||||||
app.with_logging()
|
app.with_logging()
|
||||||
|
app.with_database()
|
||||||
|
|
||||||
|
app.with_authentication()
|
||||||
|
app.with_authorization()
|
||||||
|
|
||||||
|
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET", authentication=True, permissions=[Permissions.administrator])
|
||||||
|
app.with_routes_directory("routes")
|
||||||
|
|
||||||
app.run()
|
app.run()
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
from urllib.request import Request
|
from urllib.request import Request
|
||||||
|
|
||||||
|
from service import PingService
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
|
from cpl.api import APILogger
|
||||||
from cpl.api.router import Router
|
from cpl.api.router import Router
|
||||||
from cpl.core.log import Logger
|
|
||||||
from custom.api.src.service import PingService
|
|
||||||
|
|
||||||
|
|
||||||
|
@Router.authenticate()
|
||||||
|
# @Router.authorize(permissions=[Permissions.administrator])
|
||||||
|
# @Router.authorize(policies=["test"])
|
||||||
@Router.get(f"/ping")
|
@Router.get(f"/ping")
|
||||||
async def ping(r: Request, ping: PingService, logger: Logger):
|
async def ping(r: Request, ping: PingService, logger: APILogger):
|
||||||
logger.info(f"Ping: {ping}")
|
logger.info(f"Ping: {ping}")
|
||||||
return JSONResponse(ping.ping(r))
|
return JSONResponse(ping.ping(r))
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from model.city import City
|
|||||||
class CityDao(DbModelDaoABC[City]):
|
class CityDao(DbModelDaoABC[City]):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
DbModelDaoABC.__init__(self, __name__, City, "city")
|
DbModelDaoABC.__init__(self, City, "city")
|
||||||
|
|
||||||
self.attribute(City.name, str)
|
self.attribute(City.name, str)
|
||||||
self.attribute(City.zip, int)
|
self.attribute(City.zip, int)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from model.user import User
|
|||||||
class UserDao(DbModelDaoABC[User]):
|
class UserDao(DbModelDaoABC[User]):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
DbModelDaoABC.__init__(self, __name__, User, "users")
|
DbModelDaoABC.__init__(self, User, "users")
|
||||||
|
|
||||||
self.attribute(User.name, str)
|
self.attribute(User.name, str)
|
||||||
self.attribute(User.city_id, int, db_name="CityId")
|
self.attribute(User.city_id, int, db_name="CityId")
|
||||||
|
|||||||
Reference in New Issue
Block a user