Compare commits

..

1 Commits

Author SHA1 Message Date
69bbbc8cee Authorization via with_route
Some checks failed
Build on push / prepare (push) Successful in 9s
Build on push / query (push) Successful in 20s
Build on push / core (push) Successful in 20s
Build on push / dependency (push) Successful in 17s
Build on push / mail (push) Successful in 15s
Build on push / application (push) Successful in 18s
Build on push / translation (push) Successful in 18s
Build on push / database (push) Successful in 25s
Build on push / auth (push) Successful in 15s
Build on push / api (push) Successful in 14s
Test before pr merge / test-lint (pull_request) Failing after 6s
2025-09-22 22:03:42 +02:00
190 changed files with 1597 additions and 1661 deletions

3
.gitignore vendored
View File

@@ -139,6 +139,3 @@ PythonImportHelper-v2-Completion.json
# cpl unittest stuff
unittests/test_*_playground
# cpl logs
**/logs/*.jsonl

View File

@@ -1,77 +0,0 @@
from starlette.responses import JSONResponse
from cpl import api
from cpl.api.application.web_app import WebApp
from cpl.application import ApplicationBuilder
from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema import AuthUser, Role
from cpl.core.configuration import Configuration
from cpl.core.console import Console
from cpl.core.environment import Environment
from cpl.core.utils.cache import Cache
from custom.api.src.scoped_service import ScopedService
from service import PingService
def main():
builder = ApplicationBuilder[WebApp](WebApp)
Configuration.add_json_file(f"appsettings.json")
Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json")
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
# builder.services.add_logging()
builder.services.add_structured_logging()
builder.services.add_transient(PingService)
builder.services.add_module(api)
builder.services.add_scoped(ScopedService)
builder.services.add_cache(AuthUser)
builder.services.add_cache(Role)
app = builder.build()
app.with_logging()
app.with_database()
app.with_authentication()
app.with_authorization()
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET", authentication=True, permissions=[Permissions.administrator])
app.with_routes_directory("routes")
provider = builder.service_provider
user_cache = provider.get_service(Cache[AuthUser])
role_cache = provider.get_service(Cache[Role])
if role_cache == user_cache:
raise Exception("Cache service is not working")
s1 = provider.get_service(ScopedService)
s2 = provider.get_service(ScopedService)
if s1.name == s2.name:
raise Exception("Scoped service is not working")
with provider.create_scope() as scope:
s3 = scope.get_service(ScopedService)
s4 = scope.get_service(ScopedService)
if s3.name != s4.name:
raise Exception("Scoped service is not working")
if s1.name == s3.name:
raise Exception("Scoped service is not working")
Console.write_line(
s1.name,
s2.name,
s3.name,
s4.name,
)
app.run()
if __name__ == "__main__":
main()

View File

@@ -1,21 +0,0 @@
from urllib.request import Request
from service import PingService
from starlette.responses import JSONResponse
from cpl.api import APILogger
from cpl.api.router import Router
from cpl.core.console import Console
from cpl.dependency import ServiceProvider
from custom.api.src.scoped_service import ScopedService
@Router.authenticate()
# @Router.authorize(permissions=[Permissions.administrator])
# @Router.authorize(policies=["test"])
@Router.get(f"/ping")
async def ping(r: Request, ping: PingService, logger: APILogger, provider: ServiceProvider, scoped: ScopedService):
logger.info(f"Ping: {ping}")
Console.write_line(scoped.name)
return JSONResponse(ping.ping(r))

View File

@@ -1,14 +0,0 @@
from cpl.core.console.console import Console
from cpl.core.utils.string import String
class ScopedService:
def __init__(self):
self._name = String.random(8)
@property
def name(self) -> str:
return self._name
def run(self):
Console.write_line(f"Im {self._name}")

View File

@@ -1,10 +0,0 @@
from cpl.dependency import ServiceProvider, ServiceProvider
from cpl.dependency.inject import inject
from di.test_service import TestService
class StaticTest:
@staticmethod
@inject
def test(services: ServiceProvider, t1: TestService):
t1.run()

View File

@@ -1,10 +0,0 @@
from cpl.core.console import Console
class ScopedService:
def __init__(self):
self.value = "I am a scoped service"
Console.write_line(self.value, self)
def get_value(self):
return self.value

View File

@@ -1,60 +0,0 @@
from cpl.core.console import Console
from cpl.core.utils.benchmark import Benchmark
from cpl.query.enumerable import Enumerable
from cpl.query.immutable_list import ImmutableList
from cpl.query.list import List
from cpl.query.set import Set
def _default():
Console.write_line(Enumerable.empty().to_list())
Console.write_line(Enumerable.range(0, 100).length)
Console.write_line(Enumerable.range(0, 100).to_list())
Console.write_line(Enumerable.range(0, 100).where(lambda x: x % 2 == 0).length)
Console.write_line(
Enumerable.range(0, 100).where(lambda x: x % 2 == 0).to_list().select(lambda x: str(x)).to_list()
)
Console.write_line(List)
s =Enumerable.range(0, 10).to_set()
Console.write_line(s)
s.add(1)
Console.write_line(s)
data = Enumerable(
[
{"name": "Alice", "age": 30},
{"name": "Dave", "age": 35},
{"name": "Charlie", "age": 25},
{"name": "Bob", "age": 25},
]
)
Console.write_line(data.order_by(lambda x: x["age"]).to_list())
Console.write_line(data.order_by(lambda x: x["age"]).then_by(lambda x: x["name"]).to_list())
Console.write_line(data.order_by(lambda x: x["name"]).then_by(lambda x: x["age"]).to_list())
def t_benchmark(data: list):
Benchmark.all("Enumerable", lambda: Enumerable(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
Benchmark.all("Set", lambda: Set(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
Benchmark.all("List", lambda: List(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
Benchmark.all(
"ImmutableList", lambda: ImmutableList(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list()
)
Benchmark.all("List comprehension", lambda: [x * 2 for x in data if x % 2 == 0])
def main():
N = 10_000_000
data = list(range(N))
#t_benchmark(data)
Console.write_line()
_default()
if __name__ == "__main__":
main()

View File

@@ -1,9 +1,5 @@
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
from .error import APIError, AlreadyExists, EndpointNotImplemented, Forbidden, NotFound, Unauthorized
from .logger import APILogger
from .settings import ApiSettings
def add_api(collection: _ServiceCollection):
try:

View File

@@ -1 +0,0 @@
from .asgi_middleware_abc import ASGIMiddleware

View File

@@ -1 +0,0 @@
from .web_app import WebApp

View File

@@ -27,41 +27,40 @@ from cpl.api.settings import ApiSettings
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
from cpl.application.abc.application_abc import ApplicationABC
from cpl.core.configuration import Configuration
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.service_provider_abc import ServiceProviderABC
_logger = APILogger("API")
PolicyInput = Union[dict[str, PolicyResolver], Policy]
class WebApp(ApplicationABC):
def __init__(self, services: ServiceProvider):
def __init__(self, services: ServiceProviderABC):
super().__init__(services, [auth, api])
self._app: Starlette | None = None
self._logger = services.get_service(APILogger)
self._api_settings = Configuration.get(ApiSettings)
self._policies = services.get_service(PolicyRegistry)
self._routes = services.get_service(RouteRegistry)
self._middleware: list[Middleware] = []
self._middleware: list[Middleware] = [
Middleware(RequestMiddleware),
Middleware(LoggingMiddleware),
]
self._exception_handlers: Mapping[Any, ExceptionHandler] = {
Exception: self._handle_exception,
APIError: self._handle_exception,
}
self.with_middleware(RequestMiddleware)
self.with_middleware(LoggingMiddleware)
async def _handle_exception(self, request: Request, exc: Exception):
@staticmethod
async def _handle_exception(request: Request, exc: Exception):
if isinstance(exc, APIError):
self._logger.error(exc)
_logger.error(exc)
return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
if hasattr(request.state, "request_id"):
self._logger.error(f"Request {request.state.request_id}", exc)
_logger.error(f"Request {request.state.request_id}", exc)
else:
self._logger.error("Request unknown", exc)
_logger.error("Request unknown", exc)
return JSONResponse({"error": str(exc)}, status_code=500)
@@ -69,10 +68,10 @@ class WebApp(ApplicationABC):
origins = self._api_settings.allowed_origins
if origins is None or origins == "":
self._logger.warning("No allowed origins specified, allowing all origins")
_logger.warning("No allowed origins specified, allowing all origins")
return ["*"]
self._logger.debug(f"Allowed origins: {origins}")
_logger.debug(f"Allowed origins: {origins}")
return origins.split(",")
def with_database(self) -> Self:
@@ -168,9 +167,9 @@ class WebApp(ApplicationABC):
self._check_for_app()
if isinstance(middleware, Middleware):
self._middleware.append(inject(middleware))
self._middleware.append(middleware)
elif callable(middleware):
self._middleware.append(Middleware(inject(middleware)))
self._middleware.append(Middleware(middleware))
else:
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
@@ -191,11 +190,11 @@ class WebApp(ApplicationABC):
if isinstance(policy, dict):
for name, resolver in policy.items():
if not isinstance(name, str):
self._logger.warning(f"Skipping policy at index {i}, name must be a string")
_logger.warning(f"Skipping policy at index {i}, name must be a string")
continue
if not callable(resolver):
self._logger.warning(f"Skipping policy {name}, resolver must be callable")
_logger.warning(f"Skipping policy {name}, resolver must be callable")
continue
_policies.append(Policy(name, resolver))
@@ -203,7 +202,7 @@ class WebApp(ApplicationABC):
_policies.append(policy)
self._policies.extend(_policies)
self._policies.extend_policies(_policies)
self.with_middleware(AuthorizationMiddleware)
return self
@@ -213,14 +212,14 @@ class WebApp(ApplicationABC):
for policy_name in rule["policies"]:
policy = self._policies.get(policy_name)
if not policy:
self._logger.fatal(f"Authorization policy '{policy_name}' not found")
_logger.fatal(f"Authorization policy '{policy_name}' not found")
async def main(self):
self._logger.debug(f"Preparing API")
_logger.debug(f"Preparing API")
self._validate_policies()
if self._app is None:
routes = [route.to_starlette(inject) for route in self._routes.all()]
routes = [route.to_starlette(self._services.inject) for route in self._routes.all()]
app = Starlette(
routes=routes,
@@ -238,7 +237,7 @@ class WebApp(ApplicationABC):
else:
app = self._app
self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
_logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
config = uvicorn.Config(
app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio"
@@ -246,4 +245,4 @@ class WebApp(ApplicationABC):
server = uvicorn.Server(config)
await server.serve()
self._logger.info("Shutdown API")
_logger.info("Shutdown API")

View File

@@ -1,7 +1,7 @@
from cpl.core.log.wrapped_logger import WrappedLogger
from cpl.core.log.logger import Logger
class APILogger(WrappedLogger):
class APILogger(Logger):
def __init__(self):
WrappedLogger.__init__(self, "api")
def __init__(self, source: str):
Logger.__init__(self, source, "api")

View File

@@ -1,4 +0,0 @@
from .authentication import AuthenticationMiddleware
from .authorization import AuthorizationMiddleware
from .logging import LoggingMiddleware
from .request import RequestMiddleware

View File

@@ -2,22 +2,24 @@ from keycloak import KeycloakAuthenticationError
from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.error import Unauthorized
from cpl.api.logger import APILogger
from cpl.api.error import Unauthorized
from cpl.api.middleware.request import get_request
from cpl.api.router import Router
from cpl.auth.keycloak import KeycloakClient
from cpl.auth.schema import AuthUserDao, AuthUser
from cpl.core.ctx import set_user
from cpl.dependency import ServiceProviderABC
_logger = APILogger(__name__)
class AuthenticationMiddleware(ASGIMiddleware):
def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
@ServiceProviderABC.inject
def __init__(self, app, keycloak: KeycloakClient, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app)
self._logger = logger
self._keycloak = keycloak
self._user_dao = user_dao
@@ -26,11 +28,11 @@ class AuthenticationMiddleware(ASGIMiddleware):
url = request.url.path
if url not in Router.get_auth_required_routes():
self._logger.trace(f"No authentication required for {url}")
_logger.trace(f"No authentication required for {url}")
return await self._app(scope, receive, send)
if not request.headers.get("Authorization"):
self._logger.debug(f"Unauthorized access to {url}, missing Authorization header")
_logger.debug(f"Unauthorized access to {url}, missing Authorization header")
return await Unauthorized(f"Missing header Authorization").asgi_response(scope, receive, send)
auth_header = request.headers.get("Authorization", None)
@@ -39,7 +41,7 @@ class AuthenticationMiddleware(ASGIMiddleware):
token = auth_header.split("Bearer ")[1]
if not await self._verify_login(token):
self._logger.debug(f"Unauthorized access to {url}, invalid token")
_logger.debug(f"Unauthorized access to {url}, invalid token")
return await Unauthorized("Invalid token").asgi_response(scope, receive, send)
# check user exists in db, if not create
@@ -49,7 +51,7 @@ class AuthenticationMiddleware(ASGIMiddleware):
user = await self._get_or_crate_user(keycloak_id)
if user.deleted:
self._logger.debug(f"Unauthorized access to {url}, user is deleted")
_logger.debug(f"Unauthorized access to {url}, user is deleted")
return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
request.state.user = user
@@ -71,8 +73,8 @@ class AuthenticationMiddleware(ASGIMiddleware):
token_info = self._keycloak.introspect(token)
return token_info.get("active", False)
except KeycloakAuthenticationError as e:
self._logger.debug(f"Keycloak authentication error: {e}")
_logger.debug(f"Keycloak authentication error: {e}")
return False
except Exception as e:
self._logger.error(f"Unexpected error during token verification: {e}")
_logger.error(f"Unexpected error during token verification: {e}")
return False

View File

@@ -9,15 +9,17 @@ from cpl.api.registry.policy import PolicyRegistry
from cpl.api.router import Router
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
from cpl.core.ctx.user_context import get_user
from cpl.dependency.service_provider_abc import ServiceProviderABC
_logger = APILogger(__name__)
class AuthorizationMiddleware(ASGIMiddleware):
def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao):
@ServiceProviderABC.inject
def __init__(self, app, policies: PolicyRegistry, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app)
self._logger = logger
self._policies = policies
self._user_dao = user_dao
@@ -26,7 +28,7 @@ class AuthorizationMiddleware(ASGIMiddleware):
url = request.url.path
if url not in Router.get_authorization_rules_paths():
self._logger.trace(f"No authorization required for {url}")
_logger.trace(f"No authorization required for {url}")
return await self._app(scope, receive, send)
user = get_user()
@@ -51,21 +53,17 @@ class AuthorizationMiddleware(ASGIMiddleware):
if rule["permissions"]:
if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
scope, receive, send
)
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]):
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
scope, receive, send
)
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
for policy_name in rule["policies"]:
policy = self._policies.get(policy_name)
if not policy:
self._logger.warning(f"Authorization policy '{policy_name}' not found")
_logger.warning(f"Authorization policy '{policy_name}' not found")
continue
if not await policy.resolve(user):
return await Forbidden(f"policy {policy.name} failed").asgi_response(scope, receive, send)
return await self._call_next(scope, receive, send)
return await self._call_next(scope, receive, send)

View File

@@ -7,14 +7,14 @@ from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger
from cpl.api.middleware.request import get_request
_logger = APILogger(__name__)
class LoggingMiddleware(ASGIMiddleware):
def __init__(self, app, logger: APILogger):
def __init__(self, app):
ASGIMiddleware.__init__(self, app)
self._logger = logger
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
await self._call_next(scope, receive, send)
@@ -53,8 +53,9 @@ class LoggingMiddleware(ASGIMiddleware):
}
return {key: value for key, value in headers.items() if key in relevant_keys}
async def _log_request(self, request: Request):
self._logger.debug(
@classmethod
async def _log_request(cls, request: Request):
_logger.debug(
f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}"
)
@@ -63,7 +64,7 @@ class LoggingMiddleware(ASGIMiddleware):
user = get_user()
request_info = {
"headers": self._filter_relevant_headers(dict(request.headers)),
"headers": cls._filter_relevant_headers(dict(request.headers)),
"args": dict(request.query_params),
"form-data": (
await request.form()
@@ -77,9 +78,10 @@ class LoggingMiddleware(ASGIMiddleware):
),
}
self._logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
_logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
async def _log_after_request(self, request: Request, status_code: int, duration: float):
self._logger.info(
@staticmethod
async def _log_after_request(request: Request, status_code: int, duration: float):
_logger.info(
f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
)

View File

@@ -9,20 +9,16 @@ from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger
from cpl.api.typing import TRequest
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
_logger = APILogger(__name__)
class RequestMiddleware(ASGIMiddleware):
def __init__(self, app, provider: ServiceProvider, logger: APILogger):
def __init__(self, app):
ASGIMiddleware.__init__(self, app)
self._provider = provider
self._logger = logger
self._ctx_token = None
async def __call__(self, scope: Scope, receive: Receive, send: Send):
@@ -30,15 +26,14 @@ class RequestMiddleware(ASGIMiddleware):
await self.set_request_data(request)
try:
with self._provider.create_scope():
inject(await self._app(scope, receive, send))
await self._app(scope, receive, send)
finally:
await self.clean_request_data()
async def set_request_data(self, request: TRequest):
request.state.request_id = uuid4()
request.state.start_time = time.time()
self._logger.trace(f"Set new current request: {request.state.request_id}")
_logger.trace(f"Set new current request: {request.state.request_id}")
self._ctx_token = _request_context.set(request)
@@ -50,7 +45,7 @@ class RequestMiddleware(ASGIMiddleware):
if self._ctx_token is None:
return
self._logger.trace(f"Clearing current request: {request.state.request_id}")
_logger.trace(f"Clearing current request: {request.state.request_id}")
_request_context.reset(self._ctx_token)

View File

@@ -1,3 +0,0 @@
from .api_route import ApiRoute
from .policy import Policy
from .validation_match import ValidationMatch

View File

@@ -7,7 +7,13 @@ from cpl.api.typing import HTTPMethods
class ApiRoute:
def __init__(self, path: str, fn: Callable, method: HTTPMethods, **kwargs):
def __init__(
self,
path: str,
fn: Callable,
method: HTTPMethods,
**kwargs
):
self._path = path
self._fn = fn
self._method = method

View File

@@ -1,5 +1,5 @@
from asyncio import iscoroutinefunction
from typing import Optional
from typing import Optional, Any, Coroutine, Awaitable
from cpl.api.typing import PolicyResolver
from cpl.core.ctx import get_user

View File

@@ -1,2 +0,0 @@
from .policy import PolicyRegistry
from .route import RouteRegistry

View File

@@ -1,5 +1,6 @@
from typing import Optional
from cpl.api.model.policy import Policy
from cpl.api.model.api_route import ApiRoute
from cpl.core.abc.registry_abc import RegistryABC

View File

@@ -3,7 +3,6 @@ from enum import Enum
from cpl.api.model.validation_match import ValidationMatch
from cpl.api.registry.route import RouteRegistry
from cpl.api.typing import HTTPMethods
from cpl.dependency import get_provider
class Router:
@@ -42,13 +41,7 @@ class Router:
return inner
@classmethod
def authorize(
cls,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
):
def authorize(cls, roles: list[str | Enum]=None, permissions: list[str | Enum]=None, policies: list[str]=None, match: ValidationMatch=None):
"""
Decorator to mark a route as requiring authorization.
Usage:
@@ -92,14 +85,15 @@ class Router:
return inner
@classmethod
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs):
from cpl.api.model.api_route import ApiRoute
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry=None, **kwargs):
if not registry:
routes = get_provider().get_service(RouteRegistry)
from cpl.api.model.api_route import ApiRoute
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = ServiceProviderABC.get_global_service(RouteRegistry)
else:
routes = registry
def inner(fn):
routes.add(ApiRoute(path, fn, method, **kwargs))
setattr(fn, "_route_path", path)
@@ -143,9 +137,8 @@ class Router:
"""
from cpl.api.model.api_route import ApiRoute
routes = get_provider().get_service(RouteRegistry)
from cpl.dependency.service_provider_abc import ServiceProviderABC
routes = ServiceProviderABC.get_global_service(RouteRegistry)
def inner(fn):
path = getattr(fn, "_route_path", None)
if path is None:
@@ -154,7 +147,7 @@ class Router:
route = routes.get(path)
if route is None:
raise ValueError(f"Cannot override a route that does not exist: {path}")
routes.add(ApiRoute(path, fn, route.method, **route.kwargs))
setattr(fn, "_route_path", path)
return fn

View File

@@ -16,4 +16,4 @@ PartialMiddleware = Union[
Middleware,
Callable[[ASGIApp], ASGIApp],
]
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]

View File

@@ -2,10 +2,11 @@ from abc import ABC, abstractmethod
from typing import Callable, Self
from cpl.application.host import Host
from cpl.core.console.console import Console
from cpl.core.log import LogSettings
from cpl.core.log.log_level import LogLevel
from cpl.core.log.log_settings import LogSettings
from cpl.core.log.logger_abc import LoggerABC
from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.service_provider_abc import ServiceProviderABC
def __not_implemented__(package: str, func: Callable):
@@ -16,12 +17,12 @@ class ApplicationABC(ABC):
r"""ABC for the Application class
Parameters:
services: :class:`cpl.dependency.service_provider.ServiceProvider`
services: :class:`cpl.dependency.service_provider_abc.ServiceProviderABC`
Contains instances of prepared objects
"""
@abstractmethod
def __init__(self, services: ServiceProvider, required_modules: list[str | object] = None):
def __init__(self, services: ServiceProviderABC, required_modules: list[str | object] = None):
self._services = services
self._required_modules = (
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []

View File

@@ -1,10 +1,10 @@
from abc import ABC, abstractmethod
from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency import ServiceProviderABC
class ApplicationExtensionABC(ABC):
@staticmethod
@abstractmethod
def run(services: ServiceProvider): ...
def run(services: ServiceProviderABC): ...

View File

@@ -7,7 +7,6 @@ from cpl.application.abc.startup_abc import StartupABC
from cpl.application.abc.startup_extension_abc import StartupExtensionABC
from cpl.application.host import Host
from cpl.core.errors import dependency_error
from cpl.dependency.context import get_provider, use_root_provider
from cpl.dependency.service_collection import ServiceCollection
TApp = TypeVar("TApp", bound=ApplicationABC)
@@ -22,7 +21,6 @@ class ApplicationBuilder(Generic[TApp]):
self._app = app if app is not None else ApplicationABC
self._services = ServiceCollection()
use_root_provider(self._services.build())
self._startup: Optional[StartupABC] = None
self._app_extensions: list[Type[ApplicationExtensionABC]] = []
@@ -36,12 +34,7 @@ class ApplicationBuilder(Generic[TApp]):
@property
def service_provider(self):
provider = get_provider()
if provider is None:
provider = self._services.build()
use_root_provider(provider)
return provider
return self._services.build()
def validate_app_required_modules(self, app: ApplicationABC):
for module in app.required_modules:

View File

@@ -6,7 +6,7 @@ from cpl.auth import permission as _permission
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin as _KeycloakAdmin
from cpl.auth.keycloak.keycloak_client import KeycloakClient as _KeycloakClient
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
from .logger import AuthLogger
from .auth_logger import AuthLogger
from .keycloak_settings import KeycloakSettings
from .permission_seeder import PermissionSeeder

View File

@@ -0,0 +1,8 @@
from cpl.core.log import Logger
from cpl.core.typing import Source
class AuthLogger(Logger):
def __init__(self, source: Source):
Logger.__init__(self, source, "auth")

View File

@@ -1,13 +1,15 @@
from keycloak import KeycloakAdmin as _KeycloakAdmin, KeycloakOpenIDConnection
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.keycloak_settings import KeycloakSettings
from cpl.auth.logger import AuthLogger
_logger = AuthLogger("keycloak")
class KeycloakAdmin(_KeycloakAdmin):
def __init__(self, logger: AuthLogger, settings: KeycloakSettings):
# logger.info("Initializing Keycloak admin")
def __init__(self, settings: KeycloakSettings):
_logger.info("Initializing Keycloak admin")
_connection = KeycloakOpenIDConnection(
server_url=settings.url,
client_id=settings.client_id,

View File

@@ -2,13 +2,15 @@ from typing import Optional
from keycloak import KeycloakOpenID
from cpl.auth.logger import AuthLogger
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.keycloak_settings import KeycloakSettings
_logger = AuthLogger("keycloak")
class KeycloakClient(KeycloakOpenID):
def __init__(self, logger: AuthLogger, settings: KeycloakSettings):
def __init__(self, settings: KeycloakSettings):
KeycloakOpenID.__init__(
self,
server_url=settings.url,
@@ -16,7 +18,7 @@ class KeycloakClient(KeycloakOpenID):
realm_name=settings.realm,
client_secret_key=settings.client_secret,
)
logger.info("Initializing Keycloak client")
_logger.info("Initializing Keycloak client")
def get_user_id(self, token: str) -> Optional[str]:
info = self.introspect(token)

View File

@@ -1,5 +1,5 @@
from cpl.core.utils.get_value import get_value
from cpl.dependency import ServiceProvider
from cpl.dependency import ServiceProviderABC
class KeycloakUser:
@@ -32,5 +32,5 @@ class KeycloakUser:
def id(self) -> str:
from cpl.auth import KeycloakAdmin
keycloak_admin: KeycloakAdmin = get_provider().get_service(KeycloakAdmin)
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
return keycloak_admin.get_user_id(self._username)

View File

@@ -1,7 +0,0 @@
from cpl.core.log.wrapped_logger import WrappedLogger
class AuthLogger(WrappedLogger):
def __init__(self):
WrappedLogger.__init__(self, "auth")

View File

@@ -14,13 +14,14 @@ from cpl.auth.schema import (
)
from cpl.core.utils.get_value import get_value
from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.logger import DBLogger
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class PermissionSeeder(DataSeederABC):
def __init__(
self,
logger: DBLogger,
permission_dao: PermissionDao,
role_dao: RoleDao,
role_permission_dao: RolePermissionDao,
@@ -28,7 +29,6 @@ class PermissionSeeder(DataSeederABC):
api_key_permission_dao: ApiKeyPermissionDao,
):
DataSeederABC.__init__(self)
self._logger = logger
self._permission_dao = permission_dao
self._role_dao = role_dao
self._role_permission_dao = role_permission_dao
@@ -40,7 +40,7 @@ class PermissionSeeder(DataSeederABC):
possible_permissions = [permission for permission in PermissionsRegistry.get()]
if len(permissions) == len(possible_permissions):
self._logger.info("Permissions already existing")
_logger.info("Permissions already existing")
await self._update_missing_descriptions()
return
@@ -53,7 +53,7 @@ class PermissionSeeder(DataSeederABC):
await self._permission_dao.delete_many(to_delete, hard_delete=True)
self._logger.warning("Permissions incomplete")
_logger.warning("Permissions incomplete")
permission_names = [permission.name for permission in permissions]
await self._permission_dao.create_many(
[

View File

@@ -10,8 +10,7 @@ from cpl.core.log.logger import Logger
from cpl.core.typing import Id, SerialId
from cpl.core.utils.credential_manager import CredentialManager
from cpl.database.abc.db_model_abc import DbModelABC
from cpl.dependency import get_provider
from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.service_provider_abc import ServiceProviderABC
_logger = Logger(__name__)
@@ -48,7 +47,7 @@ class ApiKey(DbModelABC):
async def permissions(self):
from cpl.auth.schema._permission.api_key_permission_dao import ApiKeyPermissionDao
apiKeyPermissionDao = get_provider().get_service(ApiKeyPermissionDao)
apiKeyPermissionDao = ServiceProviderABC.get_global_provider().get_service(ApiKeyPermissionDao)
return [await x.permission for x in await apiKeyPermissionDao.find_by_api_key_id(self.id)]

View File

@@ -3,12 +3,15 @@ from typing import Optional
from cpl.auth.schema._administration.api_key import ApiKey
from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class ApiKeyDao(DbModelDaoABC[ApiKey]):
def __init__(self):
DbModelDaoABC.__init__(self, ApiKey, TableManager.get("api_keys"))
DbModelDaoABC.__init__(self, __name__, ApiKey, TableManager.get("api_keys"))
self.attribute(ApiKey.identifier, str)
self.attribute(ApiKey.key, str, "keystring")

View File

@@ -6,11 +6,13 @@ from async_property import async_property
from keycloak import KeycloakGetError
from cpl.auth.keycloak import KeycloakAdmin
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.permission.permissions import Permissions
from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC
from cpl.database.logger import DBLogger
from cpl.dependency import ServiceProvider
from cpl.dependency import ServiceProviderABC
_logger = AuthLogger(__name__)
class AuthUser(DbModelABC):
@@ -36,13 +38,12 @@ class AuthUser(DbModelABC):
return "ANONYMOUS"
try:
keycloak = get_provider().get_service(KeycloakAdmin)
return keycloak.get_user(self._keycloak_id).get("username")
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
return keycloak_admin.get_user(self._keycloak_id).get("username")
except KeycloakGetError as e:
return "UNKNOWN"
except Exception as e:
logger = get_provider().get_service(DBLogger)
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
_logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
return "UNKNOWN"
@property
@@ -51,39 +52,38 @@ class AuthUser(DbModelABC):
return "ANONYMOUS"
try:
keycloak = get_provider().get_service(KeycloakAdmin)
return keycloak.get_user(self._keycloak_id).get("email")
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
return keycloak_admin.get_user(self._keycloak_id).get("email")
except KeycloakGetError as e:
return "UNKNOWN"
except Exception as e:
logger = get_provider().get_service(DBLogger)
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
_logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
return "UNKNOWN"
@async_property
async def roles(self):
from cpl.auth.schema._permission.role_user_dao import RoleUserDao
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao)
role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
return [await x.role for x in await role_user_dao.get_by_user_id(self.id)]
@async_property
async def permissions(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
return await auth_user_dao.get_permissions(self.id)
async def has_permission(self, permission: Permissions) -> bool:
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
return await auth_user_dao.has_permission(self.id, permission)
async def anonymize(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
self._keycloak_id = str(uuid.UUID(int=0))
await auth_user_dao.update(self)

View File

@@ -4,14 +4,17 @@ from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema._administration.auth_user import AuthUser
from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.dependency import ServiceProvider
from cpl.dependency import ServiceProviderABC
_logger = DBLogger(__name__)
class AuthUserDao(DbModelDaoABC[AuthUser]):
def __init__(self):
DbModelDaoABC.__init__(self, AuthUser, TableManager.get("auth_users"))
DbModelDaoABC.__init__(self, __name__, AuthUser, TableManager.get("auth_users"))
self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId")
@@ -36,7 +39,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool:
from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
result = await self._db.select_map(
f"""

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId
from cpl.database.abc import DbJoinModelABC
from cpl.dependency import ServiceProvider
from cpl.dependency import ServiceProviderABC
class ApiKeyPermission(DbJoinModelABC):
@@ -31,7 +31,7 @@ class ApiKeyPermission(DbJoinModelABC):
async def api_key(self):
from cpl.auth.schema._administration.api_key_dao import ApiKeyDao
api_key_dao: ApiKeyDao = get_provider().get_service(ApiKeyDao)
api_key_dao: ApiKeyDao = ServiceProviderABC.get_global_service(ApiKeyDao)
return await api_key_dao.get_by_id(self._api_key_id)
@property
@@ -42,5 +42,5 @@ class ApiKeyPermission(DbJoinModelABC):
async def permission(self):
from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
return await permission_dao.get_by_id(self._permission_id)

View File

@@ -1,12 +1,15 @@
from cpl.auth.schema._permission.api_key_permission import ApiKeyPermission
from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class ApiKeyPermissionDao(DbModelDaoABC[ApiKeyPermission]):
def __init__(self):
DbModelDaoABC.__init__(self, ApiKeyPermission, TableManager.get("api_key_permissions"))
DbModelDaoABC.__init__(self, __name__, ApiKeyPermission, TableManager.get("api_key_permissions"))
self.attribute(ApiKeyPermission.api_key_id, int)
self.attribute(ApiKeyPermission.permission_id, int)

View File

@@ -3,12 +3,15 @@ from typing import Optional
from cpl.auth.schema._permission.permission import Permission
from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class PermissionDao(DbModelDaoABC[Permission]):
def __init__(self):
DbModelDaoABC.__init__(self, Permission, TableManager.get("permissions"))
DbModelDaoABC.__init__(self, __name__, Permission, TableManager.get("permissions"))
self.attribute(Permission.name, str)
self.attribute(Permission.description, Optional[str])

View File

@@ -6,7 +6,7 @@ from async_property import async_property
from cpl.auth.permission.permissions import Permissions
from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProvider
from cpl.dependency import ServiceProviderABC
class Role(DbModelABC):
@@ -44,22 +44,22 @@ class Role(DbModelABC):
async def permissions(self):
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao)
role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
return [await x.permission for x in await role_permission_dao.get_by_role_id(self.id)]
@async_property
async def users(self):
from cpl.auth.schema._permission.role_user_dao import RoleUserDao
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao)
role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
return [await x.user for x in await role_user_dao.get_by_role_id(self.id)]
async def has_permission(self, permission: Permissions) -> bool:
from cpl.auth.schema._permission.permission_dao import PermissionDao
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao)
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
p = await permission_dao.get_by_name(permission.value)

View File

@@ -1,11 +1,14 @@
from cpl.auth.schema._permission.role import Role
from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class RoleDao(DbModelDaoABC[Role]):
def __init__(self):
DbModelDaoABC.__init__(self, Role, TableManager.get("roles"))
DbModelDaoABC.__init__(self, __name__, Role, TableManager.get("roles"))
self.attribute(Role.name, str)
self.attribute(Role.description, str)

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProvider
from cpl.dependency import ServiceProviderABC
class RolePermission(DbModelABC):
@@ -31,7 +31,7 @@ class RolePermission(DbModelABC):
async def role(self):
from cpl.auth.schema._permission.role_dao import RoleDao
role_dao: RoleDao = get_provider().get_service(RoleDao)
role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
return await role_dao.get_by_id(self._role_id)
@property
@@ -42,5 +42,5 @@ class RolePermission(DbModelABC):
async def permission(self):
from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
return await permission_dao.get_by_id(self._permission_id)

View File

@@ -1,12 +1,15 @@
from cpl.auth.schema._permission.role_permission import RolePermission
from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class RolePermissionDao(DbModelDaoABC[RolePermission]):
def __init__(self):
DbModelDaoABC.__init__(self, RolePermission, TableManager.get("role_permissions"))
DbModelDaoABC.__init__(self, __name__, RolePermission, TableManager.get("role_permissions"))
self.attribute(RolePermission.role_id, int)
self.attribute(RolePermission.permission_id, int)

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId
from cpl.database.abc import DbJoinModelABC
from cpl.dependency import ServiceProvider
from cpl.dependency import ServiceProviderABC
class RoleUser(DbJoinModelABC):
@@ -31,7 +31,7 @@ class RoleUser(DbJoinModelABC):
async def user(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
return await auth_user_dao.get_by_id(self._user_id)
@property
@@ -42,5 +42,5 @@ class RoleUser(DbJoinModelABC):
async def role(self):
from cpl.auth.schema._permission.role_dao import RoleDao
role_dao: RoleDao = get_provider().get_service(RoleDao)
role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
return await role_dao.get_by_id(self._role_id)

View File

@@ -1,12 +1,15 @@
from cpl.auth.schema._permission.role_user import RoleUser
from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC
from cpl.database.db_logger import DBLogger
_logger = DBLogger(__name__)
class RoleUserDao(DbModelDaoABC[RoleUser]):
def __init__(self):
DbModelDaoABC.__init__(self, RoleUser, TableManager.get("role_users"))
DbModelDaoABC.__init__(self, __name__, RoleUser, TableManager.get("role_users"))
self.attribute(RoleUser.role_id, int)
self.attribute(RoleUser.user_id, int)

View File

@@ -1,18 +1,17 @@
from contextvars import ContextVar
from typing import Optional
from cpl.auth.auth_logger import AuthLogger
from cpl.auth.schema._administration.auth_user import AuthUser
from cpl.dependency import get_provider
_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None)
_logger = AuthLogger(__name__)
def set_user(user: Optional[AuthUser]):
from cpl.core.log.logger_abc import LoggerABC
logger = get_provider().get_service(LoggerABC)
logger.trace("Setting user context", user.id)
_user_context.set(user)
def set_user(user_id: Optional[AuthUser]):
_logger.trace("Setting user context", user_id)
_user_context.set(user_id)
def get_user() -> Optional[AuthUser]:

View File

@@ -2,4 +2,3 @@ from .logger import Logger
from .logger_abc import LoggerABC
from .log_level import LogLevel
from .log_settings import LogSettings
from .structured_logger import StructuredLogger

View File

@@ -1,111 +0,0 @@
import asyncio
import importlib.util
import json
import traceback
from datetime import datetime
from starlette.requests import Request
from cpl.core.log.log_level import LogLevel
from cpl.core.log.logger import Logger
from cpl.core.typing import Source, Messages
from cpl.dependency import get_provider
class StructuredLogger(Logger):
def __init__(self, source: Source, file_prefix: str = None):
Logger.__init__(self, source, file_prefix)
@property
def log_file(self):
return f"logs/{self._file_prefix}_{datetime.now().strftime('%Y-%m-%d')}.jsonl"
def _log(self, level: LogLevel, *messages: Messages):
try:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
formatted_message = self._format_message(level.value, timestamp, *messages)
structured_message = self._get_structured_message(level.value, timestamp, formatted_message)
self._write_log_to_file(level, structured_message)
self._write_to_console(level, formatted_message)
except Exception as e:
print(f"Error while logging: {e} -> {traceback.format_exc()}")
def _get_structured_message(self, level: str, timestamp: str, messages: str) -> str:
structured_message = {
"timestamp": timestamp,
"level": level.upper(),
"source": self._source,
"messages": messages,
}
self._enrich_message_with_request(structured_message)
self._enrich_message_with_user(structured_message)
return json.dumps(structured_message, ensure_ascii=False)
@staticmethod
def _scope_to_json(request: Request, include_headers: bool = False) -> dict:
scope = dict(request.scope)
def convert(value):
if isinstance(value, bytes):
return value.decode("utf-8")
if isinstance(value, (list, tuple)):
return [convert(v) for v in value]
if isinstance(value, dict):
return {str(k): convert(v) for k, v in value.items()}
if not isinstance(value, (str, int, float, bool, type(None))):
return str(value)
return value
serializable_scope = {str(k): convert(v) for k, v in scope.items()}
if not include_headers and "headers" in serializable_scope:
serializable_scope["headers"] = "<omitted>"
return serializable_scope
def _enrich_message_with_request(self, message: dict):
if importlib.util.find_spec("cpl.api") is None:
return
from cpl.api.middleware.request import get_request
from starlette.requests import Request
request = get_request()
if request is None:
return
message["request"] = {
"url": str(request.url),
"method": request.method,
"scope": self._scope_to_json(request),
}
if isinstance(request, Request) and request.scope == "http":
request: Request = request # fix typing for IDEs
message["request"]["data"] = asyncio.create_task(request.body())
@staticmethod
def _enrich_message_with_user(message: dict):
if importlib.util.find_spec("cpl-auth") is None:
return
from cpl.core.ctx import get_user
user = get_user()
if user is None:
return
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin
keycloak = get_provider().get_service(KeycloakAdmin)
kc_user = keycloak.get_user(user.keycloak_id)
message["user"] = {
"id": str(user.id),
"username": kc_user.get("username"),
"email": kc_user.get("email"),
}

View File

@@ -1,101 +0,0 @@
import inspect
from typing import Type
from cpl.core.log import LoggerABC, LogLevel
from cpl.core.typing import Messages
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
class WrappedLogger(LoggerABC):
def __init__(self, file_prefix: str):
LoggerABC.__init__(self)
assert file_prefix is not None and file_prefix != "", "file_prefix must be a non-empty string"
self._source = None
self._file_prefix = file_prefix
self._set_logger()
@inject
def _set_logger(self, services: ServiceProvider):
from cpl.core.log import Logger
t_logger: Type[Logger] = services.get_service_type(LoggerABC)
if t_logger is None:
raise Exception("No LoggerABC service registered in ServiceProvider")
self._logger = t_logger(self._source, self._file_prefix)
def set_level(self, level: LogLevel):
self._logger.set_level(level)
def _format_message(self, level: str, timestamp, *messages: Messages) -> str:
return self._logger._format_message(level, timestamp, *messages)
@staticmethod
def _get_source() -> str | None:
stack = inspect.stack()
if len(stack) <= 1:
return None
from cpl.dependency import ServiceCollection
ignore_classes = [
ServiceProvider,
ServiceProvider.__subclasses__(),
ServiceCollection,
WrappedLogger,
WrappedLogger.__subclasses__(),
]
ignore_modules = [x.__module__ for x in ignore_classes if isinstance(x, type)]
for i, frame_info in enumerate(stack[1:]):
module = inspect.getmodule(frame_info.frame)
if module is None:
continue
if module.__name__ in ignore_classes or module in ignore_classes:
continue
if module in ignore_modules or module.__name__ in ignore_modules:
continue
if module.__name__ != __name__:
return module.__name__
return None
def _set_source(self):
self._source = self._get_source()
self._set_logger()
def header(self, string: str):
self._set_source()
self._logger.header(string)
def trace(self, *messages: Messages):
self._set_source()
self._logger.trace(*messages)
def debug(self, *messages: Messages):
self._set_source()
self._logger.debug(*messages)
def info(self, *messages: Messages):
self._set_source()
self._logger.info(*messages)
def warning(self, *messages: Messages):
self._set_source()
self._logger.warning(*messages)
def error(self, messages: str, e: Exception = None):
self._set_source()
self._logger.error(messages, e)
def fatal(self, messages: str, e: Exception = None):
self._set_source()
self._logger.fatal(messages, e)

View File

@@ -14,4 +14,3 @@ UuidId = str | UUID
SerialId = int
Id = UuidId | SerialId
TNumber = int | float | complex

View File

@@ -1,57 +0,0 @@
import time
import tracemalloc
from typing import List, Callable
from cpl.core.console import Console
class Benchmark:
@staticmethod
def all(label: str, func: Callable, iterations: int = 5):
times: List[float] = []
mems: List[float] = []
for _ in range(iterations):
start = time.perf_counter()
func()
end = time.perf_counter()
times.append(end - start)
for _ in range(iterations):
tracemalloc.start()
func()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
mems.append(peak)
avg_time = sum(times) / len(times)
avg_mem = sum(mems) / len(mems) / (1024 * 1024)
Console.write_line(f"{label:20s} -> min {min(times):.6f}s avg {avg_time:.6f}s mem {avg_mem:.8f} MB")
@staticmethod
def time(label: str, func: Callable, iterations: int = 5):
times: List[float] = []
for _ in range(iterations):
start = time.perf_counter()
func()
end = time.perf_counter()
times.append(end - start)
avg_time = sum(times) / len(times)
Console.write_line(f"{label:20s} -> min {min(times):.6f}s avg {avg_time:.6f}s")
@staticmethod
def memory(label: str, func: Callable, iterations: int = 5):
mems: List[float] = []
for _ in range(iterations):
tracemalloc.start()
func()
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
mems.append(peak)
avg_mem = sum(mems) / len(mems) / (1024 * 1024)
Console.write_line(f"{label:20s} -> mem {avg_mem:.2f} MB")

View File

@@ -1,100 +0,0 @@
import threading
import time
from typing import Generic
from cpl.core.typing import T
class Cache(Generic[T]):
def __init__(self, default_ttl: int = None, cleanup_interval: int = 60, t: type = None):
self._store = {}
self._default_ttl = default_ttl
self._lock = threading.Lock()
self._cleanup_interval = cleanup_interval
self._stop_event = threading.Event()
self._type = t
# Start background cleanup thread
self._thread = threading.Thread(target=self._auto_cleanup, daemon=True)
self._thread.start()
def set(self, key: str, value: T, ttl: int = None) -> None:
"""Store a value in the cache with optional TTL override."""
expire_at = None
ttl = ttl if ttl is not None else self._default_ttl
if ttl is not None:
expire_at = time.time() + ttl
with self._lock:
self._store[key] = (value, expire_at)
def get(self, key: str) -> T | None:
"""Retrieve a value from the cache if not expired."""
with self._lock:
item = self._store.get(key)
if not item:
return None
value, expire_at = item
if expire_at and expire_at < time.time():
# Expired -> remove and return None
del self._store[key]
return None
return value
def get_all(self) -> list[T]:
"""Retrieve all non-expired values from the cache."""
now = time.time()
with self._lock:
valid_items = []
expired_keys = []
for k, (v, exp) in self._store.items():
if exp and exp < now:
expired_keys.append(k)
else:
valid_items.append(v)
for k in expired_keys:
del self._store[k]
return valid_items
def has(self, key: str) -> bool:
"""Check if a key exists and is not expired."""
with self._lock:
item = self._store.get(key)
if not item:
return False
_, expire_at = item
if expire_at and expire_at < time.time():
# Expired -> remove and return False
del self._store[key]
return False
return True
def delete(self, key: str) -> None:
"""Remove an item from the cache."""
with self._lock:
self._store.pop(key, None)
def clear(self) -> None:
"""Clear the entire cache."""
with self._lock:
self._store.clear()
def _auto_cleanup(self):
"""Background thread to clean expired items."""
while not self._stop_event.is_set():
self.cleanup()
self._stop_event.wait(self._cleanup_interval)
def cleanup(self) -> None:
"""Remove expired items immediately."""
now = time.time()
with self._lock:
expired_keys = [k for k, (_, exp) in self._store.items() if exp and exp < now]
for k in expired_keys:
del self._store[k]
def stop(self):
"""Stop the background cleanup thread."""
self._stop_event.set()
self._thread.join()

View File

@@ -1,48 +0,0 @@
from typing import Any
class Number:
@staticmethod
def is_number(value: Any) -> bool:
"""Check if the value is a number (int or float)."""
return isinstance(value, (int, float, complex))
@staticmethod
def to_number(value: Any) -> int | float | complex:
"""
Convert a given value into int, float, or complex.
Raises ValueError if conversion is not possible.
"""
if isinstance(value, (int, float, complex)):
return value
if isinstance(value, str):
value = value.strip()
for caster in (int, float, complex):
try:
return caster(value)
except ValueError:
continue
raise ValueError(f"Cannot convert string '{value}' to number.")
if isinstance(value, bool):
return int(value)
try:
return int(value)
except Exception:
pass
try:
return float(value)
except Exception:
pass
try:
return complex(value)
except Exception:
pass
raise ValueError(f"Cannot convert type {type(value)} to number.")

View File

@@ -114,15 +114,12 @@ class String:
characters = []
if letters:
characters.extend(string.ascii_letters)
characters.append(string.ascii_letters)
if digits:
characters.extend(string.digits)
characters.append(string.digits)
if special_characters:
characters.extend(string.punctuation)
characters.append(string.punctuation)
x = "".join(random.choice(list(characters)) for _ in range(length)) if characters else ""
if len(x) != length:
raise Exception("No characters selected to generate random string")
return x
return "".join(random.choice(characters) for _ in range(length)) if characters else ""

View File

@@ -9,19 +9,25 @@ from cpl.core.utils.get_value import get_value
from cpl.core.utils.string import String
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.const import DATETIME_FORMAT
from cpl.database.db_logger import DBLogger
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.database.logger import DBLogger
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
from cpl.dependency import get_provider
class DataAccessObjectABC(ABC, Generic[T_DBM]):
@abstractmethod
def __init__(self, model_type: Type[T_DBM], table_name: str):
self._db = get_provider().get_service(DBContextABC)
self._logger = get_provider().get_service(DBLogger)
def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
from cpl.dependency.service_provider_abc import ServiceProviderABC
self._db = ServiceProviderABC.get_global_service(DBContextABC)
self._logger = DBLogger(source)
self._model_type = model_type
self._table_name = table_name
self._logger = DBLogger(source)
self._model_type = model_type
self._table_name = table_name

View File

@@ -10,8 +10,8 @@ from cpl.database.abc.db_model_abc import DbModelABC
class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]):
@abstractmethod
def __init__(self, model_type: Type[T_DBM], table_name: str):
DataAccessObjectABC.__init__(self, model_type, table_name)
def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
DataAccessObjectABC.__init__(self, source, model_type, table_name)
self.attribute(DbModelABC.id, int, ignore=True)
self.attribute(DbModelABC.deleted, bool)

View File

@@ -0,0 +1,8 @@
from cpl.core.log import Logger
from cpl.core.typing import Source
class DBLogger(Logger):
def __init__(self, source: Source):
Logger.__init__(self, source, "db")

View File

@@ -1,7 +0,0 @@
from cpl.core.log.wrapped_logger import WrappedLogger
class DBLogger(WrappedLogger):
def __init__(self):
WrappedLogger.__init__(self, "db")

View File

@@ -4,17 +4,18 @@ from typing import Any, List, Dict, Tuple, Union
from mysql.connector import Error as MySQLError, PoolError
from cpl.core.configuration import Configuration
from cpl.core.environment import Environment
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.logger import DBLogger
from cpl.database.db_logger import DBLogger
from cpl.database.model.database_settings import DatabaseSettings
from cpl.database.mysql.mysql_pool import MySQLPool
_logger = DBLogger(__name__)
class DBContext(DBContextABC):
def __init__(self, logger: DBLogger):
def __init__(self):
DBContextABC.__init__(self)
self._logger = logger
self._pool: MySQLPool = None
self._fails = 0
@@ -22,62 +23,62 @@ class DBContext(DBContextABC):
def connect(self, database_settings: DatabaseSettings):
try:
self._logger.debug("Connecting to database")
_logger.debug("Connecting to database")
self._pool = MySQLPool(
database_settings,
)
self._logger.info("Connected to database")
_logger.info("Connected to database")
except Exception as e:
self._logger.fatal("Connecting to database failed", e)
_logger.fatal("Connecting to database failed", e)
async def execute(self, statement: str, args=None, multi=True) -> List[List]:
self._logger.trace(f"execute {statement} with args: {args}")
_logger.trace(f"execute {statement} with args: {args}")
return await self._pool.execute(statement, args, multi)
async def select_map(self, statement: str, args=None) -> List[Dict]:
self._logger.trace(f"select {statement} with args: {args}")
_logger.trace(f"select {statement} with args: {args}")
try:
return await self._pool.select_map(statement, args)
except (MySQLError, PoolError) as e:
if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4()
raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
)
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1
try:
self._logger.debug("Retry select")
_logger.debug("Retry select")
return await self.select_map(statement, args)
except Exception as e:
pass
return []
except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
raise e
async def select(self, statement: str, args=None) -> Union[List[str], List[Tuple], List[Any]]:
self._logger.trace(f"select {statement} with args: {args}")
_logger.trace(f"select {statement} with args: {args}")
try:
return await self._pool.select(statement, args)
except (MySQLError, PoolError) as e:
if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4()
raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
)
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1
try:
self._logger.debug("Retry select")
_logger.debug("Retry select")
return await self.select(statement, args)
except Exception as e:
pass
return []
except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
raise e

View File

@@ -4,9 +4,10 @@ import sqlparse
from mysql.connector.aio import MySQLConnectionPool
from cpl.core.environment import Environment
from cpl.database.logger import DBLogger
from cpl.database.db_logger import DBLogger
from cpl.database.model import DatabaseSettings
from cpl.dependency import ServiceProvider
_logger = DBLogger(__name__)
class MySQLPool:
@@ -35,8 +36,7 @@ class MySQLPool:
await cursor.execute("SELECT 1")
await cursor.fetchall()
except Exception as e:
logger = get_provider().get_service(DBLogger)
logger.fatal(f"Error connecting to the database: {e}")
_logger.fatal(f"Error connecting to the database: {e}")
finally:
await con.close()

View File

@@ -7,16 +7,16 @@ from psycopg_pool import PoolTimeout
from cpl.core.configuration import Configuration
from cpl.core.environment import Environment
from cpl.database.abc.db_context_abc import DBContextABC
from cpl.database.logger import DBLogger
from cpl.database.model import DatabaseSettings
from cpl.database.database_settings import DatabaseSettings
from cpl.database.db_logger import DBLogger
from cpl.database.postgres.postgres_pool import PostgresPool
_logger = DBLogger(__name__)
class DBContext(DBContextABC):
def __init__(self, logger: DBLogger):
def __init__(self):
DBContextABC.__init__(self)
self._logger = logger
self._pool: PostgresPool = None
self._fails = 0
@@ -24,63 +24,63 @@ class DBContext(DBContextABC):
def connect(self, database_settings: DatabaseSettings):
try:
self._logger.debug("Connecting to database")
_logger.debug("Connecting to database")
self._pool = PostgresPool(
database_settings,
Environment.get("DB_POOL_SIZE", int, 1),
)
self._logger.info("Connected to database")
_logger.info("Connected to database")
except Exception as e:
self._logger.fatal("Connecting to database failed", e)
_logger.fatal("Connecting to database failed", e)
async def execute(self, statement: str, args=None, multi=True) -> list[list]:
self._logger.trace(f"execute {statement} with args: {args}")
_logger.trace(f"execute {statement} with args: {args}")
return await self._pool.execute(statement, args, multi)
async def select_map(self, statement: str, args=None) -> list[dict]:
self._logger.trace(f"select {statement} with args: {args}")
_logger.trace(f"select {statement} with args: {args}")
try:
return await self._pool.select_map(statement, args)
except (OperationalError, PoolTimeout) as e:
if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4()
raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
)
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1
try:
self._logger.debug("Retry select")
_logger.debug("Retry select")
return await self.select_map(statement, args)
except Exception as e:
pass
return []
except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
raise e
async def select(self, statement: str, args=None) -> list[str] | list[tuple] | list[Any]:
self._logger.trace(f"select {statement} with args: {args}")
_logger.trace(f"select {statement} with args: {args}")
try:
return await self._pool.select(statement, args)
except (OperationalError, PoolTimeout) as e:
if self._fails >= 3:
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
uid = uuid.uuid4()
raise Exception(
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
)
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
self._fails += 1
try:
self._logger.debug("Retry select")
_logger.debug("Retry select")
return await self.select(statement, args)
except Exception as e:
pass
return []
except Exception as e:
self._logger.error(f"Database error caused by `{statement}`", e)
_logger.error(f"Database error caused by `{statement}`", e)
raise e

View File

@@ -5,9 +5,10 @@ from psycopg import sql
from psycopg_pool import AsyncConnectionPool, PoolTimeout
from cpl.core.environment import Environment
from cpl.database.logger import DBLogger
from cpl.database.db_logger import DBLogger
from cpl.database.model import DatabaseSettings
from cpl.dependency import ServiceProvider
_logger = DBLogger(__name__)
class PostgresPool:
@@ -37,8 +38,7 @@ class PostgresPool:
await pool.check_connection(con)
except PoolTimeout as e:
await pool.close()
logger = get_provider().get_service(DBLogger)
logger.fatal(f"Failed to connect to the database", e)
_logger.fatal(f"Failed to connect to the database", e)
self._pool = pool
return self._pool

View File

@@ -1,11 +1,14 @@
from cpl.database import TableManager
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.database.db_logger import DBLogger
from cpl.database.schema.executed_migration import ExecutedMigration
_logger = DBLogger(__name__)
class ExecutedMigrationDao(DataAccessObjectABC[ExecutedMigration]):
def __init__(self):
DataAccessObjectABC.__init__(self, ExecutedMigration, TableManager.get("executed_migrations"))
DataAccessObjectABC.__init__(self, __name__, ExecutedMigration, TableManager.get("executed_migrations"))
self.attribute(ExecutedMigration.migration_id, str, primary_key=True, db_name="migrationId")

View File

@@ -2,17 +2,18 @@ import glob
import os
from cpl.database.abc import DBContextABC
from cpl.database.logger import DBLogger
from cpl.database.db_logger import DBLogger
from cpl.database.model import Migration
from cpl.database.model.server_type import ServerType, ServerTypes
from cpl.database.schema.executed_migration import ExecutedMigration
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
_logger = DBLogger(__name__)
class MigrationService:
def __init__(self, logger: DBLogger, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao):
self._logger = logger
def __init__(self, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao):
self._db = db
self._executedMigrationDao = executedMigrationDao
@@ -95,13 +96,13 @@ class MigrationService:
if migration_from_db is not None:
continue
self._logger.debug(f"Running upgrade migration: {migration.name}")
_logger.debug(f"Running upgrade migration: {migration.name}")
await self._db.execute(migration.script, multi=True)
await self._executedMigrationDao.create(ExecutedMigration(migration.name), skip_editor=True)
except Exception as e:
self._logger.fatal(
_logger.fatal(
f"Migration failed: {migration.name}\n{active_statement}",
e,
)

View File

@@ -1,16 +1,18 @@
from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.logger import DBLogger
from cpl.dependency import ServiceProvider
from cpl.database.db_logger import DBLogger
from cpl.dependency import ServiceProviderABC
_logger = DBLogger(__name__)
class SeederService:
def __init__(self, provider: ServiceProvider):
def __init__(self, provider: ServiceProviderABC):
self._provider = provider
self._logger = provider.get_service(DBLogger)
async def seed(self):
seeders = self._provider.get_services(DataSeederABC)
self._logger.debug(f"Found {len(seeders)} seeders")
_logger.debug(f"Found {len(seeders)} seeders")
for seeder in seeders:
await seeder.seed()

View File

@@ -1,7 +1,7 @@
from .context import get_provider, use_provider
from .inject import inject
from .scope import Scope
from .scope_abc import ScopeABC
from .service_collection import ServiceCollection
from .service_descriptor import ServiceDescriptor
from .service_lifetime import ServiceLifetimeEnum
from .service_provider import ServiceProvider
from .service_lifetime_enum import ServiceLifetimeEnum
from .service_provider import ServiceProvider
from .service_provider_abc import ServiceProviderABC

View File

@@ -1,21 +0,0 @@
import contextvars
from contextlib import contextmanager
_current_provider = contextvars.ContextVar("current_provider", default=None)
def use_root_provider(provider):
_current_provider.set(provider)
@contextmanager
def use_provider(provider):
token = _current_provider.set(provider)
try:
yield
finally:
_current_provider.reset(token)
def get_provider():
return _current_provider.get()

View File

@@ -1,42 +0,0 @@
import functools
from asyncio import iscoroutinefunction
from inspect import signature
from cpl.dependency.context import get_provider
def inject(f=None):
if f is None:
return functools.partial(inject)
if iscoroutinefunction(f):
@functools.wraps(f)
async def async_inner(*args, **kwargs):
from cpl.dependency.service_provider import ServiceProvider
provider: ServiceProvider | None = get_provider()
if provider is None:
raise ValueError(
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
)
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
return await f(*args, *injection, **kwargs)
return async_inner
@functools.wraps(f)
def inner(*args, **kwargs):
from cpl.dependency.service_provider import ServiceProvider
provider: ServiceProvider | None = get_provider()
if provider is None:
raise ValueError(
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
)
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
return f(*args, *injection, **kwargs)
return inner

View File

@@ -0,0 +1,22 @@
from cpl.dependency.scope_abc import ScopeABC
from cpl.dependency.service_provider_abc import ServiceProviderABC
class Scope(ScopeABC):
def __init__(self, service_provider: ServiceProviderABC):
self._service_provider = service_provider
self._service_provider.set_scope(self)
ScopeABC.__init__(self)
def __enter__(self):
return self
def __exit__(self, *args):
self.dispose()
@property
def service_provider(self) -> ServiceProviderABC:
return self._service_provider
def dispose(self):
self._service_provider = None

View File

@@ -0,0 +1,20 @@
from abc import ABC, abstractmethod
class ScopeABC(ABC):
r"""ABC for the class :class:`cpl.dependency.scope.Scope`"""
def __init__(self): ...
@property
@abstractmethod
def service_provider(self):
r"""Returns to service provider of scope
Returns:
Object of type :class:`cpl.dependency.service_provider_abc.ServiceProviderABC`
"""
@abstractmethod
def dispose(self):
r"""Sets service_provider to None"""

View File

@@ -0,0 +1,18 @@
from cpl.dependency.scope import Scope
from cpl.dependency.scope_abc import ScopeABC
from cpl.dependency.service_provider_abc import ServiceProviderABC
class ScopeBuilder:
r"""Class to build :class:`cpl.dependency.scope.Scope`"""
def __init__(self, service_provider: ServiceProviderABC) -> None:
self._service_provider = service_provider
def build(self) -> ScopeABC:
r"""Returns scope
Returns:
Object of type :class:`cpl.dependency.scope.Scope`
"""
return Scope(self._service_provider)

View File

@@ -1,11 +1,12 @@
from typing import Union, Type, Callable, Self
from cpl.core.log.logger import Logger
from cpl.core.log.logger_abc import LoggerABC
from cpl.core.typing import T, Service
from cpl.core.utils.cache import Cache
from cpl.dependency.service_descriptor import ServiceDescriptor
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.service_provider_abc import ServiceProviderABC
class ServiceCollection:
@@ -61,8 +62,9 @@ class ServiceCollection:
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service)
return self
def build(self) -> ServiceProvider:
def build(self) -> ServiceProviderABC:
sp = ServiceProvider(self._service_descriptors)
ServiceProviderABC.set_global_provider(sp)
return sp
def add_module(self, module: str | object) -> Self:
@@ -78,24 +80,5 @@ class ServiceCollection:
return self
def add_logging(self) -> Self:
from cpl.core.log.logger import Logger
from cpl.core.log.wrapped_logger import WrappedLogger
self.add_transient(LoggerABC, Logger)
for wrapper in WrappedLogger.__subclasses__():
self.add_transient(wrapper)
return self
def add_structured_logging(self) -> Self:
from cpl.core.log.structured_logger import StructuredLogger
from cpl.core.log.wrapped_logger import WrappedLogger
self.add_transient(LoggerABC, StructuredLogger)
for wrapper in WrappedLogger.__subclasses__():
self.add_transient(wrapper)
return self
def add_cache(self, t: Type[T]):
self._service_descriptors.append(ServiceDescriptor(Cache(t=t), ServiceLifetimeEnum.singleton, Cache[t]))
return self

View File

@@ -1,6 +1,6 @@
from typing import Union, Optional
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
class ServiceDescriptor:

View File

@@ -1,7 +0,0 @@
from enum import Enum, auto
class ServiceLifetimeEnum(Enum):
singleton = auto()
scoped = auto()
transient = auto()

View File

@@ -0,0 +1,7 @@
from enum import Enum
class ServiceLifetimeEnum(Enum):
singleton = 0
scoped = 1
transient = 2

View File

@@ -1,41 +1,44 @@
import copy
import typing
from contextlib import contextmanager
from inspect import signature, Parameter, Signature
from typing import Optional, Type
from typing import Optional
from cpl.core.configuration import Configuration
from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC
from cpl.core.environment import Environment
from cpl.core.typing import T, R, Source
from cpl.dependency import use_provider
from cpl.dependency.scope_abc import ScopeABC
from cpl.dependency.scope_builder import ScopeBuilder
from cpl.dependency.service_descriptor import ServiceDescriptor
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
from cpl.dependency.service_provider_abc import ServiceProviderABC
class ServiceProvider:
def __init__(self, service_descriptors: list[ServiceDescriptor], is_scope: bool = False):
class ServiceProvider(ServiceProviderABC):
r"""Provider for the services
Parameter
---------
service_descriptors: list[:class:`cpl.dependency.service_descriptor.ServiceDescriptor`]
Descriptor of the service
config: :class:`cpl.core.configuration.configuration_abc.ConfigurationABC`
CPL Configuration
db_context: Optional[:class:`cpl.database.context.database_context_abc.DatabaseContextABC`]
Database representation
"""
def __init__(
self,
service_descriptors: list[ServiceDescriptor],
):
ServiceProviderABC.__init__(self)
self._service_descriptors: list[ServiceDescriptor] = service_descriptors
self._is_scope = is_scope
self._scope: Optional[ScopeABC] = None
def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]:
origin_type = typing.get_origin(service_type) or service_type
type_args = list(typing.get_args(service_type))
for descriptor in self._service_descriptors:
descriptor_base_type = typing.get_origin(descriptor.base_type) or descriptor.base_type
descriptor_type_args = list(typing.get_args(descriptor.base_type))
if descriptor_base_type == origin_type and len(descriptor_type_args) == 0 and len(type_args) == 0:
return descriptor
if descriptor_base_type != origin_type or len(descriptor_type_args) != len(type_args):
continue
if descriptor_base_type == origin_type and type_args != descriptor_type_args:
continue
if descriptor.service_type == origin_type or issubclass(descriptor.base_type, origin_type):
if descriptor.service_type == service_type or issubclass(descriptor.base_type, service_type):
return descriptor
return None
@@ -49,7 +52,7 @@ class ServiceProvider:
return descriptor.implementation
implementation = self._build_service(descriptor.service_type, origin_service_type=origin_service_type)
if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped):
if descriptor.lifetime == ServiceLifetimeEnum.singleton:
descriptor.implementation = implementation
return implementation
@@ -67,7 +70,7 @@ class ServiceProvider:
implementation = self._build_service(
descriptor.service_type, origin_service_type=service_type, **kwargs
)
if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped):
if descriptor.lifetime == ServiceLifetimeEnum.singleton:
descriptor.implementation = implementation
implementations.append(implementation)
@@ -85,7 +88,7 @@ class ServiceProvider:
elif parameter.annotation == Source:
params.append(origin_service_type.__name__)
elif issubclass(parameter.annotation, ServiceProvider):
elif issubclass(parameter.annotation, ServiceProviderABC):
params.append(self)
elif issubclass(parameter.annotation, Environment):
@@ -113,27 +116,32 @@ class ServiceProvider:
service_type = type(descriptor.implementation)
else:
service_type = descriptor.service_type
break
sig = signature(service_type.__init__)
params = self._build_by_signature(sig, origin_service_type)
return service_type(*params, *args, **kwargs)
@contextmanager
def create_scope(self):
scoped_descriptors = []
for d in self._service_descriptors:
if d.lifetime == ServiceLifetimeEnum.singleton:
scoped_descriptors.append(d)
else:
scoped_descriptors.append(copy.deepcopy(d))
def set_scope(self, scope: ScopeABC):
self._scope = scope
scoped_provider = ServiceProvider(scoped_descriptors, is_scope=True)
with use_provider(scoped_provider):
yield scoped_provider
def create_scope(self) -> ScopeABC:
descriptors = []
for descriptor in self._service_descriptors:
if descriptor.lifetime == ServiceLifetimeEnum.singleton:
descriptors.append(descriptor)
else:
descriptors.append(copy.deepcopy(descriptor))
sb = ScopeBuilder(ServiceProvider(descriptors))
return sb.build()
def get_service(self, service_type: T, *args, **kwargs) -> Optional[R]:
result = self._find_service(service_type)
if result is None:
return None
@@ -141,30 +149,21 @@ class ServiceProvider:
return result.implementation
implementation = self._build_service(service_type, *args, **kwargs)
if result.lifetime == ServiceLifetimeEnum.singleton:
result.implementation = implementation
elif result.lifetime == ServiceLifetimeEnum.scoped and self._is_scope:
if (
result.lifetime == ServiceLifetimeEnum.singleton
or result.lifetime == ServiceLifetimeEnum.scoped
and self._scope is not None
):
result.implementation = implementation
return implementation
def get_service_type(self, service_type: Type[T]) -> Optional[Type[T]]:
for descriptor in self._service_descriptors:
if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type):
return descriptor.service_type
return None
def get_services(self, service_type: T, *args, **kwargs) -> list[Optional[R]]:
implementations = []
if typing.get_origin(service_type) == list:
raise Exception(f"Invalid type {service_type}! Expected single type not list of type")
implementations.extend(self._get_services(service_type))
return implementations
def get_service_types(self, service_type: Type[T]) -> list[Type[T]]:
types = []
for descriptor in self._service_descriptors:
if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type):
types.append(descriptor.service_type)
return types
implementations.extend(self._get_services(service_type))
return implementations

View File

@@ -0,0 +1,137 @@
import functools
from abc import abstractmethod, ABC
from inspect import Signature, signature, iscoroutinefunction
from typing import Optional, Type
from cpl.core.typing import T, R
from cpl.dependency.scope_abc import ScopeABC
class ServiceProviderABC(ABC):
r"""ABC for the class :class:`cpl.dependency.service_provider.ServiceProvider`"""
_provider: Optional["ServiceProviderABC"] = None
@abstractmethod
def __init__(self): ...
@classmethod
def set_global_provider(cls, provider: "ServiceProviderABC"):
cls._provider = provider
@classmethod
def get_global_provider(cls) -> Optional["ServiceProviderABC"]:
return cls._provider
@classmethod
def get_global_service(cls, instance_type: Type[T], *args, **kwargs) -> Optional[T]:
if cls._provider is None:
return None
return cls._provider.get_service(instance_type, *args, **kwargs)
@classmethod
def get_global_services(cls, instance_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
if cls._provider is None:
return []
return cls._provider.get_services(instance_type, *args, **kwargs)
@abstractmethod
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[T]: ...
@abstractmethod
def _build_service(self, service_type: type, *args, **kwargs) -> object:
r"""Creates instance of given type
Parameter
---------
instance_type: :class:`type`
The type of the searched instance
Returns
-------
Object of the given type
"""
@abstractmethod
def set_scope(self, scope: ScopeABC):
r"""Sets the scope of service provider
Parameter
---------
Object of type :class:`cpl.dependency.scope_abc.ScopeABC`
Service scope
"""
@abstractmethod
def create_scope(self) -> ScopeABC:
r"""Creates a service scope
Returns
-------
Object of type :class:`cpl.dependency.scope_abc.ScopeABC`
"""
@abstractmethod
def get_service(self, instance_type: Type[T], *args, **kwargs) -> Optional[T]:
r"""Returns instance of given type
Parameter
---------
instance_type: :class:`cpl.core.type.T`
The type of the searched instance
Returns
-------
Object of type Optional[:class:`cpl.core.type.T`]
"""
@abstractmethod
def get_services(self, service_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
r"""Returns instance of given type
Parameter
---------
service_type: :class:`cpl.core.type.T`
The type of the searched instance
Returns
-------
Object of type list[Optional[:class:`cpl.core.type.T`]
"""
@classmethod
def inject(cls, f=None):
r"""Decorator to allow injection into static and class methods
Parameter
---------
f: Callable
Returns
-------
function
"""
if f is None:
return functools.partial(cls.inject)
if iscoroutinefunction(f):
@functools.wraps(f)
async def async_inner(*args, **kwargs):
if cls._provider is None:
raise Exception(f"{cls.__name__} not build!")
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
return await f(*args, *injection, **kwargs)
return async_inner
@functools.wraps(f)
def inner(*args, **kwargs):
if cls._provider is None:
raise Exception(f"{cls.__name__} not build!")
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
return f(*args, *injection, **kwargs)
return inner

View File

@@ -3,7 +3,7 @@ from .abc.email_client_abc import EMailClientABC
from .email_client import EMailClient
from .email_client_settings import EMailClientSettings
from .email_model import EMail
from .logger import MailLogger
from .mail_logger import MailLogger
def add_mail(collection: _ServiceCollection):

View File

@@ -5,7 +5,7 @@ from typing import Optional
from cpl.mail.abc.email_client_abc import EMailClientABC
from cpl.mail.email_client_settings import EMailClientSettings
from cpl.mail.email_model import EMail
from cpl.mail.logger import MailLogger
from cpl.mail.mail_logger import MailLogger
class EMailClient(EMailClientABC):

View File

@@ -1,7 +0,0 @@
from cpl.core.log.wrapped_logger import WrappedLogger
class MailLogger(WrappedLogger):
def __init__(self):
WrappedLogger.__init__(self, "mail")

View File

@@ -0,0 +1,8 @@
from cpl.core.log.logger import Logger
from cpl.core.typing import Source
class MailLogger(Logger):
def __init__(self, source: Source):
Logger.__init__(self, source, "mail")

View File

@@ -1,7 +1 @@
from .array import Array
from .enumerable import Enumerable
from .immutable_list import ImmutableList
from .immutable_set import ImmutableSet
from .list import List
from .ordered_enumerable import OrderedEnumerable
from .set import Set

View File

@@ -0,0 +1,2 @@
def is_number(t: type) -> bool:
return issubclass(t, int) or issubclass(t, float) or issubclass(t, complex)

View File

@@ -1,44 +0,0 @@
from typing import Generic, Iterable, Optional
from cpl.core.typing import T
from cpl.query.list import List
from cpl.query.enumerable import Enumerable
class Array(Generic[T], List[T]):
def __init__(self, length: int, source: Optional[Iterable[T]] = None):
List.__init__(self, source)
self._length = length
@property
def length(self) -> int:
return len(self._source)
def add(self, item: T) -> None:
if self._length == self.length:
raise IndexError("Array is full")
self._source.append(item)
def extend(self, items: Iterable[T]) -> None:
if self._length == self.length:
raise IndexError("Array is full")
self._source.extend(items)
def insert(self, index: int, item: T) -> None:
if index < 0 or index > self.length:
raise IndexError("Index out of range")
self._source.insert(index, item)
def remove(self, item: T) -> None:
self._source.remove(item)
def pop(self, index: int = -1) -> T:
return self._source.pop(index)
def clear(self) -> None:
self._source.clear()
def to_enumerable(self) -> "Enumerable[T]":
from cpl.query.enumerable import Enumerable
return Enumerable(self._source)

View File

@@ -0,0 +1,5 @@
from .default_lambda import default_lambda
from .ordered_queryable import OrderedQueryable
from .sequence import Sequence
from .ordered_queryable_abc import OrderedQueryableABC
from .queryable_abc import QueryableABC

View File

@@ -0,0 +1,2 @@
def default_lambda(x: object):
return x

View File

@@ -0,0 +1,34 @@
from collections.abc import Callable
from cpl.query.base.ordered_queryable_abc import OrderedQueryableABC
from cpl.query.exceptions import ArgumentNoneException, ExceptionArgument
class OrderedQueryable(OrderedQueryableABC):
r"""Implementation of :class: `cpl.query.base.ordered_queryable_abc.OrderedQueryableABC`"""
def __init__(self, _t: type, _values: OrderedQueryableABC = None, _func: Callable = None):
OrderedQueryableABC.__init__(self, _t, _values, _func)
def then_by(self, _func: Callable) -> OrderedQueryableABC:
if self is None:
raise ArgumentNoneException(ExceptionArgument.list)
if _func is None:
raise ArgumentNoneException(ExceptionArgument.func)
self._funcs.append(_func)
return OrderedQueryable(self.type, sorted(self, key=lambda *args: [f(*args) for f in self._funcs]), _func)
def then_by_descending(self, _func: Callable) -> OrderedQueryableABC:
if self is None:
raise ArgumentNoneException(ExceptionArgument.list)
if _func is None:
raise ArgumentNoneException(ExceptionArgument.func)
self._funcs.append(_func)
return OrderedQueryable(
self.type, sorted(self, key=lambda *args: [f(*args) for f in self._funcs], reverse=True), _func
)

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
from abc import abstractmethod
from collections.abc import Callable
from typing import Iterable
from cpl.query.base.queryable_abc import QueryableABC
class OrderedQueryableABC(QueryableABC):
@abstractmethod
def __init__(self, _t: type, _values: Iterable = None, _func: Callable = None):
QueryableABC.__init__(self, _t, _values)
self._funcs: list[Callable] = []
if _func is not None:
self._funcs.append(_func)
@abstractmethod
def then_by(self, func: Callable) -> OrderedQueryableABC:
r"""Sorts OrderedList in ascending order by function
Parameter:
func: :class:`Callable`
Returns:
list of :class:`cpl.query.base.ordered_queryable_abc.OrderedQueryableABC`
"""
@abstractmethod
def then_by_descending(self, func: Callable) -> OrderedQueryableABC:
r"""Sorts OrderedList in descending order by function
Parameter:
func: :class:`Callable`
Returns:
list of :class:`cpl.query.base.ordered_queryable_abc.OrderedQueryableABC`
"""

View File

@@ -0,0 +1,569 @@
from __future__ import annotations
from typing import Optional, Callable, Union, Iterable, Any
from cpl.query._helper import is_number
from cpl.query.base import default_lambda
from cpl.query.base.sequence import Sequence
from cpl.query.exceptions import (
InvalidTypeException,
ArgumentNoneException,
ExceptionArgument,
IndexOutOfRangeException,
)
class QueryableABC(Sequence):
def __init__(self, t: type, values: Iterable = None):
Sequence.__init__(self, t, values)
def all(self, _func: Callable = None) -> bool:
r"""Checks if every element of list equals result found by function
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
bool
"""
if _func is None:
_func = default_lambda
return self.count(_func) == self.count()
def any(self, _func: Callable = None) -> bool:
r"""Checks if list contains result found by function
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
bool
"""
if _func is None:
_func = default_lambda
return self.where(_func).count() > 0
def average(self, _func: Callable = None) -> Union[int, float, complex]:
r"""Returns average value of list
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
Union[int, float, complex]
"""
if _func is None and not is_number(self.type):
raise InvalidTypeException()
return self.sum(_func) / self.count()
def contains(self, _value: object) -> bool:
r"""Checks if list contains value given by function
Parameter
---------
value: :class:`object`
value
Returns
-------
bool
"""
if _value is None:
raise ArgumentNoneException(ExceptionArgument.value)
return self.where(lambda x: x == _value).count() > 0
def count(self, _func: Callable = None) -> int:
r"""Returns length of list or count of found elements
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
int
"""
if _func is None:
return self.__len__()
return self.where(_func).count()
def distinct(self, _func: Callable = None) -> QueryableABC:
r"""Returns list without redundancies
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
if _func is None:
_func = default_lambda
result = []
known_values = []
for element in self:
value = _func(element)
if value in known_values:
continue
known_values.append(value)
result.append(element)
return type(self)(self._type, result)
def element_at(self, _index: int) -> any:
r"""Returns element at given index
Parameter
---------
_index: :class:`int`
index
Returns
-------
Value at _index: any
"""
if _index is None:
raise ArgumentNoneException(ExceptionArgument.index)
if _index < 0 or _index >= self.count():
raise IndexOutOfRangeException
result = self._values[_index]
if result is None:
raise IndexOutOfRangeException
return result
def element_at_or_default(self, _index: int) -> Optional[any]:
r"""Returns element at given index or None
Parameter
---------
_index: :class:`int`
index
Returns
-------
Value at _index: Optional[any]
"""
if _index is None:
raise ArgumentNoneException(ExceptionArgument.index)
try:
return self._values[_index]
except IndexError:
return None
def first(self) -> any:
r"""Returns first element
Returns
-------
First element of list: any
"""
if self.count() == 0:
raise IndexOutOfRangeException()
return self._values[0]
def first_or_default(self) -> any:
r"""Returns first element or None
Returns
-------
First element of list: Optional[any]
"""
if self.count() == 0:
return None
return self._values[0]
def for_each(self, _func: Callable = None):
r"""Runs given function for each element of list
Parameter
---------
func: :class: `Callable`
function to call
"""
if _func is not None:
for element in self:
_func(element)
return self
def group_by(self, _func: Callable = None) -> QueryableABC:
r"""Groups by func
Returns
-------
Grouped list[list[any]]: any
"""
if _func is None:
_func = default_lambda
groups = {}
for v in self:
value = _func(v)
if v not in groups:
groups[value] = []
groups[value].append(v)
v = []
for g in groups.values():
v.append(type(self)(object, g))
x = type(self)(type(self), v)
return x
def last(self) -> any:
r"""Returns last element
Returns
-------
Last element of list: any
"""
if self.count() == 0:
raise IndexOutOfRangeException()
return self._values[self.count() - 1]
def last_or_default(self) -> any:
r"""Returns last element or None
Returns
-------
Last element of list: Optional[any]
"""
if self.count() == 0:
return None
return self._values[self.count() - 1]
def max(self, _func: Callable = None) -> object:
r"""Returns the highest value
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
object
"""
if _func is None and not is_number(self.type):
raise InvalidTypeException()
if _func is None:
_func = default_lambda
return _func(max(self, key=_func))
def median(self, _func=None) -> Union[int, float]:
r"""Return the median value of data elements
Returns
-------
Union[int, float]
"""
if _func is None:
_func = default_lambda
result = self.order_by(_func).select(_func).to_list()
length = len(result)
i = int(length / 2)
return result[i] if length % 2 == 1 else (float(result[i - 1]) + float(result[i])) / float(2)
def min(self, _func: Callable = None) -> object:
r"""Returns the lowest value
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
object
"""
if _func is None and not is_number(self.type):
raise InvalidTypeException()
if _func is None:
_func = default_lambda
return _func(min(self, key=_func))
def order_by(self, _func: Callable = None) -> "OrderedQueryableABC":
r"""Sorts elements by function in ascending order
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
:class: `cpl.query.base.ordered_queryable_abc.OrderedQueryableABC`
"""
if _func is None:
_func = default_lambda
from cpl.query.base.ordered_queryable import OrderedQueryable
return OrderedQueryable(self.type, sorted(self, key=_func), _func)
def order_by_descending(self, _func: Callable = None) -> "OrderedQueryableABC":
r"""Sorts elements by function in descending order
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
:class: `cpl.query.base.ordered_queryable_abc.OrderedQueryableABC`
"""
if _func is None:
_func = default_lambda
from cpl.query.base.ordered_queryable import OrderedQueryable
return OrderedQueryable(self.type, sorted(self, key=_func, reverse=True), _func)
def reverse(self) -> QueryableABC:
r"""Reverses list
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
return type(self)(self._type, reversed(self._values))
def select(self, _func: Callable) -> QueryableABC:
r"""Formats each element of list to a given format
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
if _func is None:
_func = default_lambda
_l = [_func(_o) for _o in self]
_t = type(_l[0]) if len(_l) > 0 else Any
return type(self)(_t, _l)
def select_many(self, _func: Callable) -> QueryableABC:
r"""Flattens resulting lists to one
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
# The line below is pain. I don't understand anything of it...
# written on 09.11.2022 by Sven Heidemann
return type(self)(object, [_a for _o in self for _a in _func(_o)])
def single(self) -> any:
r"""Returns one single element of list
Returns
-------
Found value: any
Raises
------
ArgumentNoneException: when argument is None
Exception: when argument is None or found more than one element
"""
if self.count() > 1:
raise Exception("Found more than one element")
elif self.count() == 0:
raise Exception("Found no element")
return self._values[0]
def single_or_default(self) -> Optional[any]:
r"""Returns one single element of list
Returns
-------
Found value: Optional[any]
"""
if self.count() > 1:
raise Exception("Index out of range")
elif self.count() == 0:
return None
return self._values[0]
def skip(self, _index: int) -> QueryableABC:
r"""Skips all elements from index
Parameter
---------
_index: :class:`int`
index
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
if _index is None:
raise ArgumentNoneException(ExceptionArgument.index)
return type(self)(self.type, self._values[_index:])
def skip_last(self, _index: int) -> QueryableABC:
r"""Skips all elements after index
Parameter
---------
_index: :class:`int`
index
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
if _index is None:
raise ArgumentNoneException(ExceptionArgument.index)
index = self.count() - _index
return type(self)(self._type, self._values[:index])
def sum(self, _func: Callable = None) -> Union[int, float, complex]:
r"""Sum of all values
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
Union[int, float, complex]
"""
if _func is None and not is_number(self.type):
raise InvalidTypeException()
if _func is None:
_func = default_lambda
result = 0
for x in self:
result += _func(x)
return result
def split(self, _func: Callable) -> QueryableABC:
r"""Splits the list by given function
Parameter
---------
func: :class:`Callable`
seperator
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
groups = []
group = []
for x in self:
v = _func(x)
if x == v:
groups.append(group)
group = []
group.append(x)
groups.append(group)
query_groups = []
for g in groups:
if len(g) == 0:
continue
query_groups.append(type(self)(self._type, g))
return type(self)(self._type, query_groups)
def take(self, _index: int) -> QueryableABC:
r"""Takes all elements from index
Parameter
---------
_index: :class:`int`
index
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
if _index is None:
raise ArgumentNoneException(ExceptionArgument.index)
return type(self)(self._type, self._values[:_index])
def take_last(self, _index: int) -> QueryableABC:
r"""Takes all elements after index
Parameter
---------
_index: :class:`int`
index
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
index = self.count() - _index
if index >= self.count() or index < 0:
raise IndexOutOfRangeException()
return type(self)(self._type, self._values[index:])
def where(self, _func: Callable = None) -> QueryableABC:
r"""Select element by function
Parameter
---------
func: :class:`Callable`
selected value
Returns
-------
:class: `cpl.query.base.queryable_abc.QueryableABC`
"""
if _func is None:
raise ArgumentNoneException(ExceptionArgument.func)
if _func is None:
_func = default_lambda
return type(self)(self.type, filter(_func, self))

View File

@@ -0,0 +1,96 @@
from abc import abstractmethod, ABC
from typing import Iterable
class Sequence(ABC):
@abstractmethod
def __init__(self, t: type, values: Iterable = None):
assert t is not None
assert isinstance(t, type) or t == any
assert values is None or isinstance(values, Iterable)
if values is None:
values = []
self._values = list(values)
if t is None:
t = object
self._type = t
def __iter__(self):
return iter(self._values)
def __next__(self):
return next(iter(self._values))
def __len__(self):
return self.to_list().__len__()
@classmethod
def __class_getitem__(cls, _t: type) -> type:
return _t
def __repr__(self):
return f"<{type(self).__name__} {self.to_list().__repr__()}>"
@property
def type(self) -> type:
return self._type
def _check_type(self, __object: any):
if self._type == any:
return
if (
self._type is not None
and type(__object) != self._type
and not isinstance(type(__object), self._type)
and not issubclass(type(__object), self._type)
):
raise Exception(f"Unexpected type: {type(__object)}\nExpected type: {self._type}")
def to_list(self) -> list:
r"""Converts :class: `cpl.query.base.sequence_abc.SequenceABC` to :class: `list`
Returns:
:class: `list`
"""
return [x for x in self._values]
def copy(self) -> "Sequence":
r"""Creates a copy of sequence
Returns:
Sequence
"""
return type(self)(self._type, self.to_list())
@classmethod
def empty(cls) -> "Sequence":
r"""Returns an empty sequence
Returns:
Sequence object that contains no elements
"""
return cls(object, [])
def index_of(self, _object: object) -> int:
r"""Returns the index of given element
Returns:
Index of object
Raises:
IndexError if object not in sequence
"""
for i, o in enumerate(self):
if o == _object:
return i
raise IndexError
@classmethod
def range(cls, start: int, length: int) -> "Sequence":
return cls(int, range(start, length))

View File

@@ -1,213 +0,0 @@
from itertools import islice, groupby, chain
from typing import Generic, Callable, Iterable, Iterator, Dict, Tuple, Optional
from cpl.core.typing import T, R
from cpl.query.typing import Predicate, K, Selector
class Enumerable(Generic[T]):
def __init__(self, source: Iterable[T]):
self._source = source
def __iter__(self) -> Iterator[T]:
return iter(self._source)
@property
def length(self) -> int:
return len(list(self._source))
def where(self, f: Predicate) -> "Enumerable[T]":
return Enumerable(x for x in self._source if f(x))
def select(self, f: Selector) -> "Enumerable[R]":
return Enumerable(f(x) for x in self._source)
def select_many(self, f: Callable[[T], Iterable[R]]) -> "Enumerable[R]":
return Enumerable(y for x in self._source for y in f(x))
def take(self, count: int) -> "Enumerable[T]":
return Enumerable(islice(self._source, count))
def skip(self, count: int) -> "Enumerable[T]":
return Enumerable(islice(self._source, count, None))
def take_while(self, f: Predicate) -> "Enumerable[T]":
def generator():
for x in self._source:
if f(x):
yield x
else:
break
return Enumerable(generator())
def skip_while(self, f: Predicate) -> "Enumerable[T]":
def generator():
it = iter(self._source)
for x in it:
if not f(x):
yield x
break
yield from it
return Enumerable(generator())
def distinct(self) -> "Enumerable[T]":
def generator():
seen = set()
for x in self._source:
if x not in seen:
seen.add(x)
yield x
return Enumerable(generator())
def union(self, other: Iterable[T]) -> "Enumerable[T]":
return Enumerable(chain(self.distinct(), Enumerable(other).distinct())).distinct()
def intersect(self, other: Iterable[T]) -> "Enumerable[T]":
other_set = set(other)
return Enumerable(x for x in self._source if x in other_set)
def except_(self, other: Iterable[T]) -> "Enumerable[T]":
other_set = set(other)
return Enumerable(x for x in self._source if x not in other_set)
def concat(self, other: Iterable[T]) -> "Enumerable[T]":
return Enumerable(chain(self._source, other))
# --- Aggregation ---
def count(self) -> int:
return sum(1 for _ in self._source)
def sum(self, f: Optional[Selector] = None) -> R:
if f:
return sum(f(x) for x in self._source)
return sum(self._source) # type: ignore
def min(self, f: Optional[Selector] = None) -> R:
if f:
return min(f(x) for x in self._source)
return min(self._source) # type: ignore
def max(self, f: Optional[Selector] = None) -> R:
if f:
return max(f(x) for x in self._source)
return max(self._source) # type: ignore
def average(self, f: Optional[Callable[[T], float]] = None) -> float:
values = list(self.select(f).to_list()) if f else list(self._source)
return sum(values) / len(values) if values else 0.0
def aggregate(self, func: Callable[[R, T], R], seed: Optional[R] = None) -> R:
it = iter(self._source)
if seed is None:
acc = next(it) # type: ignore
else:
acc = seed
for x in it:
acc = func(acc, x)
return acc
def any(self, f: Optional[Predicate] = None) -> bool:
return any(f(x) if f else x for x in self._source)
def all(self, f: Predicate) -> bool:
return all(f(x) for x in self._source)
def contains(self, value: T) -> bool:
return any(x == value for x in self._source)
def sequence_equal(self, other: Iterable[T]) -> bool:
return list(self._source) == list(other)
def group_by(self, key_f: Callable[[T], K]) -> "Enumerable[Tuple[K, List[T]]]":
def generator():
sorted_data = sorted(self._source, key=key_f)
for key, group in groupby(sorted_data, key=key_f):
yield (key, list(group))
return Enumerable(generator())
def join(
self, inner: Iterable[R], outer_key: Callable[[T], K], inner_key: Callable[[R], K], result: Callable[[T, R], R]
) -> "Enumerable[R]":
def generator():
lookup: Dict[K, List[R]] = {}
for i in inner:
k = inner_key(i)
lookup.setdefault(k, []).append(i)
for o in self._source:
k = outer_key(o)
if k in lookup:
for i in lookup[k]:
yield result(o, i)
return Enumerable(generator())
def first(self, f: Optional[Predicate] = None) -> T:
if f:
for x in self._source:
if f(x):
return x
raise ValueError("No matching element")
return next(iter(self._source))
def first_or_default(self, default: Optional[T] = None) -> Optional[T]:
return next(iter(self._source), default)
def last(self) -> T:
return list(self._source)[-1]
def single(self, f: Optional[Predicate] = None) -> T:
items = [x for x in self._source if f(x)] if f else list(self._source)
if len(items) != 1:
raise ValueError("Sequence does not contain exactly one element")
return items[0]
def to_list(self) -> "List[T]":
from cpl.query.list import List
return List(self)
def to_set(self) -> "Set[T]":
from cpl.query.set import Set
return Set(self)
def to_dict(self, key_f: Callable[[T], K], value_f: Selector) -> Dict[K, R]:
return {key_f(x): value_f(x) for x in self._source}
def cast(self, t: Selector) -> "Enumerable[R]":
return Enumerable(t(x) for x in self._source)
def of_type(self, t: type) -> "Enumerable[T]":
return Enumerable(x for x in self._source if isinstance(x, t))
def reverse(self) -> "Enumerable[T]":
return Enumerable(reversed(list(self._source)))
def zip(self, other: Iterable[R]) -> "Enumerable[Tuple[T, R]]":
return Enumerable(zip(self._source, other))
def order_by(self, key_selector: Callable[[T], K]) -> "OrderedEnumerable[T]":
from cpl.query.ordered_enumerable import OrderedEnumerable
return OrderedEnumerable(self._source, [(key_selector, False)])
def order_by_descending(self, key_selector: Callable[[T], K]) -> "OrderedEnumerable[T]":
from cpl.query.ordered_enumerable import OrderedEnumerable
return OrderedEnumerable(self._source, [(key_selector, True)])
@staticmethod
def range(start: int, count: int) -> "Enumerable[int]":
return Enumerable(range(start, start + count))
@staticmethod
def repeat(value: T, count: int) -> "Enumerable[T]":
return Enumerable(value for _ in range(count))
@staticmethod
def empty() -> "Enumerable[T]":
return Enumerable([])

View File

@@ -0,0 +1,2 @@
from .enumerable import Enumerable
from .enumerable_abc import EnumerableABC

View File

@@ -0,0 +1,12 @@
from cpl.query.enumerable.enumerable_abc import EnumerableABC
def _default_lambda(x: object):
return x
class Enumerable(EnumerableABC):
r"""Implementation of :class: `cpl.query.enumerable.enumerable_abc.EnumerableABC`"""
def __init__(self, t: type = None, values: list = None):
EnumerableABC.__init__(self, t, values)

View File

@@ -0,0 +1,21 @@
from abc import abstractmethod
from cpl.query.base.queryable_abc import QueryableABC
class EnumerableABC(QueryableABC):
r"""ABC to define functions on list"""
@abstractmethod
def __init__(self, t: type = None, values: list = None):
QueryableABC.__init__(self, t, values)
def to_iterable(self) -> "IterableABC":
r"""Converts :class: `cpl.query.enumerable.enumerable_abc.EnumerableABC` to :class: `cpl.query.iterable.iterable_abc.IterableABC`
Returns:
:class: `cpl.query.iterable.iterable_abc.IterableABC`
"""
from cpl.query.iterable.iterable import Iterable
return Iterable(self._type, self.to_list())

View File

@@ -0,0 +1,33 @@
from enum import Enum
# models
class ExceptionArgument(Enum):
list = "list"
func = "func"
type = "type"
value = "value"
index = "index"
# exceptions
class ArgumentNoneException(Exception):
r"""Exception when argument is None"""
def __init__(self, arg: ExceptionArgument):
Exception.__init__(self, f"argument {arg} is None")
class IndexOutOfRangeException(Exception):
r"""Exception when index is out of range"""
def __init__(self, err: str = None):
Exception.__init__(self, f"List index out of range" if err is None else err)
class InvalidTypeException(Exception):
r"""Exception when type is invalid"""
class WrongTypeException(Exception):
r"""Exception when type is unexpected"""

View File

@@ -0,0 +1 @@
from .list import List

View File

@@ -0,0 +1,36 @@
from cpl.query.enumerable.enumerable_abc import EnumerableABC
from cpl.query.iterable.iterable import Iterable
class List(Iterable):
r"""Implementation of :class: `cpl.query.extension.iterable.Iterable`"""
def __init__(self, t: type = None, values: Iterable = None):
Iterable.__init__(self, t, values)
def __getitem__(self, *args):
return self._values.__getitem__(*args)
def __setitem__(self, *args):
self._values.__setitem__(*args)
def __delitem__(self, *args):
self._values.__delitem__(*args)
def to_enumerable(self) -> EnumerableABC:
r"""Converts :class: `cpl.query.iterable.iterable_abc.IterableABC` to :class: `cpl.query.enumerable.enumerable_abc.EnumerableABC`
Returns:
:class: `cpl.query.enumerable.enumerable_abc.EnumerableABC`
"""
from cpl.query.enumerable.enumerable import Enumerable
return Enumerable(self._type, self.to_list())
def to_iterable(self) -> Iterable:
r"""Converts :class: `cpl.query.enumerable.enumerable_abc.EnumerableABC` to :class: `cpl.query.iterable.iterable_abc.IterableABC`
Returns:
:class: `cpl.query.iterable.iterable_abc.IterableABC`
"""
return Iterable(self._type, self.to_list())

View File

@@ -1,65 +0,0 @@
from typing import Generic, Iterable, Iterator, Optional
from cpl.core.typing import T
from cpl.query.enumerable import Enumerable
class ImmutableList(Generic[T], Enumerable[T]):
def __init__(self, source: Optional[Iterable[T]] = None):
Enumerable.__init__(self, [])
if source is None:
source = []
elif not isinstance(source, list):
source = list(source)
self.__source = source
@property
def _source(self) -> list[T]:
return self.__source
@_source.setter
def _source(self, value: list[T]) -> None:
self.__source = value
def __iter__(self) -> Iterator[T]:
return iter(self._source)
def __len__(self) -> int:
return len(self._source)
def __getitem__(self, index: int) -> T:
return self._source[index]
def __contains__(self, item: T) -> bool:
return item in self._source
def __repr__(self) -> str:
return f"List({self._source!r})"
@property
def length(self) -> int:
return len(self._source)
def add(self, item: T) -> None:
self._source.append(item)
def extend(self, items: Iterable[T]) -> None:
self._source.extend(items)
def insert(self, index: int, item: T) -> None:
self._source.insert(index, item)
def remove(self, item: T) -> None:
self._source.remove(item)
def pop(self, index: int = -1) -> T:
return self._source.pop(index)
def clear(self) -> None:
self._source.clear()
def to_enumerable(self) -> "Enumerable[T]":
from cpl.query.enumerable import Enumerable
return Enumerable(self._source)

Some files were not shown because too many files have changed in this diff Show More