Compare commits
1 Commits
2025.09.24
...
2025.09.22
| Author | SHA1 | Date | |
|---|---|---|---|
| 69bbbc8cee |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -139,6 +139,3 @@ PythonImportHelper-v2-Completion.json
|
||||
|
||||
# cpl unittest stuff
|
||||
unittests/test_*_playground
|
||||
|
||||
# cpl logs
|
||||
**/logs/*.jsonl
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from cpl import api
|
||||
from cpl.api.application.web_app import WebApp
|
||||
from cpl.application import ApplicationBuilder
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.auth.schema import AuthUser, Role
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.console import Console
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.core.utils.cache import Cache
|
||||
from custom.api.src.scoped_service import ScopedService
|
||||
from service import PingService
|
||||
|
||||
|
||||
def main():
|
||||
builder = ApplicationBuilder[WebApp](WebApp)
|
||||
|
||||
Configuration.add_json_file(f"appsettings.json")
|
||||
Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json")
|
||||
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
|
||||
|
||||
# builder.services.add_logging()
|
||||
builder.services.add_structured_logging()
|
||||
builder.services.add_transient(PingService)
|
||||
builder.services.add_module(api)
|
||||
|
||||
builder.services.add_scoped(ScopedService)
|
||||
|
||||
builder.services.add_cache(AuthUser)
|
||||
builder.services.add_cache(Role)
|
||||
|
||||
app = builder.build()
|
||||
app.with_logging()
|
||||
app.with_database()
|
||||
|
||||
app.with_authentication()
|
||||
app.with_authorization()
|
||||
|
||||
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET", authentication=True, permissions=[Permissions.administrator])
|
||||
app.with_routes_directory("routes")
|
||||
|
||||
provider = builder.service_provider
|
||||
user_cache = provider.get_service(Cache[AuthUser])
|
||||
role_cache = provider.get_service(Cache[Role])
|
||||
|
||||
if role_cache == user_cache:
|
||||
raise Exception("Cache service is not working")
|
||||
|
||||
s1 = provider.get_service(ScopedService)
|
||||
s2 = provider.get_service(ScopedService)
|
||||
|
||||
if s1.name == s2.name:
|
||||
raise Exception("Scoped service is not working")
|
||||
|
||||
with provider.create_scope() as scope:
|
||||
s3 = scope.get_service(ScopedService)
|
||||
s4 = scope.get_service(ScopedService)
|
||||
|
||||
if s3.name != s4.name:
|
||||
raise Exception("Scoped service is not working")
|
||||
|
||||
if s1.name == s3.name:
|
||||
raise Exception("Scoped service is not working")
|
||||
|
||||
Console.write_line(
|
||||
s1.name,
|
||||
s2.name,
|
||||
s3.name,
|
||||
s4.name,
|
||||
)
|
||||
|
||||
app.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,21 +0,0 @@
|
||||
from urllib.request import Request
|
||||
|
||||
from service import PingService
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from cpl.api import APILogger
|
||||
from cpl.api.router import Router
|
||||
from cpl.core.console import Console
|
||||
from cpl.dependency import ServiceProvider
|
||||
from custom.api.src.scoped_service import ScopedService
|
||||
|
||||
|
||||
@Router.authenticate()
|
||||
# @Router.authorize(permissions=[Permissions.administrator])
|
||||
# @Router.authorize(policies=["test"])
|
||||
@Router.get(f"/ping")
|
||||
async def ping(r: Request, ping: PingService, logger: APILogger, provider: ServiceProvider, scoped: ScopedService):
|
||||
logger.info(f"Ping: {ping}")
|
||||
|
||||
Console.write_line(scoped.name)
|
||||
return JSONResponse(ping.ping(r))
|
||||
@@ -1,14 +0,0 @@
|
||||
from cpl.core.console.console import Console
|
||||
from cpl.core.utils.string import String
|
||||
|
||||
|
||||
class ScopedService:
|
||||
def __init__(self):
|
||||
self._name = String.random(8)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def run(self):
|
||||
Console.write_line(f"Im {self._name}")
|
||||
@@ -1,10 +0,0 @@
|
||||
from cpl.dependency import ServiceProvider, ServiceProvider
|
||||
from cpl.dependency.inject import inject
|
||||
from di.test_service import TestService
|
||||
|
||||
|
||||
class StaticTest:
|
||||
@staticmethod
|
||||
@inject
|
||||
def test(services: ServiceProvider, t1: TestService):
|
||||
t1.run()
|
||||
@@ -1,10 +0,0 @@
|
||||
from cpl.core.console import Console
|
||||
|
||||
|
||||
class ScopedService:
|
||||
def __init__(self):
|
||||
self.value = "I am a scoped service"
|
||||
Console.write_line(self.value, self)
|
||||
|
||||
def get_value(self):
|
||||
return self.value
|
||||
@@ -1,60 +0,0 @@
|
||||
from cpl.core.console import Console
|
||||
from cpl.core.utils.benchmark import Benchmark
|
||||
from cpl.query.enumerable import Enumerable
|
||||
from cpl.query.immutable_list import ImmutableList
|
||||
from cpl.query.list import List
|
||||
from cpl.query.set import Set
|
||||
|
||||
|
||||
def _default():
|
||||
Console.write_line(Enumerable.empty().to_list())
|
||||
|
||||
Console.write_line(Enumerable.range(0, 100).length)
|
||||
Console.write_line(Enumerable.range(0, 100).to_list())
|
||||
|
||||
Console.write_line(Enumerable.range(0, 100).where(lambda x: x % 2 == 0).length)
|
||||
Console.write_line(
|
||||
Enumerable.range(0, 100).where(lambda x: x % 2 == 0).to_list().select(lambda x: str(x)).to_list()
|
||||
)
|
||||
Console.write_line(List)
|
||||
|
||||
s =Enumerable.range(0, 10).to_set()
|
||||
Console.write_line(s)
|
||||
s.add(1)
|
||||
Console.write_line(s)
|
||||
|
||||
data = Enumerable(
|
||||
[
|
||||
{"name": "Alice", "age": 30},
|
||||
{"name": "Dave", "age": 35},
|
||||
{"name": "Charlie", "age": 25},
|
||||
{"name": "Bob", "age": 25},
|
||||
]
|
||||
)
|
||||
|
||||
Console.write_line(data.order_by(lambda x: x["age"]).to_list())
|
||||
Console.write_line(data.order_by(lambda x: x["age"]).then_by(lambda x: x["name"]).to_list())
|
||||
Console.write_line(data.order_by(lambda x: x["name"]).then_by(lambda x: x["age"]).to_list())
|
||||
|
||||
|
||||
def t_benchmark(data: list):
|
||||
Benchmark.all("Enumerable", lambda: Enumerable(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
|
||||
Benchmark.all("Set", lambda: Set(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
|
||||
Benchmark.all("List", lambda: List(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
|
||||
Benchmark.all(
|
||||
"ImmutableList", lambda: ImmutableList(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list()
|
||||
)
|
||||
Benchmark.all("List comprehension", lambda: [x * 2 for x in data if x % 2 == 0])
|
||||
|
||||
|
||||
def main():
|
||||
N = 10_000_000
|
||||
data = list(range(N))
|
||||
#t_benchmark(data)
|
||||
|
||||
Console.write_line()
|
||||
_default()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,9 +1,5 @@
|
||||
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
|
||||
|
||||
from .error import APIError, AlreadyExists, EndpointNotImplemented, Forbidden, NotFound, Unauthorized
|
||||
from .logger import APILogger
|
||||
from .settings import ApiSettings
|
||||
|
||||
|
||||
def add_api(collection: _ServiceCollection):
|
||||
try:
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .asgi_middleware_abc import ASGIMiddleware
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .web_app import WebApp
|
||||
|
||||
@@ -27,41 +27,40 @@ from cpl.api.settings import ApiSettings
|
||||
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
|
||||
from cpl.application.abc.application_abc import ApplicationABC
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
_logger = APILogger("API")
|
||||
|
||||
PolicyInput = Union[dict[str, PolicyResolver], Policy]
|
||||
|
||||
|
||||
class WebApp(ApplicationABC):
|
||||
def __init__(self, services: ServiceProvider):
|
||||
def __init__(self, services: ServiceProviderABC):
|
||||
super().__init__(services, [auth, api])
|
||||
self._app: Starlette | None = None
|
||||
|
||||
self._logger = services.get_service(APILogger)
|
||||
|
||||
self._api_settings = Configuration.get(ApiSettings)
|
||||
self._policies = services.get_service(PolicyRegistry)
|
||||
self._routes = services.get_service(RouteRegistry)
|
||||
|
||||
self._middleware: list[Middleware] = []
|
||||
self._middleware: list[Middleware] = [
|
||||
Middleware(RequestMiddleware),
|
||||
Middleware(LoggingMiddleware),
|
||||
]
|
||||
self._exception_handlers: Mapping[Any, ExceptionHandler] = {
|
||||
Exception: self._handle_exception,
|
||||
APIError: self._handle_exception,
|
||||
}
|
||||
|
||||
self.with_middleware(RequestMiddleware)
|
||||
self.with_middleware(LoggingMiddleware)
|
||||
|
||||
async def _handle_exception(self, request: Request, exc: Exception):
|
||||
@staticmethod
|
||||
async def _handle_exception(request: Request, exc: Exception):
|
||||
if isinstance(exc, APIError):
|
||||
self._logger.error(exc)
|
||||
_logger.error(exc)
|
||||
return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
|
||||
|
||||
if hasattr(request.state, "request_id"):
|
||||
self._logger.error(f"Request {request.state.request_id}", exc)
|
||||
_logger.error(f"Request {request.state.request_id}", exc)
|
||||
else:
|
||||
self._logger.error("Request unknown", exc)
|
||||
_logger.error("Request unknown", exc)
|
||||
|
||||
return JSONResponse({"error": str(exc)}, status_code=500)
|
||||
|
||||
@@ -69,10 +68,10 @@ class WebApp(ApplicationABC):
|
||||
origins = self._api_settings.allowed_origins
|
||||
|
||||
if origins is None or origins == "":
|
||||
self._logger.warning("No allowed origins specified, allowing all origins")
|
||||
_logger.warning("No allowed origins specified, allowing all origins")
|
||||
return ["*"]
|
||||
|
||||
self._logger.debug(f"Allowed origins: {origins}")
|
||||
_logger.debug(f"Allowed origins: {origins}")
|
||||
return origins.split(",")
|
||||
|
||||
def with_database(self) -> Self:
|
||||
@@ -168,9 +167,9 @@ class WebApp(ApplicationABC):
|
||||
self._check_for_app()
|
||||
|
||||
if isinstance(middleware, Middleware):
|
||||
self._middleware.append(inject(middleware))
|
||||
self._middleware.append(middleware)
|
||||
elif callable(middleware):
|
||||
self._middleware.append(Middleware(inject(middleware)))
|
||||
self._middleware.append(Middleware(middleware))
|
||||
else:
|
||||
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
|
||||
|
||||
@@ -191,11 +190,11 @@ class WebApp(ApplicationABC):
|
||||
if isinstance(policy, dict):
|
||||
for name, resolver in policy.items():
|
||||
if not isinstance(name, str):
|
||||
self._logger.warning(f"Skipping policy at index {i}, name must be a string")
|
||||
_logger.warning(f"Skipping policy at index {i}, name must be a string")
|
||||
continue
|
||||
|
||||
if not callable(resolver):
|
||||
self._logger.warning(f"Skipping policy {name}, resolver must be callable")
|
||||
_logger.warning(f"Skipping policy {name}, resolver must be callable")
|
||||
continue
|
||||
|
||||
_policies.append(Policy(name, resolver))
|
||||
@@ -203,7 +202,7 @@ class WebApp(ApplicationABC):
|
||||
|
||||
_policies.append(policy)
|
||||
|
||||
self._policies.extend(_policies)
|
||||
self._policies.extend_policies(_policies)
|
||||
|
||||
self.with_middleware(AuthorizationMiddleware)
|
||||
return self
|
||||
@@ -213,14 +212,14 @@ class WebApp(ApplicationABC):
|
||||
for policy_name in rule["policies"]:
|
||||
policy = self._policies.get(policy_name)
|
||||
if not policy:
|
||||
self._logger.fatal(f"Authorization policy '{policy_name}' not found")
|
||||
_logger.fatal(f"Authorization policy '{policy_name}' not found")
|
||||
|
||||
async def main(self):
|
||||
self._logger.debug(f"Preparing API")
|
||||
_logger.debug(f"Preparing API")
|
||||
self._validate_policies()
|
||||
|
||||
if self._app is None:
|
||||
routes = [route.to_starlette(inject) for route in self._routes.all()]
|
||||
routes = [route.to_starlette(self._services.inject) for route in self._routes.all()]
|
||||
|
||||
app = Starlette(
|
||||
routes=routes,
|
||||
@@ -238,7 +237,7 @@ class WebApp(ApplicationABC):
|
||||
else:
|
||||
app = self._app
|
||||
|
||||
self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
|
||||
_logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
|
||||
|
||||
config = uvicorn.Config(
|
||||
app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio"
|
||||
@@ -246,4 +245,4 @@ class WebApp(ApplicationABC):
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
self._logger.info("Shutdown API")
|
||||
_logger.info("Shutdown API")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||
from cpl.core.log.logger import Logger
|
||||
|
||||
|
||||
class APILogger(WrappedLogger):
|
||||
class APILogger(Logger):
|
||||
|
||||
def __init__(self):
|
||||
WrappedLogger.__init__(self, "api")
|
||||
def __init__(self, source: str):
|
||||
Logger.__init__(self, source, "api")
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from .authentication import AuthenticationMiddleware
|
||||
from .authorization import AuthorizationMiddleware
|
||||
from .logging import LoggingMiddleware
|
||||
from .request import RequestMiddleware
|
||||
|
||||
@@ -2,22 +2,24 @@ from keycloak import KeycloakAuthenticationError
|
||||
from starlette.types import Scope, Receive, Send
|
||||
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.error import Unauthorized
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.error import Unauthorized
|
||||
from cpl.api.middleware.request import get_request
|
||||
from cpl.api.router import Router
|
||||
from cpl.auth.keycloak import KeycloakClient
|
||||
from cpl.auth.schema import AuthUserDao, AuthUser
|
||||
from cpl.core.ctx import set_user
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
|
||||
class AuthenticationMiddleware(ASGIMiddleware):
|
||||
|
||||
def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, keycloak: KeycloakClient, user_dao: AuthUserDao):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._logger = logger
|
||||
|
||||
self._keycloak = keycloak
|
||||
self._user_dao = user_dao
|
||||
|
||||
@@ -26,11 +28,11 @@ class AuthenticationMiddleware(ASGIMiddleware):
|
||||
url = request.url.path
|
||||
|
||||
if url not in Router.get_auth_required_routes():
|
||||
self._logger.trace(f"No authentication required for {url}")
|
||||
_logger.trace(f"No authentication required for {url}")
|
||||
return await self._app(scope, receive, send)
|
||||
|
||||
if not request.headers.get("Authorization"):
|
||||
self._logger.debug(f"Unauthorized access to {url}, missing Authorization header")
|
||||
_logger.debug(f"Unauthorized access to {url}, missing Authorization header")
|
||||
return await Unauthorized(f"Missing header Authorization").asgi_response(scope, receive, send)
|
||||
|
||||
auth_header = request.headers.get("Authorization", None)
|
||||
@@ -39,7 +41,7 @@ class AuthenticationMiddleware(ASGIMiddleware):
|
||||
|
||||
token = auth_header.split("Bearer ")[1]
|
||||
if not await self._verify_login(token):
|
||||
self._logger.debug(f"Unauthorized access to {url}, invalid token")
|
||||
_logger.debug(f"Unauthorized access to {url}, invalid token")
|
||||
return await Unauthorized("Invalid token").asgi_response(scope, receive, send)
|
||||
|
||||
# check user exists in db, if not create
|
||||
@@ -49,7 +51,7 @@ class AuthenticationMiddleware(ASGIMiddleware):
|
||||
|
||||
user = await self._get_or_crate_user(keycloak_id)
|
||||
if user.deleted:
|
||||
self._logger.debug(f"Unauthorized access to {url}, user is deleted")
|
||||
_logger.debug(f"Unauthorized access to {url}, user is deleted")
|
||||
return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
|
||||
|
||||
request.state.user = user
|
||||
@@ -71,8 +73,8 @@ class AuthenticationMiddleware(ASGIMiddleware):
|
||||
token_info = self._keycloak.introspect(token)
|
||||
return token_info.get("active", False)
|
||||
except KeycloakAuthenticationError as e:
|
||||
self._logger.debug(f"Keycloak authentication error: {e}")
|
||||
_logger.debug(f"Keycloak authentication error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
self._logger.error(f"Unexpected error during token verification: {e}")
|
||||
_logger.error(f"Unexpected error during token verification: {e}")
|
||||
return False
|
||||
|
||||
@@ -9,15 +9,17 @@ from cpl.api.registry.policy import PolicyRegistry
|
||||
from cpl.api.router import Router
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
from cpl.core.ctx.user_context import get_user
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
|
||||
class AuthorizationMiddleware(ASGIMiddleware):
|
||||
|
||||
def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao):
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, policies: PolicyRegistry, user_dao: AuthUserDao):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._logger = logger
|
||||
|
||||
self._policies = policies
|
||||
self._user_dao = user_dao
|
||||
|
||||
@@ -26,7 +28,7 @@ class AuthorizationMiddleware(ASGIMiddleware):
|
||||
url = request.url.path
|
||||
|
||||
if url not in Router.get_authorization_rules_paths():
|
||||
self._logger.trace(f"No authorization required for {url}")
|
||||
_logger.trace(f"No authorization required for {url}")
|
||||
return await self._app(scope, receive, send)
|
||||
|
||||
user = get_user()
|
||||
@@ -51,21 +53,17 @@ class AuthorizationMiddleware(ASGIMiddleware):
|
||||
|
||||
if rule["permissions"]:
|
||||
if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]):
|
||||
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
|
||||
scope, receive, send
|
||||
)
|
||||
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
|
||||
if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]):
|
||||
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
|
||||
scope, receive, send
|
||||
)
|
||||
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
|
||||
|
||||
for policy_name in rule["policies"]:
|
||||
policy = self._policies.get(policy_name)
|
||||
if not policy:
|
||||
self._logger.warning(f"Authorization policy '{policy_name}' not found")
|
||||
_logger.warning(f"Authorization policy '{policy_name}' not found")
|
||||
continue
|
||||
|
||||
if not await policy.resolve(user):
|
||||
return await Forbidden(f"policy {policy.name} failed").asgi_response(scope, receive, send)
|
||||
|
||||
return await self._call_next(scope, receive, send)
|
||||
return await self._call_next(scope, receive, send)
|
||||
@@ -7,14 +7,14 @@ from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.middleware.request import get_request
|
||||
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
|
||||
class LoggingMiddleware(ASGIMiddleware):
|
||||
|
||||
def __init__(self, app, logger: APILogger):
|
||||
def __init__(self, app):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._logger = logger
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] != "http":
|
||||
await self._call_next(scope, receive, send)
|
||||
@@ -53,8 +53,9 @@ class LoggingMiddleware(ASGIMiddleware):
|
||||
}
|
||||
return {key: value for key, value in headers.items() if key in relevant_keys}
|
||||
|
||||
async def _log_request(self, request: Request):
|
||||
self._logger.debug(
|
||||
@classmethod
|
||||
async def _log_request(cls, request: Request):
|
||||
_logger.debug(
|
||||
f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}"
|
||||
)
|
||||
|
||||
@@ -63,7 +64,7 @@ class LoggingMiddleware(ASGIMiddleware):
|
||||
user = get_user()
|
||||
|
||||
request_info = {
|
||||
"headers": self._filter_relevant_headers(dict(request.headers)),
|
||||
"headers": cls._filter_relevant_headers(dict(request.headers)),
|
||||
"args": dict(request.query_params),
|
||||
"form-data": (
|
||||
await request.form()
|
||||
@@ -77,9 +78,10 @@ class LoggingMiddleware(ASGIMiddleware):
|
||||
),
|
||||
}
|
||||
|
||||
self._logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
|
||||
_logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
|
||||
|
||||
async def _log_after_request(self, request: Request, status_code: int, duration: float):
|
||||
self._logger.info(
|
||||
@staticmethod
|
||||
async def _log_after_request(request: Request, status_code: int, duration: float):
|
||||
_logger.info(
|
||||
f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
|
||||
)
|
||||
|
||||
@@ -9,20 +9,16 @@ from starlette.types import Scope, Receive, Send
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.typing import TRequest
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
|
||||
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
|
||||
class RequestMiddleware(ASGIMiddleware):
|
||||
|
||||
def __init__(self, app, provider: ServiceProvider, logger: APILogger):
|
||||
def __init__(self, app):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._provider = provider
|
||||
self._logger = logger
|
||||
|
||||
self._ctx_token = None
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
@@ -30,15 +26,14 @@ class RequestMiddleware(ASGIMiddleware):
|
||||
await self.set_request_data(request)
|
||||
|
||||
try:
|
||||
with self._provider.create_scope():
|
||||
inject(await self._app(scope, receive, send))
|
||||
await self._app(scope, receive, send)
|
||||
finally:
|
||||
await self.clean_request_data()
|
||||
|
||||
async def set_request_data(self, request: TRequest):
|
||||
request.state.request_id = uuid4()
|
||||
request.state.start_time = time.time()
|
||||
self._logger.trace(f"Set new current request: {request.state.request_id}")
|
||||
_logger.trace(f"Set new current request: {request.state.request_id}")
|
||||
|
||||
self._ctx_token = _request_context.set(request)
|
||||
|
||||
@@ -50,7 +45,7 @@ class RequestMiddleware(ASGIMiddleware):
|
||||
if self._ctx_token is None:
|
||||
return
|
||||
|
||||
self._logger.trace(f"Clearing current request: {request.state.request_id}")
|
||||
_logger.trace(f"Clearing current request: {request.state.request_id}")
|
||||
_request_context.reset(self._ctx_token)
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .api_route import ApiRoute
|
||||
from .policy import Policy
|
||||
from .validation_match import ValidationMatch
|
||||
|
||||
@@ -7,7 +7,13 @@ from cpl.api.typing import HTTPMethods
|
||||
|
||||
class ApiRoute:
|
||||
|
||||
def __init__(self, path: str, fn: Callable, method: HTTPMethods, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
fn: Callable,
|
||||
method: HTTPMethods,
|
||||
**kwargs
|
||||
):
|
||||
self._path = path
|
||||
self._fn = fn
|
||||
self._method = method
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from asyncio import iscoroutinefunction
|
||||
from typing import Optional
|
||||
from typing import Optional, Any, Coroutine, Awaitable
|
||||
|
||||
from cpl.api.typing import PolicyResolver
|
||||
from cpl.core.ctx import get_user
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
from .policy import PolicyRegistry
|
||||
from .route import RouteRegistry
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from cpl.api.model.policy import Policy
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
from cpl.core.abc.registry_abc import RegistryABC
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from enum import Enum
|
||||
from cpl.api.model.validation_match import ValidationMatch
|
||||
from cpl.api.registry.route import RouteRegistry
|
||||
from cpl.api.typing import HTTPMethods
|
||||
from cpl.dependency import get_provider
|
||||
|
||||
|
||||
class Router:
|
||||
@@ -42,13 +41,7 @@ class Router:
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def authorize(
|
||||
cls,
|
||||
roles: list[str | Enum] = None,
|
||||
permissions: list[str | Enum] = None,
|
||||
policies: list[str] = None,
|
||||
match: ValidationMatch = None,
|
||||
):
|
||||
def authorize(cls, roles: list[str | Enum]=None, permissions: list[str | Enum]=None, policies: list[str]=None, match: ValidationMatch=None):
|
||||
"""
|
||||
Decorator to mark a route as requiring authorization.
|
||||
Usage:
|
||||
@@ -92,14 +85,15 @@ class Router:
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs):
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
|
||||
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry=None, **kwargs):
|
||||
if not registry:
|
||||
routes = get_provider().get_service(RouteRegistry)
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
routes = ServiceProviderABC.get_global_service(RouteRegistry)
|
||||
else:
|
||||
routes = registry
|
||||
|
||||
|
||||
def inner(fn):
|
||||
routes.add(ApiRoute(path, fn, method, **kwargs))
|
||||
setattr(fn, "_route_path", path)
|
||||
@@ -143,9 +137,8 @@ class Router:
|
||||
"""
|
||||
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
|
||||
routes = get_provider().get_service(RouteRegistry)
|
||||
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
routes = ServiceProviderABC.get_global_service(RouteRegistry)
|
||||
def inner(fn):
|
||||
path = getattr(fn, "_route_path", None)
|
||||
if path is None:
|
||||
@@ -154,7 +147,7 @@ class Router:
|
||||
route = routes.get(path)
|
||||
if route is None:
|
||||
raise ValueError(f"Cannot override a route that does not exist: {path}")
|
||||
|
||||
|
||||
routes.add(ApiRoute(path, fn, route.method, **route.kwargs))
|
||||
setattr(fn, "_route_path", path)
|
||||
return fn
|
||||
|
||||
@@ -16,4 +16,4 @@ PartialMiddleware = Union[
|
||||
Middleware,
|
||||
Callable[[ASGIApp], ASGIApp],
|
||||
]
|
||||
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]
|
||||
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]
|
||||
@@ -2,10 +2,11 @@ from abc import ABC, abstractmethod
|
||||
from typing import Callable, Self
|
||||
|
||||
from cpl.application.host import Host
|
||||
from cpl.core.console.console import Console
|
||||
from cpl.core.log import LogSettings
|
||||
from cpl.core.log.log_level import LogLevel
|
||||
from cpl.core.log.log_settings import LogSettings
|
||||
from cpl.core.log.logger_abc import LoggerABC
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
|
||||
def __not_implemented__(package: str, func: Callable):
|
||||
@@ -16,12 +17,12 @@ class ApplicationABC(ABC):
|
||||
r"""ABC for the Application class
|
||||
|
||||
Parameters:
|
||||
services: :class:`cpl.dependency.service_provider.ServiceProvider`
|
||||
services: :class:`cpl.dependency.service_provider_abc.ServiceProviderABC`
|
||||
Contains instances of prepared objects
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, services: ServiceProvider, required_modules: list[str | object] = None):
|
||||
def __init__(self, services: ServiceProviderABC, required_modules: list[str | object] = None):
|
||||
self._services = services
|
||||
self._required_modules = (
|
||||
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class ApplicationExtensionABC(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def run(services: ServiceProvider): ...
|
||||
def run(services: ServiceProviderABC): ...
|
||||
|
||||
@@ -7,7 +7,6 @@ from cpl.application.abc.startup_abc import StartupABC
|
||||
from cpl.application.abc.startup_extension_abc import StartupExtensionABC
|
||||
from cpl.application.host import Host
|
||||
from cpl.core.errors import dependency_error
|
||||
from cpl.dependency.context import get_provider, use_root_provider
|
||||
from cpl.dependency.service_collection import ServiceCollection
|
||||
|
||||
TApp = TypeVar("TApp", bound=ApplicationABC)
|
||||
@@ -22,7 +21,6 @@ class ApplicationBuilder(Generic[TApp]):
|
||||
self._app = app if app is not None else ApplicationABC
|
||||
|
||||
self._services = ServiceCollection()
|
||||
use_root_provider(self._services.build())
|
||||
|
||||
self._startup: Optional[StartupABC] = None
|
||||
self._app_extensions: list[Type[ApplicationExtensionABC]] = []
|
||||
@@ -36,12 +34,7 @@ class ApplicationBuilder(Generic[TApp]):
|
||||
|
||||
@property
|
||||
def service_provider(self):
|
||||
provider = get_provider()
|
||||
if provider is None:
|
||||
provider = self._services.build()
|
||||
use_root_provider(provider)
|
||||
|
||||
return provider
|
||||
return self._services.build()
|
||||
|
||||
def validate_app_required_modules(self, app: ApplicationABC):
|
||||
for module in app.required_modules:
|
||||
|
||||
@@ -6,7 +6,7 @@ from cpl.auth import permission as _permission
|
||||
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin as _KeycloakAdmin
|
||||
from cpl.auth.keycloak.keycloak_client import KeycloakClient as _KeycloakClient
|
||||
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
|
||||
from .logger import AuthLogger
|
||||
from .auth_logger import AuthLogger
|
||||
from .keycloak_settings import KeycloakSettings
|
||||
from .permission_seeder import PermissionSeeder
|
||||
|
||||
|
||||
8
src/cpl-auth/cpl/auth/auth_logger.py
Normal file
8
src/cpl-auth/cpl/auth/auth_logger.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from cpl.core.log import Logger
|
||||
from cpl.core.typing import Source
|
||||
|
||||
|
||||
class AuthLogger(Logger):
|
||||
|
||||
def __init__(self, source: Source):
|
||||
Logger.__init__(self, source, "auth")
|
||||
@@ -1,13 +1,15 @@
|
||||
from keycloak import KeycloakAdmin as _KeycloakAdmin, KeycloakOpenIDConnection
|
||||
|
||||
from cpl.auth.auth_logger import AuthLogger
|
||||
from cpl.auth.keycloak_settings import KeycloakSettings
|
||||
from cpl.auth.logger import AuthLogger
|
||||
|
||||
_logger = AuthLogger("keycloak")
|
||||
|
||||
|
||||
class KeycloakAdmin(_KeycloakAdmin):
|
||||
|
||||
def __init__(self, logger: AuthLogger, settings: KeycloakSettings):
|
||||
# logger.info("Initializing Keycloak admin")
|
||||
def __init__(self, settings: KeycloakSettings):
|
||||
_logger.info("Initializing Keycloak admin")
|
||||
_connection = KeycloakOpenIDConnection(
|
||||
server_url=settings.url,
|
||||
client_id=settings.client_id,
|
||||
|
||||
@@ -2,13 +2,15 @@ from typing import Optional
|
||||
|
||||
from keycloak import KeycloakOpenID
|
||||
|
||||
from cpl.auth.logger import AuthLogger
|
||||
from cpl.auth.auth_logger import AuthLogger
|
||||
from cpl.auth.keycloak_settings import KeycloakSettings
|
||||
|
||||
_logger = AuthLogger("keycloak")
|
||||
|
||||
|
||||
class KeycloakClient(KeycloakOpenID):
|
||||
|
||||
def __init__(self, logger: AuthLogger, settings: KeycloakSettings):
|
||||
def __init__(self, settings: KeycloakSettings):
|
||||
KeycloakOpenID.__init__(
|
||||
self,
|
||||
server_url=settings.url,
|
||||
@@ -16,7 +18,7 @@ class KeycloakClient(KeycloakOpenID):
|
||||
realm_name=settings.realm,
|
||||
client_secret_key=settings.client_secret,
|
||||
)
|
||||
logger.info("Initializing Keycloak client")
|
||||
_logger.info("Initializing Keycloak client")
|
||||
|
||||
def get_user_id(self, token: str) -> Optional[str]:
|
||||
info = self.introspect(token)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from cpl.core.utils.get_value import get_value
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class KeycloakUser:
|
||||
@@ -32,5 +32,5 @@ class KeycloakUser:
|
||||
def id(self) -> str:
|
||||
from cpl.auth import KeycloakAdmin
|
||||
|
||||
keycloak_admin: KeycloakAdmin = get_provider().get_service(KeycloakAdmin)
|
||||
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
return keycloak_admin.get_user_id(self._username)
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||
|
||||
|
||||
class AuthLogger(WrappedLogger):
|
||||
|
||||
def __init__(self):
|
||||
WrappedLogger.__init__(self, "auth")
|
||||
@@ -14,13 +14,14 @@ from cpl.auth.schema import (
|
||||
)
|
||||
from cpl.core.utils.get_value import get_value
|
||||
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class PermissionSeeder(DataSeederABC):
|
||||
def __init__(
|
||||
self,
|
||||
logger: DBLogger,
|
||||
permission_dao: PermissionDao,
|
||||
role_dao: RoleDao,
|
||||
role_permission_dao: RolePermissionDao,
|
||||
@@ -28,7 +29,6 @@ class PermissionSeeder(DataSeederABC):
|
||||
api_key_permission_dao: ApiKeyPermissionDao,
|
||||
):
|
||||
DataSeederABC.__init__(self)
|
||||
self._logger = logger
|
||||
self._permission_dao = permission_dao
|
||||
self._role_dao = role_dao
|
||||
self._role_permission_dao = role_permission_dao
|
||||
@@ -40,7 +40,7 @@ class PermissionSeeder(DataSeederABC):
|
||||
possible_permissions = [permission for permission in PermissionsRegistry.get()]
|
||||
|
||||
if len(permissions) == len(possible_permissions):
|
||||
self._logger.info("Permissions already existing")
|
||||
_logger.info("Permissions already existing")
|
||||
await self._update_missing_descriptions()
|
||||
return
|
||||
|
||||
@@ -53,7 +53,7 @@ class PermissionSeeder(DataSeederABC):
|
||||
|
||||
await self._permission_dao.delete_many(to_delete, hard_delete=True)
|
||||
|
||||
self._logger.warning("Permissions incomplete")
|
||||
_logger.warning("Permissions incomplete")
|
||||
permission_names = [permission.name for permission in permissions]
|
||||
await self._permission_dao.create_many(
|
||||
[
|
||||
|
||||
@@ -10,8 +10,7 @@ from cpl.core.log.logger import Logger
|
||||
from cpl.core.typing import Id, SerialId
|
||||
from cpl.core.utils.credential_manager import CredentialManager
|
||||
from cpl.database.abc.db_model_abc import DbModelABC
|
||||
from cpl.dependency import get_provider
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
_logger = Logger(__name__)
|
||||
|
||||
@@ -48,7 +47,7 @@ class ApiKey(DbModelABC):
|
||||
async def permissions(self):
|
||||
from cpl.auth.schema._permission.api_key_permission_dao import ApiKeyPermissionDao
|
||||
|
||||
apiKeyPermissionDao = get_provider().get_service(ApiKeyPermissionDao)
|
||||
apiKeyPermissionDao = ServiceProviderABC.get_global_provider().get_service(ApiKeyPermissionDao)
|
||||
|
||||
return [await x.permission for x in await apiKeyPermissionDao.find_by_api_key_id(self.id)]
|
||||
|
||||
|
||||
@@ -3,12 +3,15 @@ from typing import Optional
|
||||
from cpl.auth.schema._administration.api_key import ApiKey
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class ApiKeyDao(DbModelDaoABC[ApiKey]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, ApiKey, TableManager.get("api_keys"))
|
||||
DbModelDaoABC.__init__(self, __name__, ApiKey, TableManager.get("api_keys"))
|
||||
|
||||
self.attribute(ApiKey.identifier, str)
|
||||
self.attribute(ApiKey.key, str, "keystring")
|
||||
|
||||
@@ -6,11 +6,13 @@ from async_property import async_property
|
||||
from keycloak import KeycloakGetError
|
||||
|
||||
from cpl.auth.keycloak import KeycloakAdmin
|
||||
from cpl.auth.auth_logger import AuthLogger
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
_logger = AuthLogger(__name__)
|
||||
|
||||
|
||||
class AuthUser(DbModelABC):
|
||||
@@ -36,13 +38,12 @@ class AuthUser(DbModelABC):
|
||||
return "ANONYMOUS"
|
||||
|
||||
try:
|
||||
keycloak = get_provider().get_service(KeycloakAdmin)
|
||||
return keycloak.get_user(self._keycloak_id).get("username")
|
||||
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
return keycloak_admin.get_user(self._keycloak_id).get("username")
|
||||
except KeycloakGetError as e:
|
||||
return "UNKNOWN"
|
||||
except Exception as e:
|
||||
logger = get_provider().get_service(DBLogger)
|
||||
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||
_logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||
return "UNKNOWN"
|
||||
|
||||
@property
|
||||
@@ -51,39 +52,38 @@ class AuthUser(DbModelABC):
|
||||
return "ANONYMOUS"
|
||||
|
||||
try:
|
||||
keycloak = get_provider().get_service(KeycloakAdmin)
|
||||
return keycloak.get_user(self._keycloak_id).get("email")
|
||||
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
return keycloak_admin.get_user(self._keycloak_id).get("email")
|
||||
except KeycloakGetError as e:
|
||||
return "UNKNOWN"
|
||||
except Exception as e:
|
||||
logger = get_provider().get_service(DBLogger)
|
||||
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||
_logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||
return "UNKNOWN"
|
||||
|
||||
@async_property
|
||||
async def roles(self):
|
||||
from cpl.auth.schema._permission.role_user_dao import RoleUserDao
|
||||
|
||||
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao)
|
||||
role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
|
||||
return [await x.role for x in await role_user_dao.get_by_user_id(self.id)]
|
||||
|
||||
@async_property
|
||||
async def permissions(self):
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
|
||||
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
|
||||
return await auth_user_dao.get_permissions(self.id)
|
||||
|
||||
async def has_permission(self, permission: Permissions) -> bool:
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
|
||||
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
|
||||
return await auth_user_dao.has_permission(self.id, permission)
|
||||
|
||||
async def anonymize(self):
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
|
||||
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
|
||||
|
||||
self._keycloak_id = str(uuid.UUID(int=0))
|
||||
await auth_user_dao.update(self)
|
||||
|
||||
@@ -4,14 +4,17 @@ from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.auth.schema._administration.auth_user import AuthUser
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class AuthUserDao(DbModelDaoABC[AuthUser]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, AuthUser, TableManager.get("auth_users"))
|
||||
DbModelDaoABC.__init__(self, __name__, AuthUser, TableManager.get("auth_users"))
|
||||
|
||||
self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId")
|
||||
|
||||
@@ -36,7 +39,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
|
||||
async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool:
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
|
||||
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
|
||||
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
|
||||
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
|
||||
result = await self._db.select_map(
|
||||
f"""
|
||||
|
||||
@@ -5,7 +5,7 @@ from async_property import async_property
|
||||
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbJoinModelABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class ApiKeyPermission(DbJoinModelABC):
|
||||
@@ -31,7 +31,7 @@ class ApiKeyPermission(DbJoinModelABC):
|
||||
async def api_key(self):
|
||||
from cpl.auth.schema._administration.api_key_dao import ApiKeyDao
|
||||
|
||||
api_key_dao: ApiKeyDao = get_provider().get_service(ApiKeyDao)
|
||||
api_key_dao: ApiKeyDao = ServiceProviderABC.get_global_service(ApiKeyDao)
|
||||
return await api_key_dao.get_by_id(self._api_key_id)
|
||||
|
||||
@property
|
||||
@@ -42,5 +42,5 @@ class ApiKeyPermission(DbJoinModelABC):
|
||||
async def permission(self):
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
|
||||
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
|
||||
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
|
||||
return await permission_dao.get_by_id(self._permission_id)
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from cpl.auth.schema._permission.api_key_permission import ApiKeyPermission
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class ApiKeyPermissionDao(DbModelDaoABC[ApiKeyPermission]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, ApiKeyPermission, TableManager.get("api_key_permissions"))
|
||||
DbModelDaoABC.__init__(self, __name__, ApiKeyPermission, TableManager.get("api_key_permissions"))
|
||||
|
||||
self.attribute(ApiKeyPermission.api_key_id, int)
|
||||
self.attribute(ApiKeyPermission.permission_id, int)
|
||||
|
||||
@@ -3,12 +3,15 @@ from typing import Optional
|
||||
from cpl.auth.schema._permission.permission import Permission
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class PermissionDao(DbModelDaoABC[Permission]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, Permission, TableManager.get("permissions"))
|
||||
DbModelDaoABC.__init__(self, __name__, Permission, TableManager.get("permissions"))
|
||||
|
||||
self.attribute(Permission.name, str)
|
||||
self.attribute(Permission.description, Optional[str])
|
||||
|
||||
@@ -6,7 +6,7 @@ from async_property import async_property
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class Role(DbModelABC):
|
||||
@@ -44,22 +44,22 @@ class Role(DbModelABC):
|
||||
async def permissions(self):
|
||||
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
|
||||
|
||||
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao)
|
||||
role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
|
||||
return [await x.permission for x in await role_permission_dao.get_by_role_id(self.id)]
|
||||
|
||||
@async_property
|
||||
async def users(self):
|
||||
from cpl.auth.schema._permission.role_user_dao import RoleUserDao
|
||||
|
||||
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao)
|
||||
role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
|
||||
return [await x.user for x in await role_user_dao.get_by_role_id(self.id)]
|
||||
|
||||
async def has_permission(self, permission: Permissions) -> bool:
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
|
||||
|
||||
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
|
||||
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao)
|
||||
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
|
||||
role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
|
||||
|
||||
p = await permission_dao.get_by_name(permission.value)
|
||||
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from cpl.auth.schema._permission.role import Role
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class RoleDao(DbModelDaoABC[Role]):
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, Role, TableManager.get("roles"))
|
||||
DbModelDaoABC.__init__(self, __name__, Role, TableManager.get("roles"))
|
||||
self.attribute(Role.name, str)
|
||||
self.attribute(Role.description, str)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from async_property import async_property
|
||||
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class RolePermission(DbModelABC):
|
||||
@@ -31,7 +31,7 @@ class RolePermission(DbModelABC):
|
||||
async def role(self):
|
||||
from cpl.auth.schema._permission.role_dao import RoleDao
|
||||
|
||||
role_dao: RoleDao = get_provider().get_service(RoleDao)
|
||||
role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
|
||||
return await role_dao.get_by_id(self._role_id)
|
||||
|
||||
@property
|
||||
@@ -42,5 +42,5 @@ class RolePermission(DbModelABC):
|
||||
async def permission(self):
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
|
||||
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
|
||||
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
|
||||
return await permission_dao.get_by_id(self._permission_id)
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from cpl.auth.schema._permission.role_permission import RolePermission
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class RolePermissionDao(DbModelDaoABC[RolePermission]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, RolePermission, TableManager.get("role_permissions"))
|
||||
DbModelDaoABC.__init__(self, __name__, RolePermission, TableManager.get("role_permissions"))
|
||||
|
||||
self.attribute(RolePermission.role_id, int)
|
||||
self.attribute(RolePermission.permission_id, int)
|
||||
|
||||
@@ -5,7 +5,7 @@ from async_property import async_property
|
||||
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbJoinModelABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class RoleUser(DbJoinModelABC):
|
||||
@@ -31,7 +31,7 @@ class RoleUser(DbJoinModelABC):
|
||||
async def user(self):
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
|
||||
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
|
||||
return await auth_user_dao.get_by_id(self._user_id)
|
||||
|
||||
@property
|
||||
@@ -42,5 +42,5 @@ class RoleUser(DbJoinModelABC):
|
||||
async def role(self):
|
||||
from cpl.auth.schema._permission.role_dao import RoleDao
|
||||
|
||||
role_dao: RoleDao = get_provider().get_service(RoleDao)
|
||||
role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
|
||||
return await role_dao.get_by_id(self._role_id)
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from cpl.auth.schema._permission.role_user import RoleUser
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class RoleUserDao(DbModelDaoABC[RoleUser]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, RoleUser, TableManager.get("role_users"))
|
||||
DbModelDaoABC.__init__(self, __name__, RoleUser, TableManager.get("role_users"))
|
||||
|
||||
self.attribute(RoleUser.role_id, int)
|
||||
self.attribute(RoleUser.user_id, int)
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
|
||||
from cpl.auth.auth_logger import AuthLogger
|
||||
from cpl.auth.schema._administration.auth_user import AuthUser
|
||||
from cpl.dependency import get_provider
|
||||
|
||||
_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None)
|
||||
|
||||
_logger = AuthLogger(__name__)
|
||||
|
||||
def set_user(user: Optional[AuthUser]):
|
||||
from cpl.core.log.logger_abc import LoggerABC
|
||||
|
||||
logger = get_provider().get_service(LoggerABC)
|
||||
logger.trace("Setting user context", user.id)
|
||||
_user_context.set(user)
|
||||
def set_user(user_id: Optional[AuthUser]):
|
||||
_logger.trace("Setting user context", user_id)
|
||||
_user_context.set(user_id)
|
||||
|
||||
|
||||
def get_user() -> Optional[AuthUser]:
|
||||
|
||||
@@ -2,4 +2,3 @@ from .logger import Logger
|
||||
from .logger_abc import LoggerABC
|
||||
from .log_level import LogLevel
|
||||
from .log_settings import LogSettings
|
||||
from .structured_logger import StructuredLogger
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import json
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
from cpl.core.log.log_level import LogLevel
|
||||
from cpl.core.log.logger import Logger
|
||||
from cpl.core.typing import Source, Messages
|
||||
from cpl.dependency import get_provider
|
||||
|
||||
|
||||
class StructuredLogger(Logger):
|
||||
|
||||
def __init__(self, source: Source, file_prefix: str = None):
|
||||
Logger.__init__(self, source, file_prefix)
|
||||
|
||||
@property
|
||||
def log_file(self):
|
||||
return f"logs/{self._file_prefix}_{datetime.now().strftime('%Y-%m-%d')}.jsonl"
|
||||
|
||||
def _log(self, level: LogLevel, *messages: Messages):
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
formatted_message = self._format_message(level.value, timestamp, *messages)
|
||||
structured_message = self._get_structured_message(level.value, timestamp, formatted_message)
|
||||
|
||||
self._write_log_to_file(level, structured_message)
|
||||
self._write_to_console(level, formatted_message)
|
||||
except Exception as e:
|
||||
print(f"Error while logging: {e} -> {traceback.format_exc()}")
|
||||
|
||||
def _get_structured_message(self, level: str, timestamp: str, messages: str) -> str:
|
||||
structured_message = {
|
||||
"timestamp": timestamp,
|
||||
"level": level.upper(),
|
||||
"source": self._source,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
self._enrich_message_with_request(structured_message)
|
||||
self._enrich_message_with_user(structured_message)
|
||||
|
||||
return json.dumps(structured_message, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def _scope_to_json(request: Request, include_headers: bool = False) -> dict:
|
||||
scope = dict(request.scope)
|
||||
|
||||
def convert(value):
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8")
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [convert(v) for v in value]
|
||||
if isinstance(value, dict):
|
||||
return {str(k): convert(v) for k, v in value.items()}
|
||||
if not isinstance(value, (str, int, float, bool, type(None))):
|
||||
return str(value)
|
||||
return value
|
||||
|
||||
serializable_scope = {str(k): convert(v) for k, v in scope.items()}
|
||||
|
||||
if not include_headers and "headers" in serializable_scope:
|
||||
serializable_scope["headers"] = "<omitted>"
|
||||
|
||||
return serializable_scope
|
||||
|
||||
def _enrich_message_with_request(self, message: dict):
|
||||
if importlib.util.find_spec("cpl.api") is None:
|
||||
return
|
||||
|
||||
from cpl.api.middleware.request import get_request
|
||||
from starlette.requests import Request
|
||||
|
||||
request = get_request()
|
||||
|
||||
if request is None:
|
||||
return
|
||||
|
||||
message["request"] = {
|
||||
"url": str(request.url),
|
||||
"method": request.method,
|
||||
"scope": self._scope_to_json(request),
|
||||
}
|
||||
if isinstance(request, Request) and request.scope == "http":
|
||||
request: Request = request # fix typing for IDEs
|
||||
|
||||
message["request"]["data"] = asyncio.create_task(request.body())
|
||||
|
||||
@staticmethod
|
||||
def _enrich_message_with_user(message: dict):
|
||||
if importlib.util.find_spec("cpl-auth") is None:
|
||||
return
|
||||
|
||||
from cpl.core.ctx import get_user
|
||||
|
||||
user = get_user()
|
||||
if user is None:
|
||||
return
|
||||
|
||||
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin
|
||||
|
||||
keycloak = get_provider().get_service(KeycloakAdmin)
|
||||
kc_user = keycloak.get_user(user.keycloak_id)
|
||||
message["user"] = {
|
||||
"id": str(user.id),
|
||||
"username": kc_user.get("username"),
|
||||
"email": kc_user.get("email"),
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
import inspect
|
||||
from typing import Type
|
||||
|
||||
from cpl.core.log import LoggerABC, LogLevel
|
||||
from cpl.core.typing import Messages
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
|
||||
class WrappedLogger(LoggerABC):
|
||||
|
||||
def __init__(self, file_prefix: str):
|
||||
LoggerABC.__init__(self)
|
||||
assert file_prefix is not None and file_prefix != "", "file_prefix must be a non-empty string"
|
||||
|
||||
self._source = None
|
||||
self._file_prefix = file_prefix
|
||||
|
||||
self._set_logger()
|
||||
|
||||
@inject
|
||||
def _set_logger(self, services: ServiceProvider):
|
||||
from cpl.core.log import Logger
|
||||
|
||||
t_logger: Type[Logger] = services.get_service_type(LoggerABC)
|
||||
if t_logger is None:
|
||||
raise Exception("No LoggerABC service registered in ServiceProvider")
|
||||
|
||||
self._logger = t_logger(self._source, self._file_prefix)
|
||||
|
||||
def set_level(self, level: LogLevel):
|
||||
self._logger.set_level(level)
|
||||
|
||||
def _format_message(self, level: str, timestamp, *messages: Messages) -> str:
|
||||
return self._logger._format_message(level, timestamp, *messages)
|
||||
|
||||
@staticmethod
|
||||
def _get_source() -> str | None:
|
||||
stack = inspect.stack()
|
||||
if len(stack) <= 1:
|
||||
return None
|
||||
|
||||
from cpl.dependency import ServiceCollection
|
||||
|
||||
ignore_classes = [
|
||||
ServiceProvider,
|
||||
ServiceProvider.__subclasses__(),
|
||||
ServiceCollection,
|
||||
WrappedLogger,
|
||||
WrappedLogger.__subclasses__(),
|
||||
]
|
||||
|
||||
ignore_modules = [x.__module__ for x in ignore_classes if isinstance(x, type)]
|
||||
|
||||
for i, frame_info in enumerate(stack[1:]):
|
||||
module = inspect.getmodule(frame_info.frame)
|
||||
if module is None:
|
||||
continue
|
||||
|
||||
if module.__name__ in ignore_classes or module in ignore_classes:
|
||||
continue
|
||||
|
||||
if module in ignore_modules or module.__name__ in ignore_modules:
|
||||
continue
|
||||
|
||||
if module.__name__ != __name__:
|
||||
return module.__name__
|
||||
|
||||
return None
|
||||
|
||||
def _set_source(self):
|
||||
self._source = self._get_source()
|
||||
self._set_logger()
|
||||
|
||||
def header(self, string: str):
|
||||
self._set_source()
|
||||
self._logger.header(string)
|
||||
|
||||
def trace(self, *messages: Messages):
|
||||
self._set_source()
|
||||
self._logger.trace(*messages)
|
||||
|
||||
def debug(self, *messages: Messages):
|
||||
self._set_source()
|
||||
self._logger.debug(*messages)
|
||||
|
||||
def info(self, *messages: Messages):
|
||||
self._set_source()
|
||||
self._logger.info(*messages)
|
||||
|
||||
def warning(self, *messages: Messages):
|
||||
self._set_source()
|
||||
self._logger.warning(*messages)
|
||||
|
||||
def error(self, messages: str, e: Exception = None):
|
||||
self._set_source()
|
||||
self._logger.error(messages, e)
|
||||
|
||||
def fatal(self, messages: str, e: Exception = None):
|
||||
self._set_source()
|
||||
self._logger.fatal(messages, e)
|
||||
@@ -14,4 +14,3 @@ UuidId = str | UUID
|
||||
SerialId = int
|
||||
|
||||
Id = UuidId | SerialId
|
||||
TNumber = int | float | complex
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
import time
|
||||
import tracemalloc
|
||||
from typing import List, Callable
|
||||
|
||||
from cpl.core.console import Console
|
||||
|
||||
|
||||
class Benchmark:
|
||||
|
||||
@staticmethod
|
||||
def all(label: str, func: Callable, iterations: int = 5):
|
||||
times: List[float] = []
|
||||
mems: List[float] = []
|
||||
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
func()
|
||||
end = time.perf_counter()
|
||||
times.append(end - start)
|
||||
|
||||
for _ in range(iterations):
|
||||
tracemalloc.start()
|
||||
func()
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
mems.append(peak)
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
avg_mem = sum(mems) / len(mems) / (1024 * 1024)
|
||||
Console.write_line(f"{label:20s} -> min {min(times):.6f}s avg {avg_time:.6f}s mem {avg_mem:.8f} MB")
|
||||
|
||||
@staticmethod
|
||||
def time(label: str, func: Callable, iterations: int = 5):
|
||||
times: List[float] = []
|
||||
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
func()
|
||||
end = time.perf_counter()
|
||||
times.append(end - start)
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
Console.write_line(f"{label:20s} -> min {min(times):.6f}s avg {avg_time:.6f}s")
|
||||
|
||||
@staticmethod
|
||||
def memory(label: str, func: Callable, iterations: int = 5):
|
||||
mems: List[float] = []
|
||||
|
||||
for _ in range(iterations):
|
||||
tracemalloc.start()
|
||||
func()
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
mems.append(peak)
|
||||
|
||||
avg_mem = sum(mems) / len(mems) / (1024 * 1024)
|
||||
Console.write_line(f"{label:20s} -> mem {avg_mem:.2f} MB")
|
||||
@@ -1,100 +0,0 @@
|
||||
import threading
|
||||
import time
|
||||
from typing import Generic
|
||||
|
||||
from cpl.core.typing import T
|
||||
|
||||
|
||||
class Cache(Generic[T]):
|
||||
def __init__(self, default_ttl: int = None, cleanup_interval: int = 60, t: type = None):
|
||||
self._store = {}
|
||||
self._default_ttl = default_ttl
|
||||
self._lock = threading.Lock()
|
||||
self._cleanup_interval = cleanup_interval
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
self._type = t
|
||||
|
||||
# Start background cleanup thread
|
||||
self._thread = threading.Thread(target=self._auto_cleanup, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def set(self, key: str, value: T, ttl: int = None) -> None:
|
||||
"""Store a value in the cache with optional TTL override."""
|
||||
expire_at = None
|
||||
ttl = ttl if ttl is not None else self._default_ttl
|
||||
if ttl is not None:
|
||||
expire_at = time.time() + ttl
|
||||
|
||||
with self._lock:
|
||||
self._store[key] = (value, expire_at)
|
||||
|
||||
def get(self, key: str) -> T | None:
|
||||
"""Retrieve a value from the cache if not expired."""
|
||||
with self._lock:
|
||||
item = self._store.get(key)
|
||||
if not item:
|
||||
return None
|
||||
value, expire_at = item
|
||||
if expire_at and expire_at < time.time():
|
||||
# Expired -> remove and return None
|
||||
del self._store[key]
|
||||
return None
|
||||
return value
|
||||
|
||||
def get_all(self) -> list[T]:
|
||||
"""Retrieve all non-expired values from the cache."""
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
valid_items = []
|
||||
expired_keys = []
|
||||
for k, (v, exp) in self._store.items():
|
||||
if exp and exp < now:
|
||||
expired_keys.append(k)
|
||||
else:
|
||||
valid_items.append(v)
|
||||
for k in expired_keys:
|
||||
del self._store[k]
|
||||
return valid_items
|
||||
|
||||
def has(self, key: str) -> bool:
|
||||
"""Check if a key exists and is not expired."""
|
||||
with self._lock:
|
||||
item = self._store.get(key)
|
||||
if not item:
|
||||
return False
|
||||
_, expire_at = item
|
||||
if expire_at and expire_at < time.time():
|
||||
# Expired -> remove and return False
|
||||
del self._store[key]
|
||||
return False
|
||||
return True
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
"""Remove an item from the cache."""
|
||||
with self._lock:
|
||||
self._store.pop(key, None)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the entire cache."""
|
||||
with self._lock:
|
||||
self._store.clear()
|
||||
|
||||
def _auto_cleanup(self):
|
||||
"""Background thread to clean expired items."""
|
||||
while not self._stop_event.is_set():
|
||||
self.cleanup()
|
||||
self._stop_event.wait(self._cleanup_interval)
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Remove expired items immediately."""
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
expired_keys = [k for k, (_, exp) in self._store.items() if exp and exp < now]
|
||||
for k in expired_keys:
|
||||
del self._store[k]
|
||||
|
||||
def stop(self):
|
||||
"""Stop the background cleanup thread."""
|
||||
self._stop_event.set()
|
||||
self._thread.join()
|
||||
@@ -1,48 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Number:
|
||||
|
||||
@staticmethod
|
||||
def is_number(value: Any) -> bool:
|
||||
"""Check if the value is a number (int or float)."""
|
||||
return isinstance(value, (int, float, complex))
|
||||
|
||||
@staticmethod
|
||||
def to_number(value: Any) -> int | float | complex:
|
||||
"""
|
||||
Convert a given value into int, float, or complex.
|
||||
Raises ValueError if conversion is not possible.
|
||||
"""
|
||||
|
||||
if isinstance(value, (int, float, complex)):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
for caster in (int, float, complex):
|
||||
try:
|
||||
return caster(value)
|
||||
except ValueError:
|
||||
continue
|
||||
raise ValueError(f"Cannot convert string '{value}' to number.")
|
||||
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return complex(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise ValueError(f"Cannot convert type {type(value)} to number.")
|
||||
@@ -114,15 +114,12 @@ class String:
|
||||
|
||||
characters = []
|
||||
if letters:
|
||||
characters.extend(string.ascii_letters)
|
||||
characters.append(string.ascii_letters)
|
||||
|
||||
if digits:
|
||||
characters.extend(string.digits)
|
||||
characters.append(string.digits)
|
||||
|
||||
if special_characters:
|
||||
characters.extend(string.punctuation)
|
||||
characters.append(string.punctuation)
|
||||
|
||||
x = "".join(random.choice(list(characters)) for _ in range(length)) if characters else ""
|
||||
if len(x) != length:
|
||||
raise Exception("No characters selected to generate random string")
|
||||
return x
|
||||
return "".join(random.choice(characters) for _ in range(length)) if characters else ""
|
||||
|
||||
@@ -9,19 +9,25 @@ from cpl.core.utils.get_value import get_value
|
||||
from cpl.core.utils.string import String
|
||||
from cpl.database.abc.db_context_abc import DBContextABC
|
||||
from cpl.database.const import DATETIME_FORMAT
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
|
||||
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
|
||||
from cpl.dependency import get_provider
|
||||
|
||||
|
||||
class DataAccessObjectABC(ABC, Generic[T_DBM]):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, model_type: Type[T_DBM], table_name: str):
|
||||
self._db = get_provider().get_service(DBContextABC)
|
||||
self._logger = get_provider().get_service(DBLogger)
|
||||
def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
self._db = ServiceProviderABC.get_global_service(DBContextABC)
|
||||
|
||||
self._logger = DBLogger(source)
|
||||
self._model_type = model_type
|
||||
self._table_name = table_name
|
||||
|
||||
self._logger = DBLogger(source)
|
||||
self._model_type = model_type
|
||||
self._table_name = table_name
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ from cpl.database.abc.db_model_abc import DbModelABC
|
||||
class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, model_type: Type[T_DBM], table_name: str):
|
||||
DataAccessObjectABC.__init__(self, model_type, table_name)
|
||||
def __init__(self, source: str, model_type: Type[T_DBM], table_name: str):
|
||||
DataAccessObjectABC.__init__(self, source, model_type, table_name)
|
||||
|
||||
self.attribute(DbModelABC.id, int, ignore=True)
|
||||
self.attribute(DbModelABC.deleted, bool)
|
||||
|
||||
8
src/cpl-database/cpl/database/db_logger.py
Normal file
8
src/cpl-database/cpl/database/db_logger.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from cpl.core.log import Logger
|
||||
from cpl.core.typing import Source
|
||||
|
||||
|
||||
class DBLogger(Logger):
|
||||
|
||||
def __init__(self, source: Source):
|
||||
Logger.__init__(self, source, "db")
|
||||
@@ -1,7 +0,0 @@
|
||||
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||
|
||||
|
||||
class DBLogger(WrappedLogger):
|
||||
|
||||
def __init__(self):
|
||||
WrappedLogger.__init__(self, "db")
|
||||
@@ -4,17 +4,18 @@ from typing import Any, List, Dict, Tuple, Union
|
||||
from mysql.connector import Error as MySQLError, PoolError
|
||||
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.database.abc.db_context_abc import DBContextABC
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.model.database_settings import DatabaseSettings
|
||||
from cpl.database.mysql.mysql_pool import MySQLPool
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class DBContext(DBContextABC):
|
||||
def __init__(self, logger: DBLogger):
|
||||
def __init__(self):
|
||||
DBContextABC.__init__(self)
|
||||
self._logger = logger
|
||||
|
||||
self._pool: MySQLPool = None
|
||||
self._fails = 0
|
||||
|
||||
@@ -22,62 +23,62 @@ class DBContext(DBContextABC):
|
||||
|
||||
def connect(self, database_settings: DatabaseSettings):
|
||||
try:
|
||||
self._logger.debug("Connecting to database")
|
||||
_logger.debug("Connecting to database")
|
||||
self._pool = MySQLPool(
|
||||
database_settings,
|
||||
)
|
||||
self._logger.info("Connected to database")
|
||||
_logger.info("Connected to database")
|
||||
except Exception as e:
|
||||
self._logger.fatal("Connecting to database failed", e)
|
||||
_logger.fatal("Connecting to database failed", e)
|
||||
|
||||
async def execute(self, statement: str, args=None, multi=True) -> List[List]:
|
||||
self._logger.trace(f"execute {statement} with args: {args}")
|
||||
_logger.trace(f"execute {statement} with args: {args}")
|
||||
return await self._pool.execute(statement, args, multi)
|
||||
|
||||
async def select_map(self, statement: str, args=None) -> List[Dict]:
|
||||
self._logger.trace(f"select {statement} with args: {args}")
|
||||
_logger.trace(f"select {statement} with args: {args}")
|
||||
try:
|
||||
return await self._pool.select_map(statement, args)
|
||||
except (MySQLError, PoolError) as e:
|
||||
if self._fails >= 3:
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
uid = uuid.uuid4()
|
||||
raise Exception(
|
||||
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
|
||||
)
|
||||
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._fails += 1
|
||||
try:
|
||||
self._logger.debug("Retry select")
|
||||
_logger.debug("Retry select")
|
||||
return await self.select_map(statement, args)
|
||||
except Exception as e:
|
||||
pass
|
||||
return []
|
||||
except Exception as e:
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
raise e
|
||||
|
||||
async def select(self, statement: str, args=None) -> Union[List[str], List[Tuple], List[Any]]:
|
||||
self._logger.trace(f"select {statement} with args: {args}")
|
||||
_logger.trace(f"select {statement} with args: {args}")
|
||||
try:
|
||||
return await self._pool.select(statement, args)
|
||||
except (MySQLError, PoolError) as e:
|
||||
if self._fails >= 3:
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
uid = uuid.uuid4()
|
||||
raise Exception(
|
||||
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
|
||||
)
|
||||
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._fails += 1
|
||||
try:
|
||||
self._logger.debug("Retry select")
|
||||
_logger.debug("Retry select")
|
||||
return await self.select(statement, args)
|
||||
except Exception as e:
|
||||
pass
|
||||
return []
|
||||
except Exception as e:
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
raise e
|
||||
|
||||
@@ -4,9 +4,10 @@ import sqlparse
|
||||
from mysql.connector.aio import MySQLConnectionPool
|
||||
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.model import DatabaseSettings
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class MySQLPool:
|
||||
@@ -35,8 +36,7 @@ class MySQLPool:
|
||||
await cursor.execute("SELECT 1")
|
||||
await cursor.fetchall()
|
||||
except Exception as e:
|
||||
logger = get_provider().get_service(DBLogger)
|
||||
logger.fatal(f"Error connecting to the database: {e}")
|
||||
_logger.fatal(f"Error connecting to the database: {e}")
|
||||
finally:
|
||||
await con.close()
|
||||
|
||||
|
||||
@@ -7,16 +7,16 @@ from psycopg_pool import PoolTimeout
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.database.abc.db_context_abc import DBContextABC
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.model import DatabaseSettings
|
||||
from cpl.database.database_settings import DatabaseSettings
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.postgres.postgres_pool import PostgresPool
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class DBContext(DBContextABC):
|
||||
def __init__(self, logger: DBLogger):
|
||||
def __init__(self):
|
||||
DBContextABC.__init__(self)
|
||||
|
||||
self._logger = logger
|
||||
self._pool: PostgresPool = None
|
||||
self._fails = 0
|
||||
|
||||
@@ -24,63 +24,63 @@ class DBContext(DBContextABC):
|
||||
|
||||
def connect(self, database_settings: DatabaseSettings):
|
||||
try:
|
||||
self._logger.debug("Connecting to database")
|
||||
_logger.debug("Connecting to database")
|
||||
self._pool = PostgresPool(
|
||||
database_settings,
|
||||
Environment.get("DB_POOL_SIZE", int, 1),
|
||||
)
|
||||
self._logger.info("Connected to database")
|
||||
_logger.info("Connected to database")
|
||||
except Exception as e:
|
||||
self._logger.fatal("Connecting to database failed", e)
|
||||
_logger.fatal("Connecting to database failed", e)
|
||||
|
||||
async def execute(self, statement: str, args=None, multi=True) -> list[list]:
|
||||
self._logger.trace(f"execute {statement} with args: {args}")
|
||||
_logger.trace(f"execute {statement} with args: {args}")
|
||||
return await self._pool.execute(statement, args, multi)
|
||||
|
||||
async def select_map(self, statement: str, args=None) -> list[dict]:
|
||||
self._logger.trace(f"select {statement} with args: {args}")
|
||||
_logger.trace(f"select {statement} with args: {args}")
|
||||
try:
|
||||
return await self._pool.select_map(statement, args)
|
||||
except (OperationalError, PoolTimeout) as e:
|
||||
if self._fails >= 3:
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
uid = uuid.uuid4()
|
||||
raise Exception(
|
||||
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
|
||||
)
|
||||
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._fails += 1
|
||||
try:
|
||||
self._logger.debug("Retry select")
|
||||
_logger.debug("Retry select")
|
||||
return await self.select_map(statement, args)
|
||||
except Exception as e:
|
||||
pass
|
||||
return []
|
||||
except Exception as e:
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
raise e
|
||||
|
||||
async def select(self, statement: str, args=None) -> list[str] | list[tuple] | list[Any]:
|
||||
self._logger.trace(f"select {statement} with args: {args}")
|
||||
_logger.trace(f"select {statement} with args: {args}")
|
||||
try:
|
||||
return await self._pool.select(statement, args)
|
||||
except (OperationalError, PoolTimeout) as e:
|
||||
if self._fails >= 3:
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
uid = uuid.uuid4()
|
||||
raise Exception(
|
||||
f"Query failed three times with {type(e).__name__}. Contact an admin with the UID: {uid}"
|
||||
)
|
||||
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
self._fails += 1
|
||||
try:
|
||||
self._logger.debug("Retry select")
|
||||
_logger.debug("Retry select")
|
||||
return await self.select(statement, args)
|
||||
except Exception as e:
|
||||
pass
|
||||
return []
|
||||
except Exception as e:
|
||||
self._logger.error(f"Database error caused by `{statement}`", e)
|
||||
_logger.error(f"Database error caused by `{statement}`", e)
|
||||
raise e
|
||||
|
||||
@@ -5,9 +5,10 @@ from psycopg import sql
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.model import DatabaseSettings
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class PostgresPool:
|
||||
@@ -37,8 +38,7 @@ class PostgresPool:
|
||||
await pool.check_connection(con)
|
||||
except PoolTimeout as e:
|
||||
await pool.close()
|
||||
logger = get_provider().get_service(DBLogger)
|
||||
logger.fatal(f"Failed to connect to the database", e)
|
||||
_logger.fatal(f"Failed to connect to the database", e)
|
||||
self._pool = pool
|
||||
|
||||
return self._pool
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.schema.executed_migration import ExecutedMigration
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class ExecutedMigrationDao(DataAccessObjectABC[ExecutedMigration]):
|
||||
|
||||
def __init__(self):
|
||||
DataAccessObjectABC.__init__(self, ExecutedMigration, TableManager.get("executed_migrations"))
|
||||
DataAccessObjectABC.__init__(self, __name__, ExecutedMigration, TableManager.get("executed_migrations"))
|
||||
|
||||
self.attribute(ExecutedMigration.migration_id, str, primary_key=True, db_name="migrationId")
|
||||
|
||||
@@ -2,17 +2,18 @@ import glob
|
||||
import os
|
||||
|
||||
from cpl.database.abc import DBContextABC
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.model import Migration
|
||||
from cpl.database.model.server_type import ServerType, ServerTypes
|
||||
from cpl.database.schema.executed_migration import ExecutedMigration
|
||||
from cpl.database.schema.executed_migration_dao import ExecutedMigrationDao
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class MigrationService:
|
||||
|
||||
def __init__(self, logger: DBLogger, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao):
|
||||
self._logger = logger
|
||||
def __init__(self, db: DBContextABC, executedMigrationDao: ExecutedMigrationDao):
|
||||
self._db = db
|
||||
self._executedMigrationDao = executedMigrationDao
|
||||
|
||||
@@ -95,13 +96,13 @@ class MigrationService:
|
||||
if migration_from_db is not None:
|
||||
continue
|
||||
|
||||
self._logger.debug(f"Running upgrade migration: {migration.name}")
|
||||
_logger.debug(f"Running upgrade migration: {migration.name}")
|
||||
|
||||
await self._db.execute(migration.script, multi=True)
|
||||
|
||||
await self._executedMigrationDao.create(ExecutedMigration(migration.name), skip_editor=True)
|
||||
except Exception as e:
|
||||
self._logger.fatal(
|
||||
_logger.fatal(
|
||||
f"Migration failed: {migration.name}\n{active_statement}",
|
||||
e,
|
||||
)
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class SeederService:
|
||||
|
||||
def __init__(self, provider: ServiceProvider):
|
||||
def __init__(self, provider: ServiceProviderABC):
|
||||
self._provider = provider
|
||||
self._logger = provider.get_service(DBLogger)
|
||||
|
||||
async def seed(self):
|
||||
seeders = self._provider.get_services(DataSeederABC)
|
||||
self._logger.debug(f"Found {len(seeders)} seeders")
|
||||
_logger.debug(f"Found {len(seeders)} seeders")
|
||||
for seeder in seeders:
|
||||
await seeder.seed()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from .context import get_provider, use_provider
|
||||
from .inject import inject
|
||||
from .scope import Scope
|
||||
from .scope_abc import ScopeABC
|
||||
from .service_collection import ServiceCollection
|
||||
from .service_descriptor import ServiceDescriptor
|
||||
from .service_lifetime import ServiceLifetimeEnum
|
||||
from .service_provider import ServiceProvider
|
||||
from .service_lifetime_enum import ServiceLifetimeEnum
|
||||
from .service_provider import ServiceProvider
|
||||
from .service_provider_abc import ServiceProviderABC
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import contextvars
|
||||
from contextlib import contextmanager
|
||||
|
||||
_current_provider = contextvars.ContextVar("current_provider", default=None)
|
||||
|
||||
|
||||
def use_root_provider(provider):
|
||||
_current_provider.set(provider)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_provider(provider):
|
||||
token = _current_provider.set(provider)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_current_provider.reset(token)
|
||||
|
||||
|
||||
def get_provider():
|
||||
return _current_provider.get()
|
||||
@@ -1,41 +0,0 @@
|
||||
import functools
|
||||
from asyncio import iscoroutinefunction
|
||||
from inspect import signature
|
||||
|
||||
from cpl.dependency.context import get_provider
|
||||
|
||||
|
||||
def inject(f=None):
|
||||
if f is None:
|
||||
return functools.partial(inject)
|
||||
|
||||
if iscoroutinefunction(f):
|
||||
@functools.wraps(f)
|
||||
async def async_inner(*args, **kwargs):
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
provider: ServiceProvider | None = get_provider()
|
||||
if provider is None:
|
||||
raise ValueError(
|
||||
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
|
||||
)
|
||||
|
||||
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
|
||||
return await f(*args, *injection, **kwargs)
|
||||
|
||||
return async_inner
|
||||
|
||||
@functools.wraps(f)
|
||||
def inner(*args, **kwargs):
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
provider: ServiceProvider | None = get_provider()
|
||||
if provider is None:
|
||||
raise ValueError(
|
||||
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
|
||||
)
|
||||
|
||||
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
|
||||
return f(*args, *injection, **kwargs)
|
||||
|
||||
return inner
|
||||
22
src/cpl-dependency/cpl/dependency/scope.py
Normal file
22
src/cpl-dependency/cpl/dependency/scope.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from cpl.dependency.scope_abc import ScopeABC
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
|
||||
class Scope(ScopeABC):
|
||||
def __init__(self, service_provider: ServiceProviderABC):
|
||||
self._service_provider = service_provider
|
||||
self._service_provider.set_scope(self)
|
||||
ScopeABC.__init__(self)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.dispose()
|
||||
|
||||
@property
|
||||
def service_provider(self) -> ServiceProviderABC:
|
||||
return self._service_provider
|
||||
|
||||
def dispose(self):
|
||||
self._service_provider = None
|
||||
20
src/cpl-dependency/cpl/dependency/scope_abc.py
Normal file
20
src/cpl-dependency/cpl/dependency/scope_abc.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ScopeABC(ABC):
|
||||
r"""ABC for the class :class:`cpl.dependency.scope.Scope`"""
|
||||
|
||||
def __init__(self): ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def service_provider(self):
|
||||
r"""Returns to service provider of scope
|
||||
|
||||
Returns:
|
||||
Object of type :class:`cpl.dependency.service_provider_abc.ServiceProviderABC`
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def dispose(self):
|
||||
r"""Sets service_provider to None"""
|
||||
18
src/cpl-dependency/cpl/dependency/scope_builder.py
Normal file
18
src/cpl-dependency/cpl/dependency/scope_builder.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from cpl.dependency.scope import Scope
|
||||
from cpl.dependency.scope_abc import ScopeABC
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
|
||||
class ScopeBuilder:
|
||||
r"""Class to build :class:`cpl.dependency.scope.Scope`"""
|
||||
|
||||
def __init__(self, service_provider: ServiceProviderABC) -> None:
|
||||
self._service_provider = service_provider
|
||||
|
||||
def build(self) -> ScopeABC:
|
||||
r"""Returns scope
|
||||
|
||||
Returns:
|
||||
Object of type :class:`cpl.dependency.scope.Scope`
|
||||
"""
|
||||
return Scope(self._service_provider)
|
||||
@@ -1,11 +1,12 @@
|
||||
from typing import Union, Type, Callable, Self
|
||||
|
||||
from cpl.core.log.logger import Logger
|
||||
from cpl.core.log.logger_abc import LoggerABC
|
||||
from cpl.core.typing import T, Service
|
||||
from cpl.core.utils.cache import Cache
|
||||
from cpl.dependency.service_descriptor import ServiceDescriptor
|
||||
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
|
||||
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
|
||||
class ServiceCollection:
|
||||
@@ -61,8 +62,9 @@ class ServiceCollection:
|
||||
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service)
|
||||
return self
|
||||
|
||||
def build(self) -> ServiceProvider:
|
||||
def build(self) -> ServiceProviderABC:
|
||||
sp = ServiceProvider(self._service_descriptors)
|
||||
ServiceProviderABC.set_global_provider(sp)
|
||||
return sp
|
||||
|
||||
def add_module(self, module: str | object) -> Self:
|
||||
@@ -78,24 +80,5 @@ class ServiceCollection:
|
||||
return self
|
||||
|
||||
def add_logging(self) -> Self:
|
||||
from cpl.core.log.logger import Logger
|
||||
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||
|
||||
self.add_transient(LoggerABC, Logger)
|
||||
for wrapper in WrappedLogger.__subclasses__():
|
||||
self.add_transient(wrapper)
|
||||
return self
|
||||
|
||||
def add_structured_logging(self) -> Self:
|
||||
from cpl.core.log.structured_logger import StructuredLogger
|
||||
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||
|
||||
self.add_transient(LoggerABC, StructuredLogger)
|
||||
|
||||
for wrapper in WrappedLogger.__subclasses__():
|
||||
self.add_transient(wrapper)
|
||||
return self
|
||||
|
||||
def add_cache(self, t: Type[T]):
|
||||
self._service_descriptors.append(ServiceDescriptor(Cache(t=t), ServiceLifetimeEnum.singleton, Cache[t]))
|
||||
return self
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Union, Optional
|
||||
|
||||
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
|
||||
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
|
||||
|
||||
|
||||
class ServiceDescriptor:
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
class ServiceLifetimeEnum(Enum):
|
||||
singleton = auto()
|
||||
scoped = auto()
|
||||
transient = auto()
|
||||
@@ -0,0 +1,7 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ServiceLifetimeEnum(Enum):
|
||||
singleton = 0
|
||||
scoped = 1
|
||||
transient = 2
|
||||
@@ -1,41 +1,44 @@
|
||||
import copy
|
||||
import typing
|
||||
from contextlib import contextmanager
|
||||
from inspect import signature, Parameter, Signature
|
||||
from typing import Optional, Type
|
||||
from typing import Optional
|
||||
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.core.typing import T, R, Source
|
||||
from cpl.dependency import use_provider
|
||||
from cpl.dependency.scope_abc import ScopeABC
|
||||
from cpl.dependency.scope_builder import ScopeBuilder
|
||||
from cpl.dependency.service_descriptor import ServiceDescriptor
|
||||
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
|
||||
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
|
||||
class ServiceProvider:
|
||||
def __init__(self, service_descriptors: list[ServiceDescriptor], is_scope: bool = False):
|
||||
class ServiceProvider(ServiceProviderABC):
|
||||
r"""Provider for the services
|
||||
|
||||
Parameter
|
||||
---------
|
||||
service_descriptors: list[:class:`cpl.dependency.service_descriptor.ServiceDescriptor`]
|
||||
Descriptor of the service
|
||||
config: :class:`cpl.core.configuration.configuration_abc.ConfigurationABC`
|
||||
CPL Configuration
|
||||
db_context: Optional[:class:`cpl.database.context.database_context_abc.DatabaseContextABC`]
|
||||
Database representation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_descriptors: list[ServiceDescriptor],
|
||||
):
|
||||
ServiceProviderABC.__init__(self)
|
||||
|
||||
self._service_descriptors: list[ServiceDescriptor] = service_descriptors
|
||||
self._is_scope = is_scope
|
||||
self._scope: Optional[ScopeABC] = None
|
||||
|
||||
def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]:
|
||||
origin_type = typing.get_origin(service_type) or service_type
|
||||
type_args = list(typing.get_args(service_type))
|
||||
|
||||
for descriptor in self._service_descriptors:
|
||||
descriptor_base_type = typing.get_origin(descriptor.base_type) or descriptor.base_type
|
||||
descriptor_type_args = list(typing.get_args(descriptor.base_type))
|
||||
|
||||
if descriptor_base_type == origin_type and len(descriptor_type_args) == 0 and len(type_args) == 0:
|
||||
return descriptor
|
||||
|
||||
if descriptor_base_type != origin_type or len(descriptor_type_args) != len(type_args):
|
||||
continue
|
||||
|
||||
if descriptor_base_type == origin_type and type_args != descriptor_type_args:
|
||||
continue
|
||||
|
||||
if descriptor.service_type == origin_type or issubclass(descriptor.base_type, origin_type):
|
||||
if descriptor.service_type == service_type or issubclass(descriptor.base_type, service_type):
|
||||
return descriptor
|
||||
|
||||
return None
|
||||
@@ -43,13 +46,13 @@ class ServiceProvider:
|
||||
def _get_service(self, parameter: Parameter, origin_service_type: type = None) -> Optional[object]:
|
||||
for descriptor in self._service_descriptors:
|
||||
if descriptor.service_type == parameter.annotation or issubclass(
|
||||
descriptor.service_type, parameter.annotation
|
||||
descriptor.service_type, parameter.annotation
|
||||
):
|
||||
if descriptor.implementation is not None:
|
||||
return descriptor.implementation
|
||||
|
||||
implementation = self._build_service(descriptor.service_type, origin_service_type=origin_service_type)
|
||||
if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped):
|
||||
if descriptor.lifetime == ServiceLifetimeEnum.singleton:
|
||||
descriptor.implementation = implementation
|
||||
|
||||
return implementation
|
||||
@@ -67,7 +70,7 @@ class ServiceProvider:
|
||||
implementation = self._build_service(
|
||||
descriptor.service_type, origin_service_type=service_type, **kwargs
|
||||
)
|
||||
if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped):
|
||||
if descriptor.lifetime == ServiceLifetimeEnum.singleton:
|
||||
descriptor.implementation = implementation
|
||||
|
||||
implementations.append(implementation)
|
||||
@@ -85,7 +88,7 @@ class ServiceProvider:
|
||||
elif parameter.annotation == Source:
|
||||
params.append(origin_service_type.__name__)
|
||||
|
||||
elif issubclass(parameter.annotation, ServiceProvider):
|
||||
elif issubclass(parameter.annotation, ServiceProviderABC):
|
||||
params.append(self)
|
||||
|
||||
elif issubclass(parameter.annotation, Environment):
|
||||
@@ -113,27 +116,32 @@ class ServiceProvider:
|
||||
service_type = type(descriptor.implementation)
|
||||
else:
|
||||
service_type = descriptor.service_type
|
||||
|
||||
break
|
||||
|
||||
sig = signature(service_type.__init__)
|
||||
params = self._build_by_signature(sig, origin_service_type)
|
||||
|
||||
return service_type(*params, *args, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
def create_scope(self):
|
||||
scoped_descriptors = []
|
||||
for d in self._service_descriptors:
|
||||
if d.lifetime == ServiceLifetimeEnum.singleton:
|
||||
scoped_descriptors.append(d)
|
||||
else:
|
||||
scoped_descriptors.append(copy.deepcopy(d))
|
||||
def set_scope(self, scope: ScopeABC):
|
||||
self._scope = scope
|
||||
|
||||
scoped_provider = ServiceProvider(scoped_descriptors, is_scope=True)
|
||||
with use_provider(scoped_provider):
|
||||
yield scoped_provider
|
||||
def create_scope(self) -> ScopeABC:
|
||||
descriptors = []
|
||||
|
||||
for descriptor in self._service_descriptors:
|
||||
if descriptor.lifetime == ServiceLifetimeEnum.singleton:
|
||||
descriptors.append(descriptor)
|
||||
else:
|
||||
descriptors.append(copy.deepcopy(descriptor))
|
||||
|
||||
sb = ScopeBuilder(ServiceProvider(descriptors))
|
||||
return sb.build()
|
||||
|
||||
def get_service(self, service_type: T, *args, **kwargs) -> Optional[R]:
|
||||
result = self._find_service(service_type)
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
@@ -141,30 +149,21 @@ class ServiceProvider:
|
||||
return result.implementation
|
||||
|
||||
implementation = self._build_service(service_type, *args, **kwargs)
|
||||
|
||||
if result.lifetime == ServiceLifetimeEnum.singleton:
|
||||
result.implementation = implementation
|
||||
elif result.lifetime == ServiceLifetimeEnum.scoped and self._is_scope:
|
||||
if (
|
||||
result.lifetime == ServiceLifetimeEnum.singleton
|
||||
or result.lifetime == ServiceLifetimeEnum.scoped
|
||||
and self._scope is not None
|
||||
):
|
||||
result.implementation = implementation
|
||||
|
||||
return implementation
|
||||
|
||||
def get_service_type(self, service_type: Type[T]) -> Optional[Type[T]]:
|
||||
for descriptor in self._service_descriptors:
|
||||
if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type):
|
||||
return descriptor.service_type
|
||||
return None
|
||||
|
||||
def get_services(self, service_type: T, *args, **kwargs) -> list[Optional[R]]:
|
||||
implementations = []
|
||||
|
||||
if typing.get_origin(service_type) == list:
|
||||
raise Exception(f"Invalid type {service_type}! Expected single type not list of type")
|
||||
implementations.extend(self._get_services(service_type))
|
||||
return implementations
|
||||
|
||||
def get_service_types(self, service_type: Type[T]) -> list[Type[T]]:
|
||||
types = []
|
||||
for descriptor in self._service_descriptors:
|
||||
if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type):
|
||||
types.append(descriptor.service_type)
|
||||
return types
|
||||
implementations.extend(self._get_services(service_type))
|
||||
|
||||
return implementations
|
||||
|
||||
137
src/cpl-dependency/cpl/dependency/service_provider_abc.py
Normal file
137
src/cpl-dependency/cpl/dependency/service_provider_abc.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import functools
|
||||
from abc import abstractmethod, ABC
|
||||
from inspect import Signature, signature, iscoroutinefunction
|
||||
from typing import Optional, Type
|
||||
|
||||
from cpl.core.typing import T, R
|
||||
from cpl.dependency.scope_abc import ScopeABC
|
||||
|
||||
|
||||
class ServiceProviderABC(ABC):
|
||||
r"""ABC for the class :class:`cpl.dependency.service_provider.ServiceProvider`"""
|
||||
|
||||
_provider: Optional["ServiceProviderABC"] = None
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self): ...
|
||||
|
||||
@classmethod
|
||||
def set_global_provider(cls, provider: "ServiceProviderABC"):
|
||||
cls._provider = provider
|
||||
|
||||
@classmethod
|
||||
def get_global_provider(cls) -> Optional["ServiceProviderABC"]:
|
||||
return cls._provider
|
||||
|
||||
@classmethod
|
||||
def get_global_service(cls, instance_type: Type[T], *args, **kwargs) -> Optional[T]:
|
||||
if cls._provider is None:
|
||||
return None
|
||||
return cls._provider.get_service(instance_type, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_global_services(cls, instance_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
|
||||
if cls._provider is None:
|
||||
return []
|
||||
return cls._provider.get_services(instance_type, *args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[T]: ...
|
||||
|
||||
@abstractmethod
|
||||
def _build_service(self, service_type: type, *args, **kwargs) -> object:
|
||||
r"""Creates instance of given type
|
||||
|
||||
Parameter
|
||||
---------
|
||||
instance_type: :class:`type`
|
||||
The type of the searched instance
|
||||
|
||||
Returns
|
||||
-------
|
||||
Object of the given type
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_scope(self, scope: ScopeABC):
|
||||
r"""Sets the scope of service provider
|
||||
|
||||
Parameter
|
||||
---------
|
||||
Object of type :class:`cpl.dependency.scope_abc.ScopeABC`
|
||||
Service scope
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create_scope(self) -> ScopeABC:
|
||||
r"""Creates a service scope
|
||||
|
||||
Returns
|
||||
-------
|
||||
Object of type :class:`cpl.dependency.scope_abc.ScopeABC`
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_service(self, instance_type: Type[T], *args, **kwargs) -> Optional[T]:
|
||||
r"""Returns instance of given type
|
||||
|
||||
Parameter
|
||||
---------
|
||||
instance_type: :class:`cpl.core.type.T`
|
||||
The type of the searched instance
|
||||
|
||||
Returns
|
||||
-------
|
||||
Object of type Optional[:class:`cpl.core.type.T`]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_services(self, service_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
|
||||
r"""Returns instance of given type
|
||||
|
||||
Parameter
|
||||
---------
|
||||
service_type: :class:`cpl.core.type.T`
|
||||
The type of the searched instance
|
||||
|
||||
Returns
|
||||
-------
|
||||
Object of type list[Optional[:class:`cpl.core.type.T`]
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def inject(cls, f=None):
|
||||
r"""Decorator to allow injection into static and class methods
|
||||
|
||||
Parameter
|
||||
---------
|
||||
f: Callable
|
||||
|
||||
Returns
|
||||
-------
|
||||
function
|
||||
"""
|
||||
if f is None:
|
||||
return functools.partial(cls.inject)
|
||||
|
||||
if iscoroutinefunction(f):
|
||||
|
||||
@functools.wraps(f)
|
||||
async def async_inner(*args, **kwargs):
|
||||
if cls._provider is None:
|
||||
raise Exception(f"{cls.__name__} not build!")
|
||||
|
||||
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
|
||||
return await f(*args, *injection, **kwargs)
|
||||
|
||||
return async_inner
|
||||
|
||||
@functools.wraps(f)
|
||||
def inner(*args, **kwargs):
|
||||
if cls._provider is None:
|
||||
raise Exception(f"{cls.__name__} not build!")
|
||||
|
||||
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
|
||||
return f(*args, *injection, **kwargs)
|
||||
|
||||
return inner
|
||||
@@ -3,7 +3,7 @@ from .abc.email_client_abc import EMailClientABC
|
||||
from .email_client import EMailClient
|
||||
from .email_client_settings import EMailClientSettings
|
||||
from .email_model import EMail
|
||||
from .logger import MailLogger
|
||||
from .mail_logger import MailLogger
|
||||
|
||||
|
||||
def add_mail(collection: _ServiceCollection):
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional
|
||||
from cpl.mail.abc.email_client_abc import EMailClientABC
|
||||
from cpl.mail.email_client_settings import EMailClientSettings
|
||||
from cpl.mail.email_model import EMail
|
||||
from cpl.mail.logger import MailLogger
|
||||
from cpl.mail.mail_logger import MailLogger
|
||||
|
||||
|
||||
class EMailClient(EMailClientABC):
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||
|
||||
|
||||
class MailLogger(WrappedLogger):
|
||||
|
||||
def __init__(self):
|
||||
WrappedLogger.__init__(self, "mail")
|
||||
8
src/cpl-mail/cpl/mail/mail_logger.py
Normal file
8
src/cpl-mail/cpl/mail/mail_logger.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from cpl.core.log.logger import Logger
|
||||
from cpl.core.typing import Source
|
||||
|
||||
|
||||
class MailLogger(Logger):
|
||||
|
||||
def __init__(self, source: Source):
|
||||
Logger.__init__(self, source, "mail")
|
||||
@@ -1,7 +1 @@
|
||||
from .array import Array
|
||||
from .enumerable import Enumerable
|
||||
from .immutable_list import ImmutableList
|
||||
from .immutable_set import ImmutableSet
|
||||
from .list import List
|
||||
from .ordered_enumerable import OrderedEnumerable
|
||||
from .set import Set
|
||||
|
||||
|
||||
2
src/cpl-query/cpl/query/_helper.py
Normal file
2
src/cpl-query/cpl/query/_helper.py
Normal file
@@ -0,0 +1,2 @@
|
||||
def is_number(t: type) -> bool:
|
||||
return issubclass(t, int) or issubclass(t, float) or issubclass(t, complex)
|
||||
@@ -1,44 +0,0 @@
|
||||
from typing import Generic, Iterable, Optional
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.query.list import List
|
||||
from cpl.query.enumerable import Enumerable
|
||||
|
||||
|
||||
class Array(Generic[T], List[T]):
|
||||
def __init__(self, length: int, source: Optional[Iterable[T]] = None):
|
||||
List.__init__(self, source)
|
||||
self._length = length
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return len(self._source)
|
||||
|
||||
def add(self, item: T) -> None:
|
||||
if self._length == self.length:
|
||||
raise IndexError("Array is full")
|
||||
self._source.append(item)
|
||||
|
||||
def extend(self, items: Iterable[T]) -> None:
|
||||
if self._length == self.length:
|
||||
raise IndexError("Array is full")
|
||||
self._source.extend(items)
|
||||
|
||||
def insert(self, index: int, item: T) -> None:
|
||||
if index < 0 or index > self.length:
|
||||
raise IndexError("Index out of range")
|
||||
self._source.insert(index, item)
|
||||
|
||||
def remove(self, item: T) -> None:
|
||||
self._source.remove(item)
|
||||
|
||||
def pop(self, index: int = -1) -> T:
|
||||
return self._source.pop(index)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._source.clear()
|
||||
|
||||
def to_enumerable(self) -> "Enumerable[T]":
|
||||
from cpl.query.enumerable import Enumerable
|
||||
|
||||
return Enumerable(self._source)
|
||||
5
src/cpl-query/cpl/query/base/__init__.py
Normal file
5
src/cpl-query/cpl/query/base/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .default_lambda import default_lambda
|
||||
from .ordered_queryable import OrderedQueryable
|
||||
from .sequence import Sequence
|
||||
from .ordered_queryable_abc import OrderedQueryableABC
|
||||
from .queryable_abc import QueryableABC
|
||||
2
src/cpl-query/cpl/query/base/default_lambda.py
Normal file
2
src/cpl-query/cpl/query/base/default_lambda.py
Normal file
@@ -0,0 +1,2 @@
|
||||
def default_lambda(x: object):
|
||||
return x
|
||||
34
src/cpl-query/cpl/query/base/ordered_queryable.py
Normal file
34
src/cpl-query/cpl/query/base/ordered_queryable.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from cpl.query.base.ordered_queryable_abc import OrderedQueryableABC
|
||||
from cpl.query.exceptions import ArgumentNoneException, ExceptionArgument
|
||||
|
||||
|
||||
class OrderedQueryable(OrderedQueryableABC):
|
||||
r"""Implementation of :class: `cpl.query.base.ordered_queryable_abc.OrderedQueryableABC`"""
|
||||
|
||||
def __init__(self, _t: type, _values: OrderedQueryableABC = None, _func: Callable = None):
|
||||
OrderedQueryableABC.__init__(self, _t, _values, _func)
|
||||
|
||||
def then_by(self, _func: Callable) -> OrderedQueryableABC:
|
||||
if self is None:
|
||||
raise ArgumentNoneException(ExceptionArgument.list)
|
||||
|
||||
if _func is None:
|
||||
raise ArgumentNoneException(ExceptionArgument.func)
|
||||
|
||||
self._funcs.append(_func)
|
||||
|
||||
return OrderedQueryable(self.type, sorted(self, key=lambda *args: [f(*args) for f in self._funcs]), _func)
|
||||
|
||||
def then_by_descending(self, _func: Callable) -> OrderedQueryableABC:
|
||||
if self is None:
|
||||
raise ArgumentNoneException(ExceptionArgument.list)
|
||||
|
||||
if _func is None:
|
||||
raise ArgumentNoneException(ExceptionArgument.func)
|
||||
|
||||
self._funcs.append(_func)
|
||||
return OrderedQueryable(
|
||||
self.type, sorted(self, key=lambda *args: [f(*args) for f in self._funcs], reverse=True), _func
|
||||
)
|
||||
38
src/cpl-query/cpl/query/base/ordered_queryable_abc.py
Normal file
38
src/cpl-query/cpl/query/base/ordered_queryable_abc.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Iterable
|
||||
|
||||
from cpl.query.base.queryable_abc import QueryableABC
|
||||
|
||||
|
||||
class OrderedQueryableABC(QueryableABC):
|
||||
@abstractmethod
|
||||
def __init__(self, _t: type, _values: Iterable = None, _func: Callable = None):
|
||||
QueryableABC.__init__(self, _t, _values)
|
||||
self._funcs: list[Callable] = []
|
||||
if _func is not None:
|
||||
self._funcs.append(_func)
|
||||
|
||||
@abstractmethod
|
||||
def then_by(self, func: Callable) -> OrderedQueryableABC:
|
||||
r"""Sorts OrderedList in ascending order by function
|
||||
|
||||
Parameter:
|
||||
func: :class:`Callable`
|
||||
|
||||
Returns:
|
||||
list of :class:`cpl.query.base.ordered_queryable_abc.OrderedQueryableABC`
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def then_by_descending(self, func: Callable) -> OrderedQueryableABC:
|
||||
r"""Sorts OrderedList in descending order by function
|
||||
|
||||
Parameter:
|
||||
func: :class:`Callable`
|
||||
|
||||
Returns:
|
||||
list of :class:`cpl.query.base.ordered_queryable_abc.OrderedQueryableABC`
|
||||
"""
|
||||
569
src/cpl-query/cpl/query/base/queryable_abc.py
Normal file
569
src/cpl-query/cpl/query/base/queryable_abc.py
Normal file
@@ -0,0 +1,569 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Callable, Union, Iterable, Any
|
||||
|
||||
from cpl.query._helper import is_number
|
||||
from cpl.query.base import default_lambda
|
||||
from cpl.query.base.sequence import Sequence
|
||||
from cpl.query.exceptions import (
|
||||
InvalidTypeException,
|
||||
ArgumentNoneException,
|
||||
ExceptionArgument,
|
||||
IndexOutOfRangeException,
|
||||
)
|
||||
|
||||
|
||||
class QueryableABC(Sequence):
|
||||
def __init__(self, t: type, values: Iterable = None):
|
||||
Sequence.__init__(self, t, values)
|
||||
|
||||
def all(self, _func: Callable = None) -> bool:
|
||||
r"""Checks if every element of list equals result found by function
|
||||
|
||||
Parameter
|
||||
---------
|
||||
func: :class:`Callable`
|
||||
selected value
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
if _func is None:
|
||||
_func = default_lambda
|
||||
|
||||
return self.count(_func) == self.count()
|
||||
|
||||
def any(self, _func: Callable = None) -> bool:
|
||||
r"""Checks if list contains result found by function
|
||||
|
||||
Parameter
|
||||
---------
|
||||
func: :class:`Callable`
|
||||
selected value
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
if _func is None:
|
||||
_func = default_lambda
|
||||
|
||||
return self.where(_func).count() > 0
|
||||
|
||||
def average(self, _func: Callable = None) -> Union[int, float, complex]:
|
||||
r"""Returns average value of list
|
||||
|
||||
Parameter
|
||||
---------
|
||||
func: :class:`Callable`
|
||||
selected value
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[int, float, complex]
|
||||
"""
|
||||
if _func is None and not is_number(self.type):
|
||||
raise InvalidTypeException()
|
||||
|
||||
return self.sum(_func) / self.count()
|
||||
|
||||
def contains(self, _value: object) -> bool:
|
||||
r"""Checks if list contains value given by function
|
||||
|
||||
Parameter
|
||||
---------
|
||||
value: :class:`object`
|
||||
value
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
if _value is None:
|
||||
raise ArgumentNoneException(ExceptionArgument.value)
|
||||
|
||||
return self.where(lambda x: x == _value).count() > 0
|
||||
|
||||
def count(self, _func: Callable = None) -> int:
|
||||
r"""Returns length of list or count of found elements
|
||||
|
||||
Parameter
|
||||
---------
|
||||
func: :class:`Callable`
|
||||
selected value
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
"""
|
||||
if _func is None:
|
||||
return self.__len__()
|
||||
|
||||
return self.where(_func).count()
|
||||
|
||||
def distinct(self, _func: Callable = None) -> QueryableABC:
|
||||
r"""Returns list without redundancies
|
||||
|
||||
Parameter
|
||||
---------
|
||||
func: :class:`Callable`
|
||||
selected value
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: `cpl.query.base.queryable_abc.QueryableABC`
|
||||
"""
|
||||
if _func is None:
|
||||
_func = default_lambda
|
||||
|
||||
result = []
|
||||
known_values = []
|
||||
for element in self:
|
||||
value = _func(element)
|
||||
if value in known_values:
|
||||
continue
|
||||
|
||||
known_values.append(value)
|
||||
result.append(element)
|
||||
|
||||
return type(self)(self._type, result)
|
||||
|
||||
def element_at(self, _index: int) -> any:
|
||||
r"""Returns element at given index
|
||||
|
||||
Parameter
|
||||
---------
|
||||
_index: :class:`int`
|
||||
index
|
||||
|
||||
Returns
|
||||
-------
|
||||
Value at _index: any
|
||||
"""
|
||||
if _index is None:
|
||||
raise ArgumentNoneException(ExceptionArgument.index)
|
||||
|
||||
if _index < 0 or _index >= self.count():
|
||||
raise IndexOutOfRangeException
|
||||
|
||||
result = self._values[_index]
|
||||
if result is None:
|
||||
raise IndexOutOfRangeException
|
||||
|
||||
return result
|
||||
|
||||
def element_at_or_default(self, _index: int) -> Optional[any]:
|
||||
r"""Returns element at given index or None
|
||||
|
||||
Parameter
|
||||
---------
|
||||
_index: :class:`int`
|
||||
index
|
||||
|
||||
Returns
|
||||
-------
|
||||
Value at _index: Optional[any]
|
||||
"""
|
||||
if _index is None:
|
||||
raise ArgumentNoneException(ExceptionArgument.index)
|
||||
|
||||
try:
|
||||
return self._values[_index]
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
def first(self) -> any:
|
||||
r"""Returns first element
|
||||
|
||||
Returns
|
||||
-------
|
||||
First element of list: any
|
||||
"""
|
||||
if self.count() == 0:
|
||||
raise IndexOutOfRangeException()
|
||||
|
||||
return self._values[0]
|
||||
|
||||
def first_or_default(self) -> any:
|
||||
r"""Returns first element or None
|
||||
|
||||
Returns
|
||||
-------
|
||||
First element of list: Optional[any]
|
||||
"""
|
||||
if self.count() == 0:
|
||||
return None
|
||||
|
||||
return self._values[0]
|
||||
|
||||
def for_each(self, _func: Callable = None):
|
||||
r"""Runs given function for each element of list
|
||||
|
||||
Parameter
|
||||
---------
|
||||
func: :class: `Callable`
|
||||
function to call
|
||||
"""
|
||||
if _func is not None:
|
||||
for element in self:
|
||||
_func(element)
|
||||
|
||||
return self
|
||||
|
||||
def group_by(self, _func: Callable = None) -> QueryableABC:
|
||||
r"""Groups by func
|
||||
|
||||
Returns
|
||||
-------
|
||||
Grouped list[list[any]]: any
|
||||
"""
|
||||
if _func is None:
|
||||
_func = default_lambda
|
||||
groups = {}
|
||||
|
||||
for v in self:
|
||||
value = _func(v)
|
||||
if v not in groups:
|
||||
groups[value] = []
|
||||
|
||||
groups[value].append(v)
|
||||
|
||||
v = []
|
||||
for g in groups.values():
|
||||
v.append(type(self)(object, g))
|
||||
x = type(self)(type(self), v)
|
||||
return x
|
||||
|
||||
def last(self) -> any:
|
||||
r"""Returns last element
|
||||
|
||||
Returns
|
||||
-------
|
||||
Last element of list: any
|
||||
"""
|
||||
if self.count() == 0:
|
||||
raise IndexOutOfRangeException()
|
||||
|
||||
return self._values[self.count() - 1]
|
||||
|
||||
def last_or_default(self) -> any:
|
||||
r"""Returns last element or None
|
||||
|
||||
Returns
|
||||
-------
|
||||
Last element of list: Optional[any]
|
||||
"""
|
||||
if self.count() == 0:
|
||||
return None
|
||||
|
||||
return self._values[self.count() - 1]
|
||||
|
||||
def max(self, _func: Callable = None) -> object:
|
||||
r"""Returns the highest value
|
||||
|
||||
Parameter
|
||||
---------
|
||||
func: :class:`Callable`
|
||||
selected value
|
||||
|
||||
Returns
|
||||
-------
|
||||
object
|
||||
"""
|
||||
if _func is None and not is_number(self.type):
|
||||
raise InvalidTypeException()
|
||||
|
||||
if _func is None:
|
||||
_func = default_lambda
|
||||
|
||||
return _func(max(self, key=_func))
|
||||
|
||||
def median(self, _func=None) -> Union[int, float]:
|
||||
r"""Return the median value of data elements
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[int, float]
|
||||
"""
|
||||
if _func is None:
|
||||
_func = default_lambda
|
||||
|
||||
result = self.order_by(_func).select(_func).to_list()
|
||||
length = len(result)
|
||||
i = int(length / 2)
|
||||
return result[i] if length % 2 == 1 else (float(result[i - 1]) + float(result[i])) / float(2)
|
||||
|
||||
def min(self, _func: Callable = None) -> object:
|
||||
r"""Returns the lowest value
|
||||
|
||||
Parameter
|
||||
---------
|
||||
func: :class:`Callable`
|
||||
selected value
|
||||
|
||||
Returns
|
||||
-------
|
||||
object
|
||||
"""
|
||||
if _func is None and not is_number(self.type):
|
||||
raise InvalidTypeException()
|
||||
|
||||
if _func is None:
|
||||
_func = default_lambda
|
||||
|
||||
return _func(min(self, key=_func))
|
||||
|
||||
def order_by(self, _func: Callable = None) -> "OrderedQueryableABC":
|
||||
r"""Sorts elements by function in ascending order
|
||||
|
||||
Parameter
|
||||
---------
|
||||
func: :class:`Callable`
|
||||
selected value
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: `cpl.query.base.ordered_queryable_abc.OrderedQueryableABC`
|
||||
"""
|
||||
if _func is None:
|
||||
_func = default_lambda
|
||||
|
||||
from cpl.query.base.ordered_queryable import OrderedQueryable
|
||||
|
||||
return OrderedQueryable(self.type, sorted(self, key=_func), _func)
|
||||
|
||||
def order_by_descending(self, _func: Callable = None) -> "OrderedQueryableABC":
|
||||
r"""Sorts elements by function in descending order
|
||||
|
||||
Parameter
|
||||
---------
|
||||
func: :class:`Callable`
|
||||
selected value
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: `cpl.query.base.ordered_queryable_abc.OrderedQueryableABC`
|
||||
"""
|
||||
if _func is None:
|
||||
_func = default_lambda
|
||||
|
||||
from cpl.query.base.ordered_queryable import OrderedQueryable
|
||||
|
||||
return OrderedQueryable(self.type, sorted(self, key=_func, reverse=True), _func)
|
||||
|
||||
def reverse(self) -> QueryableABC:
|
||||
r"""Reverses list
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: `cpl.query.base.queryable_abc.QueryableABC`
|
||||
"""
|
||||
return type(self)(self._type, reversed(self._values))
|
||||
|
||||
def select(self, _func: Callable) -> QueryableABC:
|
||||
r"""Formats each element of list to a given format
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: `cpl.query.base.queryable_abc.QueryableABC`
|
||||
"""
|
||||
if _func is None:
|
||||
_func = default_lambda
|
||||
|
||||
_l = [_func(_o) for _o in self]
|
||||
_t = type(_l[0]) if len(_l) > 0 else Any
|
||||
|
||||
return type(self)(_t, _l)
|
||||
|
||||
def select_many(self, _func: Callable) -> QueryableABC:
|
||||
r"""Flattens resulting lists to one
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: `cpl.query.base.queryable_abc.QueryableABC`
|
||||
"""
|
||||
# The line below is pain. I don't understand anything of it...
|
||||
# written on 09.11.2022 by Sven Heidemann
|
||||
return type(self)(object, [_a for _o in self for _a in _func(_o)])
|
||||
|
||||
def single(self) -> any:
|
||||
r"""Returns one single element of list
|
||||
|
||||
Returns
|
||||
-------
|
||||
Found value: any
|
||||
|
||||
Raises
|
||||
------
|
||||
ArgumentNoneException: when argument is None
|
||||
Exception: when argument is None or found more than one element
|
||||
"""
|
||||
if self.count() > 1:
|
||||
raise Exception("Found more than one element")
|
||||
elif self.count() == 0:
|
||||
raise Exception("Found no element")
|
||||
|
||||
return self._values[0]
|
||||
|
||||
def single_or_default(self) -> Optional[any]:
|
||||
r"""Returns one single element of list
|
||||
|
||||
Returns
|
||||
-------
|
||||
Found value: Optional[any]
|
||||
"""
|
||||
if self.count() > 1:
|
||||
raise Exception("Index out of range")
|
||||
elif self.count() == 0:
|
||||
return None
|
||||
|
||||
return self._values[0]
|
||||
|
||||
def skip(self, _index: int) -> QueryableABC:
|
||||
r"""Skips all elements from index
|
||||
|
||||
Parameter
|
||||
---------
|
||||
_index: :class:`int`
|
||||
index
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: `cpl.query.base.queryable_abc.QueryableABC`
|
||||
"""
|
||||
if _index is None:
|
||||
raise ArgumentNoneException(ExceptionArgument.index)
|
||||
|
||||
return type(self)(self.type, self._values[_index:])
|
||||
|
||||
def skip_last(self, _index: int) -> QueryableABC:
|
||||
r"""Skips all elements after index
|
||||
|
||||
Parameter
|
||||
---------
|
||||
_index: :class:`int`
|
||||
index
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: `cpl.query.base.queryable_abc.QueryableABC`
|
||||
"""
|
||||
if _index is None:
|
||||
raise ArgumentNoneException(ExceptionArgument.index)
|
||||
|
||||
index = self.count() - _index
|
||||
return type(self)(self._type, self._values[:index])
|
||||
|
||||
def sum(self, _func: Callable = None) -> Union[int, float, complex]:
|
||||
r"""Sum of all values
|
||||
|
||||
Parameter
|
||||
---------
|
||||
func: :class:`Callable`
|
||||
selected value
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[int, float, complex]
|
||||
"""
|
||||
if _func is None and not is_number(self.type):
|
||||
raise InvalidTypeException()
|
||||
|
||||
if _func is None:
|
||||
_func = default_lambda
|
||||
|
||||
result = 0
|
||||
for x in self:
|
||||
result += _func(x)
|
||||
|
||||
return result
|
||||
|
||||
def split(self, _func: Callable) -> QueryableABC:
|
||||
r"""Splits the list by given function
|
||||
|
||||
|
||||
Parameter
|
||||
---------
|
||||
func: :class:`Callable`
|
||||
seperator
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: `cpl.query.base.queryable_abc.QueryableABC`
|
||||
"""
|
||||
groups = []
|
||||
group = []
|
||||
for x in self:
|
||||
v = _func(x)
|
||||
if x == v:
|
||||
groups.append(group)
|
||||
group = []
|
||||
|
||||
group.append(x)
|
||||
|
||||
groups.append(group)
|
||||
|
||||
query_groups = []
|
||||
for g in groups:
|
||||
if len(g) == 0:
|
||||
continue
|
||||
query_groups.append(type(self)(self._type, g))
|
||||
|
||||
return type(self)(self._type, query_groups)
|
||||
|
||||
def take(self, _index: int) -> QueryableABC:
|
||||
r"""Takes all elements from index
|
||||
|
||||
Parameter
|
||||
---------
|
||||
_index: :class:`int`
|
||||
index
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: `cpl.query.base.queryable_abc.QueryableABC`
|
||||
"""
|
||||
if _index is None:
|
||||
raise ArgumentNoneException(ExceptionArgument.index)
|
||||
|
||||
return type(self)(self._type, self._values[:_index])
|
||||
|
||||
def take_last(self, _index: int) -> QueryableABC:
|
||||
r"""Takes all elements after index
|
||||
|
||||
Parameter
|
||||
---------
|
||||
_index: :class:`int`
|
||||
index
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: `cpl.query.base.queryable_abc.QueryableABC`
|
||||
"""
|
||||
index = self.count() - _index
|
||||
|
||||
if index >= self.count() or index < 0:
|
||||
raise IndexOutOfRangeException()
|
||||
|
||||
return type(self)(self._type, self._values[index:])
|
||||
|
||||
def where(self, _func: Callable = None) -> QueryableABC:
|
||||
r"""Select element by function
|
||||
|
||||
Parameter
|
||||
---------
|
||||
func: :class:`Callable`
|
||||
selected value
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: `cpl.query.base.queryable_abc.QueryableABC`
|
||||
"""
|
||||
if _func is None:
|
||||
raise ArgumentNoneException(ExceptionArgument.func)
|
||||
|
||||
if _func is None:
|
||||
_func = default_lambda
|
||||
|
||||
return type(self)(self.type, filter(_func, self))
|
||||
96
src/cpl-query/cpl/query/base/sequence.py
Normal file
96
src/cpl-query/cpl/query/base/sequence.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
class Sequence(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, t: type, values: Iterable = None):
|
||||
assert t is not None
|
||||
assert isinstance(t, type) or t == any
|
||||
assert values is None or isinstance(values, Iterable)
|
||||
|
||||
if values is None:
|
||||
values = []
|
||||
|
||||
self._values = list(values)
|
||||
|
||||
if t is None:
|
||||
t = object
|
||||
|
||||
self._type = t
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._values)
|
||||
|
||||
def __next__(self):
|
||||
return next(iter(self._values))
|
||||
|
||||
def __len__(self):
|
||||
return self.to_list().__len__()
|
||||
|
||||
@classmethod
|
||||
def __class_getitem__(cls, _t: type) -> type:
|
||||
return _t
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{type(self).__name__} {self.to_list().__repr__()}>"
|
||||
|
||||
@property
|
||||
def type(self) -> type:
|
||||
return self._type
|
||||
|
||||
def _check_type(self, __object: any):
|
||||
if self._type == any:
|
||||
return
|
||||
|
||||
if (
|
||||
self._type is not None
|
||||
and type(__object) != self._type
|
||||
and not isinstance(type(__object), self._type)
|
||||
and not issubclass(type(__object), self._type)
|
||||
):
|
||||
raise Exception(f"Unexpected type: {type(__object)}\nExpected type: {self._type}")
|
||||
|
||||
def to_list(self) -> list:
|
||||
r"""Converts :class: `cpl.query.base.sequence_abc.SequenceABC` to :class: `list`
|
||||
|
||||
Returns:
|
||||
:class: `list`
|
||||
"""
|
||||
return [x for x in self._values]
|
||||
|
||||
def copy(self) -> "Sequence":
|
||||
r"""Creates a copy of sequence
|
||||
|
||||
Returns:
|
||||
Sequence
|
||||
"""
|
||||
return type(self)(self._type, self.to_list())
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "Sequence":
|
||||
r"""Returns an empty sequence
|
||||
|
||||
Returns:
|
||||
Sequence object that contains no elements
|
||||
"""
|
||||
return cls(object, [])
|
||||
|
||||
def index_of(self, _object: object) -> int:
|
||||
r"""Returns the index of given element
|
||||
|
||||
Returns:
|
||||
Index of object
|
||||
|
||||
Raises:
|
||||
IndexError if object not in sequence
|
||||
"""
|
||||
for i, o in enumerate(self):
|
||||
if o == _object:
|
||||
return i
|
||||
|
||||
raise IndexError
|
||||
|
||||
@classmethod
|
||||
def range(cls, start: int, length: int) -> "Sequence":
|
||||
return cls(int, range(start, length))
|
||||
@@ -1,213 +0,0 @@
|
||||
from itertools import islice, groupby, chain
|
||||
from typing import Generic, Callable, Iterable, Iterator, Dict, Tuple, Optional
|
||||
|
||||
from cpl.core.typing import T, R
|
||||
from cpl.query.typing import Predicate, K, Selector
|
||||
|
||||
|
||||
class Enumerable(Generic[T]):
|
||||
def __init__(self, source: Iterable[T]):
|
||||
self._source = source
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
return iter(self._source)
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return len(list(self._source))
|
||||
|
||||
def where(self, f: Predicate) -> "Enumerable[T]":
|
||||
return Enumerable(x for x in self._source if f(x))
|
||||
|
||||
def select(self, f: Selector) -> "Enumerable[R]":
|
||||
return Enumerable(f(x) for x in self._source)
|
||||
|
||||
def select_many(self, f: Callable[[T], Iterable[R]]) -> "Enumerable[R]":
|
||||
return Enumerable(y for x in self._source for y in f(x))
|
||||
|
||||
def take(self, count: int) -> "Enumerable[T]":
|
||||
return Enumerable(islice(self._source, count))
|
||||
|
||||
def skip(self, count: int) -> "Enumerable[T]":
|
||||
return Enumerable(islice(self._source, count, None))
|
||||
|
||||
def take_while(self, f: Predicate) -> "Enumerable[T]":
|
||||
def generator():
|
||||
for x in self._source:
|
||||
if f(x):
|
||||
yield x
|
||||
else:
|
||||
break
|
||||
|
||||
return Enumerable(generator())
|
||||
|
||||
def skip_while(self, f: Predicate) -> "Enumerable[T]":
|
||||
def generator():
|
||||
it = iter(self._source)
|
||||
for x in it:
|
||||
if not f(x):
|
||||
yield x
|
||||
break
|
||||
yield from it
|
||||
|
||||
return Enumerable(generator())
|
||||
|
||||
def distinct(self) -> "Enumerable[T]":
|
||||
def generator():
|
||||
seen = set()
|
||||
for x in self._source:
|
||||
if x not in seen:
|
||||
seen.add(x)
|
||||
yield x
|
||||
|
||||
return Enumerable(generator())
|
||||
|
||||
def union(self, other: Iterable[T]) -> "Enumerable[T]":
|
||||
return Enumerable(chain(self.distinct(), Enumerable(other).distinct())).distinct()
|
||||
|
||||
def intersect(self, other: Iterable[T]) -> "Enumerable[T]":
|
||||
other_set = set(other)
|
||||
return Enumerable(x for x in self._source if x in other_set)
|
||||
|
||||
def except_(self, other: Iterable[T]) -> "Enumerable[T]":
|
||||
other_set = set(other)
|
||||
return Enumerable(x for x in self._source if x not in other_set)
|
||||
|
||||
def concat(self, other: Iterable[T]) -> "Enumerable[T]":
|
||||
return Enumerable(chain(self._source, other))
|
||||
|
||||
# --- Aggregation ---
|
||||
def count(self) -> int:
|
||||
return sum(1 for _ in self._source)
|
||||
|
||||
def sum(self, f: Optional[Selector] = None) -> R:
|
||||
if f:
|
||||
return sum(f(x) for x in self._source)
|
||||
return sum(self._source) # type: ignore
|
||||
|
||||
def min(self, f: Optional[Selector] = None) -> R:
|
||||
if f:
|
||||
return min(f(x) for x in self._source)
|
||||
return min(self._source) # type: ignore
|
||||
|
||||
def max(self, f: Optional[Selector] = None) -> R:
|
||||
if f:
|
||||
return max(f(x) for x in self._source)
|
||||
return max(self._source) # type: ignore
|
||||
|
||||
def average(self, f: Optional[Callable[[T], float]] = None) -> float:
|
||||
values = list(self.select(f).to_list()) if f else list(self._source)
|
||||
return sum(values) / len(values) if values else 0.0
|
||||
|
||||
def aggregate(self, func: Callable[[R, T], R], seed: Optional[R] = None) -> R:
|
||||
it = iter(self._source)
|
||||
if seed is None:
|
||||
acc = next(it) # type: ignore
|
||||
else:
|
||||
acc = seed
|
||||
for x in it:
|
||||
acc = func(acc, x)
|
||||
return acc
|
||||
|
||||
def any(self, f: Optional[Predicate] = None) -> bool:
|
||||
return any(f(x) if f else x for x in self._source)
|
||||
|
||||
def all(self, f: Predicate) -> bool:
|
||||
return all(f(x) for x in self._source)
|
||||
|
||||
def contains(self, value: T) -> bool:
|
||||
return any(x == value for x in self._source)
|
||||
|
||||
def sequence_equal(self, other: Iterable[T]) -> bool:
|
||||
return list(self._source) == list(other)
|
||||
|
||||
def group_by(self, key_f: Callable[[T], K]) -> "Enumerable[Tuple[K, List[T]]]":
|
||||
def generator():
|
||||
sorted_data = sorted(self._source, key=key_f)
|
||||
for key, group in groupby(sorted_data, key=key_f):
|
||||
yield (key, list(group))
|
||||
|
||||
return Enumerable(generator())
|
||||
|
||||
def join(
|
||||
self, inner: Iterable[R], outer_key: Callable[[T], K], inner_key: Callable[[R], K], result: Callable[[T, R], R]
|
||||
) -> "Enumerable[R]":
|
||||
def generator():
|
||||
lookup: Dict[K, List[R]] = {}
|
||||
for i in inner:
|
||||
k = inner_key(i)
|
||||
lookup.setdefault(k, []).append(i)
|
||||
for o in self._source:
|
||||
k = outer_key(o)
|
||||
if k in lookup:
|
||||
for i in lookup[k]:
|
||||
yield result(o, i)
|
||||
|
||||
return Enumerable(generator())
|
||||
|
||||
def first(self, f: Optional[Predicate] = None) -> T:
|
||||
if f:
|
||||
for x in self._source:
|
||||
if f(x):
|
||||
return x
|
||||
raise ValueError("No matching element")
|
||||
return next(iter(self._source))
|
||||
|
||||
def first_or_default(self, default: Optional[T] = None) -> Optional[T]:
|
||||
return next(iter(self._source), default)
|
||||
|
||||
def last(self) -> T:
|
||||
return list(self._source)[-1]
|
||||
|
||||
def single(self, f: Optional[Predicate] = None) -> T:
|
||||
items = [x for x in self._source if f(x)] if f else list(self._source)
|
||||
if len(items) != 1:
|
||||
raise ValueError("Sequence does not contain exactly one element")
|
||||
return items[0]
|
||||
|
||||
def to_list(self) -> "List[T]":
|
||||
from cpl.query.list import List
|
||||
|
||||
return List(self)
|
||||
|
||||
def to_set(self) -> "Set[T]":
|
||||
from cpl.query.set import Set
|
||||
|
||||
return Set(self)
|
||||
|
||||
def to_dict(self, key_f: Callable[[T], K], value_f: Selector) -> Dict[K, R]:
|
||||
return {key_f(x): value_f(x) for x in self._source}
|
||||
|
||||
def cast(self, t: Selector) -> "Enumerable[R]":
|
||||
return Enumerable(t(x) for x in self._source)
|
||||
|
||||
def of_type(self, t: type) -> "Enumerable[T]":
|
||||
return Enumerable(x for x in self._source if isinstance(x, t))
|
||||
|
||||
def reverse(self) -> "Enumerable[T]":
|
||||
return Enumerable(reversed(list(self._source)))
|
||||
|
||||
def zip(self, other: Iterable[R]) -> "Enumerable[Tuple[T, R]]":
|
||||
return Enumerable(zip(self._source, other))
|
||||
|
||||
def order_by(self, key_selector: Callable[[T], K]) -> "OrderedEnumerable[T]":
|
||||
from cpl.query.ordered_enumerable import OrderedEnumerable
|
||||
|
||||
return OrderedEnumerable(self._source, [(key_selector, False)])
|
||||
|
||||
def order_by_descending(self, key_selector: Callable[[T], K]) -> "OrderedEnumerable[T]":
|
||||
from cpl.query.ordered_enumerable import OrderedEnumerable
|
||||
|
||||
return OrderedEnumerable(self._source, [(key_selector, True)])
|
||||
|
||||
@staticmethod
|
||||
def range(start: int, count: int) -> "Enumerable[int]":
|
||||
return Enumerable(range(start, start + count))
|
||||
|
||||
@staticmethod
|
||||
def repeat(value: T, count: int) -> "Enumerable[T]":
|
||||
return Enumerable(value for _ in range(count))
|
||||
|
||||
@staticmethod
|
||||
def empty() -> "Enumerable[T]":
|
||||
return Enumerable([])
|
||||
2
src/cpl-query/cpl/query/enumerable/__init__.py
Normal file
2
src/cpl-query/cpl/query/enumerable/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .enumerable import Enumerable
|
||||
from .enumerable_abc import EnumerableABC
|
||||
12
src/cpl-query/cpl/query/enumerable/enumerable.py
Normal file
12
src/cpl-query/cpl/query/enumerable/enumerable.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from cpl.query.enumerable.enumerable_abc import EnumerableABC
|
||||
|
||||
|
||||
def _default_lambda(x: object):
|
||||
return x
|
||||
|
||||
|
||||
class Enumerable(EnumerableABC):
|
||||
r"""Implementation of :class: `cpl.query.enumerable.enumerable_abc.EnumerableABC`"""
|
||||
|
||||
def __init__(self, t: type = None, values: list = None):
|
||||
EnumerableABC.__init__(self, t, values)
|
||||
21
src/cpl-query/cpl/query/enumerable/enumerable_abc.py
Normal file
21
src/cpl-query/cpl/query/enumerable/enumerable_abc.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from cpl.query.base.queryable_abc import QueryableABC
|
||||
|
||||
|
||||
class EnumerableABC(QueryableABC):
|
||||
r"""ABC to define functions on list"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, t: type = None, values: list = None):
|
||||
QueryableABC.__init__(self, t, values)
|
||||
|
||||
def to_iterable(self) -> "IterableABC":
|
||||
r"""Converts :class: `cpl.query.enumerable.enumerable_abc.EnumerableABC` to :class: `cpl.query.iterable.iterable_abc.IterableABC`
|
||||
|
||||
Returns:
|
||||
:class: `cpl.query.iterable.iterable_abc.IterableABC`
|
||||
"""
|
||||
from cpl.query.iterable.iterable import Iterable
|
||||
|
||||
return Iterable(self._type, self.to_list())
|
||||
33
src/cpl-query/cpl/query/exceptions.py
Normal file
33
src/cpl-query/cpl/query/exceptions.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# models
|
||||
class ExceptionArgument(Enum):
|
||||
list = "list"
|
||||
func = "func"
|
||||
type = "type"
|
||||
value = "value"
|
||||
index = "index"
|
||||
|
||||
|
||||
# exceptions
|
||||
class ArgumentNoneException(Exception):
|
||||
r"""Exception when argument is None"""
|
||||
|
||||
def __init__(self, arg: ExceptionArgument):
|
||||
Exception.__init__(self, f"argument {arg} is None")
|
||||
|
||||
|
||||
class IndexOutOfRangeException(Exception):
|
||||
r"""Exception when index is out of range"""
|
||||
|
||||
def __init__(self, err: str = None):
|
||||
Exception.__init__(self, f"List index out of range" if err is None else err)
|
||||
|
||||
|
||||
class InvalidTypeException(Exception):
|
||||
r"""Exception when type is invalid"""
|
||||
|
||||
|
||||
class WrongTypeException(Exception):
|
||||
r"""Exception when type is unexpected"""
|
||||
1
src/cpl-query/cpl/query/extension/__init__.py
Normal file
1
src/cpl-query/cpl/query/extension/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .list import List
|
||||
36
src/cpl-query/cpl/query/extension/list.py
Normal file
36
src/cpl-query/cpl/query/extension/list.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from cpl.query.enumerable.enumerable_abc import EnumerableABC
|
||||
from cpl.query.iterable.iterable import Iterable
|
||||
|
||||
|
||||
class List(Iterable):
|
||||
r"""Implementation of :class: `cpl.query.extension.iterable.Iterable`"""
|
||||
|
||||
def __init__(self, t: type = None, values: Iterable = None):
|
||||
Iterable.__init__(self, t, values)
|
||||
|
||||
def __getitem__(self, *args):
|
||||
return self._values.__getitem__(*args)
|
||||
|
||||
def __setitem__(self, *args):
|
||||
self._values.__setitem__(*args)
|
||||
|
||||
def __delitem__(self, *args):
|
||||
self._values.__delitem__(*args)
|
||||
|
||||
def to_enumerable(self) -> EnumerableABC:
|
||||
r"""Converts :class: `cpl.query.iterable.iterable_abc.IterableABC` to :class: `cpl.query.enumerable.enumerable_abc.EnumerableABC`
|
||||
|
||||
Returns:
|
||||
:class: `cpl.query.enumerable.enumerable_abc.EnumerableABC`
|
||||
"""
|
||||
from cpl.query.enumerable.enumerable import Enumerable
|
||||
|
||||
return Enumerable(self._type, self.to_list())
|
||||
|
||||
def to_iterable(self) -> Iterable:
|
||||
r"""Converts :class: `cpl.query.enumerable.enumerable_abc.EnumerableABC` to :class: `cpl.query.iterable.iterable_abc.IterableABC`
|
||||
|
||||
Returns:
|
||||
:class: `cpl.query.iterable.iterable_abc.IterableABC`
|
||||
"""
|
||||
return Iterable(self._type, self.to_list())
|
||||
@@ -1,65 +0,0 @@
|
||||
from typing import Generic, Iterable, Iterator, Optional
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.query.enumerable import Enumerable
|
||||
|
||||
|
||||
class ImmutableList(Generic[T], Enumerable[T]):
|
||||
def __init__(self, source: Optional[Iterable[T]] = None):
|
||||
Enumerable.__init__(self, [])
|
||||
if source is None:
|
||||
source = []
|
||||
elif not isinstance(source, list):
|
||||
source = list(source)
|
||||
|
||||
self.__source = source
|
||||
|
||||
@property
|
||||
def _source(self) -> list[T]:
|
||||
return self.__source
|
||||
|
||||
@_source.setter
|
||||
def _source(self, value: list[T]) -> None:
|
||||
self.__source = value
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
return iter(self._source)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._source)
|
||||
|
||||
def __getitem__(self, index: int) -> T:
|
||||
return self._source[index]
|
||||
|
||||
def __contains__(self, item: T) -> bool:
|
||||
return item in self._source
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"List({self._source!r})"
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return len(self._source)
|
||||
|
||||
def add(self, item: T) -> None:
|
||||
self._source.append(item)
|
||||
|
||||
def extend(self, items: Iterable[T]) -> None:
|
||||
self._source.extend(items)
|
||||
|
||||
def insert(self, index: int, item: T) -> None:
|
||||
self._source.insert(index, item)
|
||||
|
||||
def remove(self, item: T) -> None:
|
||||
self._source.remove(item)
|
||||
|
||||
def pop(self, index: int = -1) -> T:
|
||||
return self._source.pop(index)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._source.clear()
|
||||
|
||||
def to_enumerable(self) -> "Enumerable[T]":
|
||||
from cpl.query.enumerable import Enumerable
|
||||
|
||||
return Enumerable(self._source)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user