Compare commits
9 Commits
2025.09.23
...
2025.09.24
| Author | SHA1 | Date | |
|---|---|---|---|
| b49f663ae0 | |||
| 287f5e3149 | |||
| 4c8cd988cc | |||
| cdb4a0fb34 | |||
| cf8edafd39 | |||
| 01a2ff7166 | |||
| 2da6d679ad | |||
| a1cfe76047 | |||
| c71a3df62c |
@@ -4,8 +4,12 @@ from cpl import api
|
||||
from cpl.api.application.web_app import WebApp
|
||||
from cpl.application import ApplicationBuilder
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.auth.schema import AuthUser, Role
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.console import Console
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.core.utils.cache import Cache
|
||||
from custom.api.src.scoped_service import ScopedService
|
||||
from service import PingService
|
||||
|
||||
|
||||
@@ -21,6 +25,11 @@ def main():
|
||||
builder.services.add_transient(PingService)
|
||||
builder.services.add_module(api)
|
||||
|
||||
builder.services.add_scoped(ScopedService)
|
||||
|
||||
builder.services.add_cache(AuthUser)
|
||||
builder.services.add_cache(Role)
|
||||
|
||||
app = builder.build()
|
||||
app.with_logging()
|
||||
app.with_database()
|
||||
@@ -31,6 +40,36 @@ def main():
|
||||
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET", authentication=True, permissions=[Permissions.administrator])
|
||||
app.with_routes_directory("routes")
|
||||
|
||||
provider = builder.service_provider
|
||||
user_cache = provider.get_service(Cache[AuthUser])
|
||||
role_cache = provider.get_service(Cache[Role])
|
||||
|
||||
if role_cache == user_cache:
|
||||
raise Exception("Cache service is not working")
|
||||
|
||||
s1 = provider.get_service(ScopedService)
|
||||
s2 = provider.get_service(ScopedService)
|
||||
|
||||
if s1.name == s2.name:
|
||||
raise Exception("Scoped service is not working")
|
||||
|
||||
with provider.create_scope() as scope:
|
||||
s3 = scope.get_service(ScopedService)
|
||||
s4 = scope.get_service(ScopedService)
|
||||
|
||||
if s3.name != s4.name:
|
||||
raise Exception("Scoped service is not working")
|
||||
|
||||
if s1.name == s3.name:
|
||||
raise Exception("Scoped service is not working")
|
||||
|
||||
Console.write_line(
|
||||
s1.name,
|
||||
s2.name,
|
||||
s3.name,
|
||||
s4.name,
|
||||
)
|
||||
|
||||
app.run()
|
||||
|
||||
|
||||
@@ -5,12 +5,17 @@ from starlette.responses import JSONResponse
|
||||
|
||||
from cpl.api import APILogger
|
||||
from cpl.api.router import Router
|
||||
from cpl.core.console import Console
|
||||
from cpl.dependency import ServiceProvider
|
||||
from custom.api.src.scoped_service import ScopedService
|
||||
|
||||
|
||||
@Router.authenticate()
|
||||
# @Router.authorize(permissions=[Permissions.administrator])
|
||||
# @Router.authorize(policies=["test"])
|
||||
@Router.get(f"/ping")
|
||||
async def ping(r: Request, ping: PingService, logger: APILogger):
|
||||
async def ping(r: Request, ping: PingService, logger: APILogger, provider: ServiceProvider, scoped: ScopedService):
|
||||
logger.info(f"Ping: {ping}")
|
||||
|
||||
Console.write_line(scoped.name)
|
||||
return JSONResponse(ping.ping(r))
|
||||
14
example/custom/api/src/scoped_service.py
Normal file
14
example/custom/api/src/scoped_service.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from cpl.core.console.console import Console
|
||||
from cpl.core.utils.string import String
|
||||
|
||||
|
||||
class ScopedService:
|
||||
def __init__(self):
|
||||
self._name = String.random(8)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def run(self):
|
||||
Console.write_line(f"Im {self._name}")
|
||||
@@ -3,7 +3,7 @@ from cpl.auth.keycloak import KeycloakAdmin
|
||||
from cpl.core.console import Console
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.core.log import LoggerABC
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
from model.city import City
|
||||
from model.city_dao import CityDao
|
||||
from model.user import User
|
||||
@@ -11,7 +11,7 @@ from model.user_dao import UserDao
|
||||
|
||||
|
||||
class Application(ApplicationABC):
|
||||
def __init__(self, services: ServiceProviderABC):
|
||||
def __init__(self, services: ServiceProvider):
|
||||
ApplicationABC.__init__(self, services)
|
||||
|
||||
self._logger = services.get_service(LoggerABC)
|
||||
@@ -1,7 +1,6 @@
|
||||
from cpl.application.abc import ApplicationABC
|
||||
from cpl.core.console.console import Console
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency.scope import Scope
|
||||
from cpl.dependency import ServiceProvider
|
||||
from di.static_test import StaticTest
|
||||
from di.test_abc import TestABC
|
||||
from di.test_service import TestService
|
||||
@@ -10,33 +9,37 @@ from di.tester import Tester
|
||||
|
||||
|
||||
class Application(ApplicationABC):
|
||||
def __init__(self, services: ServiceProviderABC):
|
||||
def __init__(self, services: ServiceProvider):
|
||||
ApplicationABC.__init__(self, services)
|
||||
|
||||
def _part_of_scoped(self):
|
||||
ts: TestService = self._services.get_service(TestService)
|
||||
ts.run()
|
||||
|
||||
def configure(self): ...
|
||||
|
||||
def main(self):
|
||||
with self._services.create_scope() as scope:
|
||||
Console.write_line("Scope1")
|
||||
ts: TestService = scope.service_provider.get_service(TestService)
|
||||
ts: TestService = scope.get_service(TestService)
|
||||
ts.run()
|
||||
dit: DITesterService = scope.service_provider.get_service(DITesterService)
|
||||
dit: DITesterService = scope.get_service(DITesterService)
|
||||
dit.run()
|
||||
|
||||
if ts.name != dit.name:
|
||||
raise Exception("DI is broken!")
|
||||
|
||||
with self._services.create_scope() as scope:
|
||||
Console.write_line("Scope2")
|
||||
ts: TestService = scope.service_provider.get_service(TestService)
|
||||
ts: TestService = scope.get_service(TestService)
|
||||
ts.run()
|
||||
dit: DITesterService = scope.service_provider.get_service(DITesterService)
|
||||
dit: DITesterService = scope.get_service(DITesterService)
|
||||
dit.run()
|
||||
|
||||
if ts.name != dit.name:
|
||||
raise Exception("DI is broken!")
|
||||
|
||||
Console.write_line("Global")
|
||||
self._part_of_scoped()
|
||||
StaticTest.test()
|
||||
|
||||
self._services.get_service(Tester)
|
||||
Console.write_line(self._services.get_services(list[TestABC]))
|
||||
Console.write_line(self._services.get_services(TestABC))
|
||||
@@ -6,6 +6,10 @@ class DITesterService:
|
||||
def __init__(self, ts: TestService):
|
||||
self._ts = ts
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._ts.name
|
||||
|
||||
def run(self):
|
||||
Console.write_line("DIT: ")
|
||||
self._ts.run()
|
||||
@@ -1,5 +1,5 @@
|
||||
from cpl.application.abc import StartupABC
|
||||
from cpl.dependency import ServiceProviderABC, ServiceCollection
|
||||
from cpl.dependency import ServiceProvider, ServiceCollection
|
||||
from di.di_tester_service import DITesterService
|
||||
from di.test1_service import Test1Service
|
||||
from di.test2_service import Test2Service
|
||||
@@ -12,9 +12,11 @@ class Startup(StartupABC):
|
||||
def __init__(self):
|
||||
StartupABC.__init__(self)
|
||||
|
||||
def configure_configuration(self): ...
|
||||
@staticmethod
|
||||
def configure_configuration(): ...
|
||||
|
||||
def configure_services(self, services: ServiceCollection) -> ServiceProviderABC:
|
||||
@staticmethod
|
||||
def configure_services(services: ServiceCollection) -> ServiceProvider:
|
||||
services.add_scoped(TestService)
|
||||
services.add_scoped(DITesterService)
|
||||
|
||||
10
example/custom/di/src/di/static_test.py
Normal file
10
example/custom/di/src/di/static_test.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from cpl.dependency import ServiceProvider, ServiceProvider
|
||||
from cpl.dependency.inject import inject
|
||||
from di.test_service import TestService
|
||||
|
||||
|
||||
class StaticTest:
|
||||
@staticmethod
|
||||
@inject
|
||||
def test(services: ServiceProvider, t1: TestService):
|
||||
t1.run()
|
||||
@@ -6,7 +6,7 @@ from di.test_abc import TestABC
|
||||
|
||||
class Test1Service(TestABC):
|
||||
def __init__(self):
|
||||
TestABC.__init__(self, String.random_string(string.ascii_lowercase, 8))
|
||||
TestABC.__init__(self, String.random(8))
|
||||
|
||||
def run(self):
|
||||
Console.write_line(f"Im {self._name}")
|
||||
@@ -6,7 +6,7 @@ from di.test_abc import TestABC
|
||||
|
||||
class Test2Service(TestABC):
|
||||
def __init__(self):
|
||||
TestABC.__init__(self, String.random_string(string.ascii_lowercase, 8))
|
||||
TestABC.__init__(self, String.random(8))
|
||||
|
||||
def run(self):
|
||||
Console.write_line(f"Im {self._name}")
|
||||
@@ -1,5 +1,3 @@
|
||||
import string
|
||||
|
||||
from cpl.core.console.console import Console
|
||||
from cpl.core.utils.string import String
|
||||
|
||||
@@ -8,5 +6,9 @@ class TestService:
|
||||
def __init__(self):
|
||||
self._name = String.random(8)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def run(self):
|
||||
Console.write_line(f"Im {self._name}")
|
||||
@@ -4,19 +4,20 @@ from typing import Optional
|
||||
from cpl.application.abc import ApplicationABC
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.console import Console
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.core.log import LoggerABC
|
||||
from cpl.core.pipes import IPAddressPipe
|
||||
from cpl.mail import EMail, EMailClientABC
|
||||
from cpl.query.extension.list import List
|
||||
from cpl.query import List
|
||||
from general.scoped_service import ScopedService
|
||||
from test_service import TestService
|
||||
from test_settings import TestSettings
|
||||
|
||||
|
||||
class Application(ApplicationABC):
|
||||
|
||||
def __init__(self, services: ServiceProviderABC):
|
||||
def __init__(self, services: ServiceProvider):
|
||||
ApplicationABC.__init__(self, services)
|
||||
self._logger = self._services.get_service(LoggerABC)
|
||||
self._mailer = self._services.get_service(EMailClientABC)
|
||||
@@ -38,7 +39,7 @@ class Application(ApplicationABC):
|
||||
def main(self):
|
||||
self._logger.debug(f"Host: {Environment.get_host_name()}")
|
||||
self._logger.debug(f"Environment: {Environment.get_environment()}")
|
||||
Console.write_line(List(int, range(0, 10)).select(lambda x: f"x={x}").to_list())
|
||||
Console.write_line(List(range(0, 10)).select(lambda x: f"x={x}").to_list())
|
||||
Console.spinner("Test", self._wait, 2, spinner_foreground_color="red")
|
||||
test: TestService = self._services.get_service(TestService)
|
||||
ip_pipe: IPAddressPipe = self._services.get_service(IPAddressPipe)
|
||||
@@ -48,10 +49,21 @@ class Application(ApplicationABC):
|
||||
Console.write_line(f"DI working: {test == test2 and ip_pipe != ip_pipe2}")
|
||||
Console.write_line(self._services.get_service(LoggerABC))
|
||||
|
||||
scope = self._services.create_scope()
|
||||
Console.write_line("scope", scope)
|
||||
with self._services.create_scope() as s:
|
||||
Console.write_line("with scope", s)
|
||||
root_scoped_service = self._services.get_service(ScopedService)
|
||||
with self._services.create_scope() as scope:
|
||||
s_srvc1 = scope.get_service(ScopedService)
|
||||
s_srvc2 = scope.get_service(ScopedService)
|
||||
|
||||
Console.write_line(root_scoped_service)
|
||||
Console.write_line(s_srvc1)
|
||||
Console.write_line(s_srvc2)
|
||||
if root_scoped_service == s_srvc1 or s_srvc1 != s_srvc2:
|
||||
raise Exception("Root scoped service should not be equal to scoped service")
|
||||
|
||||
root_scoped_service2 = self._services.get_service(ScopedService)
|
||||
Console.write_line(root_scoped_service2)
|
||||
if root_scoped_service == root_scoped_service2:
|
||||
raise Exception("Root scoped service should be equal to root scoped service 2")
|
||||
|
||||
test_settings = Configuration.get(TestSettings)
|
||||
Console.write_line(test_settings.value)
|
||||
10
example/custom/general/src/general/scoped_service.py
Normal file
10
example/custom/general/src/general/scoped_service.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from cpl.core.console import Console
|
||||
|
||||
|
||||
class ScopedService:
|
||||
def __init__(self):
|
||||
self.value = "I am a scoped service"
|
||||
Console.write_line(self.value, self)
|
||||
|
||||
def get_value(self):
|
||||
return self.value
|
||||
@@ -4,6 +4,7 @@ from cpl.core.configuration import Configuration
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.core.pipes import IPAddressPipe
|
||||
from cpl.dependency import ServiceCollection
|
||||
from general.scoped_service import ScopedService
|
||||
from test_service import TestService
|
||||
|
||||
|
||||
@@ -21,3 +22,4 @@ class Startup(StartupABC):
|
||||
services.add_module(mail)
|
||||
services.add_transient(IPAddressPipe)
|
||||
services.add_singleton(TestService)
|
||||
services.add_scoped(ScopedService)
|
||||
@@ -1,10 +1,10 @@
|
||||
from cpl.application.abc import ApplicationExtensionABC
|
||||
from cpl.core.console import Console
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class TestExtension(ApplicationExtensionABC):
|
||||
|
||||
@staticmethod
|
||||
def run(services: ServiceProviderABC):
|
||||
def run(services: ServiceProvider):
|
||||
Console.write_line("Hello World from App Extension")
|
||||
@@ -1,10 +1,10 @@
|
||||
from cpl.core.console.console import Console
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.core.pipes.ip_address_pipe import IPAddressPipe
|
||||
|
||||
|
||||
class TestService:
|
||||
def __init__(self, provider: ServiceProviderABC):
|
||||
def __init__(self, provider: ServiceProvider):
|
||||
self._provider = provider
|
||||
|
||||
def run(self):
|
||||
60
example/custom/query/main.py
Normal file
60
example/custom/query/main.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from cpl.core.console import Console
|
||||
from cpl.core.utils.benchmark import Benchmark
|
||||
from cpl.query.enumerable import Enumerable
|
||||
from cpl.query.immutable_list import ImmutableList
|
||||
from cpl.query.list import List
|
||||
from cpl.query.set import Set
|
||||
|
||||
|
||||
def _default():
|
||||
Console.write_line(Enumerable.empty().to_list())
|
||||
|
||||
Console.write_line(Enumerable.range(0, 100).length)
|
||||
Console.write_line(Enumerable.range(0, 100).to_list())
|
||||
|
||||
Console.write_line(Enumerable.range(0, 100).where(lambda x: x % 2 == 0).length)
|
||||
Console.write_line(
|
||||
Enumerable.range(0, 100).where(lambda x: x % 2 == 0).to_list().select(lambda x: str(x)).to_list()
|
||||
)
|
||||
Console.write_line(List)
|
||||
|
||||
s =Enumerable.range(0, 10).to_set()
|
||||
Console.write_line(s)
|
||||
s.add(1)
|
||||
Console.write_line(s)
|
||||
|
||||
data = Enumerable(
|
||||
[
|
||||
{"name": "Alice", "age": 30},
|
||||
{"name": "Dave", "age": 35},
|
||||
{"name": "Charlie", "age": 25},
|
||||
{"name": "Bob", "age": 25},
|
||||
]
|
||||
)
|
||||
|
||||
Console.write_line(data.order_by(lambda x: x["age"]).to_list())
|
||||
Console.write_line(data.order_by(lambda x: x["age"]).then_by(lambda x: x["name"]).to_list())
|
||||
Console.write_line(data.order_by(lambda x: x["name"]).then_by(lambda x: x["age"]).to_list())
|
||||
|
||||
|
||||
def t_benchmark(data: list):
|
||||
Benchmark.all("Enumerable", lambda: Enumerable(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
|
||||
Benchmark.all("Set", lambda: Set(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
|
||||
Benchmark.all("List", lambda: List(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
|
||||
Benchmark.all(
|
||||
"ImmutableList", lambda: ImmutableList(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list()
|
||||
)
|
||||
Benchmark.all("List comprehension", lambda: [x * 2 for x in data if x % 2 == 0])
|
||||
|
||||
|
||||
def main():
|
||||
N = 10_000_000
|
||||
data = list(range(N))
|
||||
#t_benchmark(data)
|
||||
|
||||
Console.write_line()
|
||||
_default()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,14 +1,14 @@
|
||||
from cpl.application import ApplicationABC
|
||||
from cpl.core.configuration import ConfigurationABC
|
||||
from cpl.core.console import Console
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.translation.translate_pipe import TranslatePipe
|
||||
from cpl.translation.translation_service_abc import TranslationServiceABC
|
||||
from cpl.translation.translation_settings import TranslationSettings
|
||||
|
||||
|
||||
class Application(ApplicationABC):
|
||||
def __init__(self, config: ConfigurationABC, services: ServiceProviderABC):
|
||||
def __init__(self, config: ConfigurationABC, services: ServiceProvider):
|
||||
ApplicationABC.__init__(self, config, services)
|
||||
|
||||
self._translate: TranslatePipe = services.get_service(TranslatePipe)
|
||||
@@ -1,6 +1,6 @@
|
||||
from cpl.application import StartupABC
|
||||
from cpl.core.configuration import ConfigurationABC
|
||||
from cpl.dependency import ServiceProviderABC, ServiceCollection
|
||||
from cpl.dependency import ServiceProvider, ServiceCollection
|
||||
from cpl.core.environment import Environment
|
||||
|
||||
|
||||
@@ -12,6 +12,6 @@ class Startup(StartupABC):
|
||||
configuration.add_json_file("appsettings.json")
|
||||
return configuration
|
||||
|
||||
def configure_services(self, services: ServiceCollection, environment: Environment) -> ServiceProviderABC:
|
||||
def configure_services(self, services: ServiceCollection, environment: Environment) -> ServiceProvider:
|
||||
services.add_translation()
|
||||
return services.build()
|
||||
@@ -27,14 +27,14 @@ from cpl.api.settings import ApiSettings
|
||||
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
|
||||
from cpl.application.abc.application_abc import ApplicationABC
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
PolicyInput = Union[dict[str, PolicyResolver], Policy]
|
||||
|
||||
|
||||
class WebApp(ApplicationABC):
|
||||
def __init__(self, services: ServiceProviderABC):
|
||||
def __init__(self, services: ServiceProvider):
|
||||
super().__init__(services, [auth, api])
|
||||
self._app: Starlette | None = None
|
||||
|
||||
@@ -44,15 +44,15 @@ class WebApp(ApplicationABC):
|
||||
self._policies = services.get_service(PolicyRegistry)
|
||||
self._routes = services.get_service(RouteRegistry)
|
||||
|
||||
self._middleware: list[Middleware] = [
|
||||
Middleware(RequestMiddleware),
|
||||
Middleware(LoggingMiddleware),
|
||||
]
|
||||
self._middleware: list[Middleware] = []
|
||||
self._exception_handlers: Mapping[Any, ExceptionHandler] = {
|
||||
Exception: self._handle_exception,
|
||||
APIError: self._handle_exception,
|
||||
}
|
||||
|
||||
self.with_middleware(RequestMiddleware)
|
||||
self.with_middleware(LoggingMiddleware)
|
||||
|
||||
async def _handle_exception(self, request: Request, exc: Exception):
|
||||
if isinstance(exc, APIError):
|
||||
self._logger.error(exc)
|
||||
@@ -168,9 +168,9 @@ class WebApp(ApplicationABC):
|
||||
self._check_for_app()
|
||||
|
||||
if isinstance(middleware, Middleware):
|
||||
self._middleware.append(middleware)
|
||||
self._middleware.append(inject(middleware))
|
||||
elif callable(middleware):
|
||||
self._middleware.append(Middleware(middleware))
|
||||
self._middleware.append(Middleware(inject(middleware)))
|
||||
else:
|
||||
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
|
||||
|
||||
@@ -220,7 +220,7 @@ class WebApp(ApplicationABC):
|
||||
self._validate_policies()
|
||||
|
||||
if self._app is None:
|
||||
routes = [route.to_starlette(self._services.inject) for route in self._routes.all()]
|
||||
routes = [route.to_starlette(inject) for route in self._routes.all()]
|
||||
|
||||
app = Starlette(
|
||||
routes=routes,
|
||||
|
||||
@@ -9,12 +9,10 @@ from cpl.api.router import Router
|
||||
from cpl.auth.keycloak import KeycloakClient
|
||||
from cpl.auth.schema import AuthUserDao, AuthUser
|
||||
from cpl.core.ctx import set_user
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class AuthenticationMiddleware(ASGIMiddleware):
|
||||
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
|
||||
@@ -9,12 +9,10 @@ from cpl.api.registry.policy import PolicyRegistry
|
||||
from cpl.api.router import Router
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
from cpl.core.ctx.user_context import get_user
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
|
||||
class AuthorizationMiddleware(ASGIMiddleware):
|
||||
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
|
||||
@@ -6,12 +6,10 @@ from starlette.types import Receive, Scope, Send
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.middleware.request import get_request
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class LoggingMiddleware(ASGIMiddleware):
|
||||
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, logger: APILogger):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
|
||||
@@ -9,17 +9,18 @@ from starlette.types import Scope, Receive, Send
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.typing import TRequest
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
|
||||
|
||||
|
||||
class RequestMiddleware(ASGIMiddleware):
|
||||
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, logger: APILogger):
|
||||
def __init__(self, app, provider: ServiceProvider, logger: APILogger):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._provider = provider
|
||||
self._logger = logger
|
||||
|
||||
self._ctx_token = None
|
||||
@@ -29,7 +30,8 @@ class RequestMiddleware(ASGIMiddleware):
|
||||
await self.set_request_data(request)
|
||||
|
||||
try:
|
||||
await self._app(scope, receive, send)
|
||||
with self._provider.create_scope():
|
||||
inject(await self._app(scope, receive, send))
|
||||
finally:
|
||||
await self.clean_request_data()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import Enum
|
||||
from cpl.api.model.validation_match import ValidationMatch
|
||||
from cpl.api.registry.route import RouteRegistry
|
||||
from cpl.api.typing import HTTPMethods
|
||||
from cpl.dependency import get_provider
|
||||
|
||||
|
||||
class Router:
|
||||
@@ -95,9 +96,7 @@ class Router:
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
|
||||
if not registry:
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
routes = ServiceProviderABC.get_global_service(RouteRegistry)
|
||||
routes = get_provider().get_service(RouteRegistry)
|
||||
else:
|
||||
routes = registry
|
||||
|
||||
@@ -144,9 +143,8 @@ class Router:
|
||||
"""
|
||||
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
routes = ServiceProviderABC.get_global_service(RouteRegistry)
|
||||
routes = get_provider().get_service(RouteRegistry)
|
||||
|
||||
def inner(fn):
|
||||
path = getattr(fn, "_route_path", None)
|
||||
|
||||
@@ -5,7 +5,7 @@ from cpl.application.host import Host
|
||||
from cpl.core.log.log_level import LogLevel
|
||||
from cpl.core.log.log_settings import LogSettings
|
||||
from cpl.core.log.logger_abc import LoggerABC
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
|
||||
def __not_implemented__(package: str, func: Callable):
|
||||
@@ -16,12 +16,12 @@ class ApplicationABC(ABC):
|
||||
r"""ABC for the Application class
|
||||
|
||||
Parameters:
|
||||
services: :class:`cpl.dependency.service_provider_abc.ServiceProviderABC`
|
||||
services: :class:`cpl.dependency.service_provider.ServiceProvider`
|
||||
Contains instances of prepared objects
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, services: ServiceProviderABC, required_modules: list[str | object] = None):
|
||||
def __init__(self, services: ServiceProvider, required_modules: list[str | object] = None):
|
||||
self._services = services
|
||||
self._required_modules = (
|
||||
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
|
||||
class ApplicationExtensionABC(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def run(services: ServiceProviderABC): ...
|
||||
def run(services: ServiceProvider): ...
|
||||
|
||||
@@ -7,6 +7,7 @@ from cpl.application.abc.startup_abc import StartupABC
|
||||
from cpl.application.abc.startup_extension_abc import StartupExtensionABC
|
||||
from cpl.application.host import Host
|
||||
from cpl.core.errors import dependency_error
|
||||
from cpl.dependency.context import get_provider, use_root_provider
|
||||
from cpl.dependency.service_collection import ServiceCollection
|
||||
|
||||
TApp = TypeVar("TApp", bound=ApplicationABC)
|
||||
@@ -21,6 +22,7 @@ class ApplicationBuilder(Generic[TApp]):
|
||||
self._app = app if app is not None else ApplicationABC
|
||||
|
||||
self._services = ServiceCollection()
|
||||
use_root_provider(self._services.build())
|
||||
|
||||
self._startup: Optional[StartupABC] = None
|
||||
self._app_extensions: list[Type[ApplicationExtensionABC]] = []
|
||||
@@ -34,7 +36,12 @@ class ApplicationBuilder(Generic[TApp]):
|
||||
|
||||
@property
|
||||
def service_provider(self):
|
||||
return self._services.build()
|
||||
provider = get_provider()
|
||||
if provider is None:
|
||||
provider = self._services.build()
|
||||
use_root_provider(provider)
|
||||
|
||||
return provider
|
||||
|
||||
def validate_app_required_modules(self, app: ApplicationABC):
|
||||
for module in app.required_modules:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from cpl.core.utils.get_value import get_value
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class KeycloakUser:
|
||||
@@ -32,5 +32,5 @@ class KeycloakUser:
|
||||
def id(self) -> str:
|
||||
from cpl.auth import KeycloakAdmin
|
||||
|
||||
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
keycloak_admin: KeycloakAdmin = get_provider().get_service(KeycloakAdmin)
|
||||
return keycloak_admin.get_user_id(self._username)
|
||||
|
||||
@@ -10,7 +10,8 @@ from cpl.core.log.logger import Logger
|
||||
from cpl.core.typing import Id, SerialId
|
||||
from cpl.core.utils.credential_manager import CredentialManager
|
||||
from cpl.database.abc.db_model_abc import DbModelABC
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
from cpl.dependency import get_provider
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
_logger = Logger(__name__)
|
||||
|
||||
@@ -47,7 +48,7 @@ class ApiKey(DbModelABC):
|
||||
async def permissions(self):
|
||||
from cpl.auth.schema._permission.api_key_permission_dao import ApiKeyPermissionDao
|
||||
|
||||
apiKeyPermissionDao = ServiceProviderABC.get_global_provider().get_service(ApiKeyPermissionDao)
|
||||
apiKeyPermissionDao = get_provider().get_service(ApiKeyPermissionDao)
|
||||
|
||||
return [await x.permission for x in await apiKeyPermissionDao.find_by_api_key_id(self.id)]
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class AuthUser(DbModelABC):
|
||||
@@ -36,12 +36,12 @@ class AuthUser(DbModelABC):
|
||||
return "ANONYMOUS"
|
||||
|
||||
try:
|
||||
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
keycloak = get_provider().get_service(KeycloakAdmin)
|
||||
return keycloak.get_user(self._keycloak_id).get("username")
|
||||
except KeycloakGetError as e:
|
||||
return "UNKNOWN"
|
||||
except Exception as e:
|
||||
logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||
logger = get_provider().get_service(DBLogger)
|
||||
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||
return "UNKNOWN"
|
||||
|
||||
@@ -51,12 +51,12 @@ class AuthUser(DbModelABC):
|
||||
return "ANONYMOUS"
|
||||
|
||||
try:
|
||||
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
keycloak = get_provider().get_service(KeycloakAdmin)
|
||||
return keycloak.get_user(self._keycloak_id).get("email")
|
||||
except KeycloakGetError as e:
|
||||
return "UNKNOWN"
|
||||
except Exception as e:
|
||||
logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||
logger = get_provider().get_service(DBLogger)
|
||||
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||
return "UNKNOWN"
|
||||
|
||||
@@ -64,26 +64,26 @@ class AuthUser(DbModelABC):
|
||||
async def roles(self):
|
||||
from cpl.auth.schema._permission.role_user_dao import RoleUserDao
|
||||
|
||||
role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
|
||||
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao)
|
||||
return [await x.role for x in await role_user_dao.get_by_user_id(self.id)]
|
||||
|
||||
@async_property
|
||||
async def permissions(self):
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
|
||||
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
|
||||
return await auth_user_dao.get_permissions(self.id)
|
||||
|
||||
async def has_permission(self, permission: Permissions) -> bool:
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
|
||||
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
|
||||
return await auth_user_dao.has_permission(self.id, permission)
|
||||
|
||||
async def anonymize(self):
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
|
||||
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
|
||||
|
||||
self._keycloak_id = str(uuid.UUID(int=0))
|
||||
await auth_user_dao.update(self)
|
||||
|
||||
@@ -5,7 +5,7 @@ from cpl.auth.schema._administration.auth_user import AuthUser
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class AuthUserDao(DbModelDaoABC[AuthUser]):
|
||||
@@ -36,7 +36,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
|
||||
async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool:
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
|
||||
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
|
||||
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
|
||||
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
|
||||
result = await self._db.select_map(
|
||||
f"""
|
||||
|
||||
@@ -5,7 +5,7 @@ from async_property import async_property
|
||||
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbJoinModelABC
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class ApiKeyPermission(DbJoinModelABC):
|
||||
@@ -31,7 +31,7 @@ class ApiKeyPermission(DbJoinModelABC):
|
||||
async def api_key(self):
|
||||
from cpl.auth.schema._administration.api_key_dao import ApiKeyDao
|
||||
|
||||
api_key_dao: ApiKeyDao = ServiceProviderABC.get_global_service(ApiKeyDao)
|
||||
api_key_dao: ApiKeyDao = get_provider().get_service(ApiKeyDao)
|
||||
return await api_key_dao.get_by_id(self._api_key_id)
|
||||
|
||||
@property
|
||||
@@ -42,5 +42,5 @@ class ApiKeyPermission(DbJoinModelABC):
|
||||
async def permission(self):
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
|
||||
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
|
||||
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
|
||||
return await permission_dao.get_by_id(self._permission_id)
|
||||
|
||||
@@ -6,7 +6,7 @@ from async_property import async_property
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class Role(DbModelABC):
|
||||
@@ -44,22 +44,22 @@ class Role(DbModelABC):
|
||||
async def permissions(self):
|
||||
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
|
||||
|
||||
role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
|
||||
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao)
|
||||
return [await x.permission for x in await role_permission_dao.get_by_role_id(self.id)]
|
||||
|
||||
@async_property
|
||||
async def users(self):
|
||||
from cpl.auth.schema._permission.role_user_dao import RoleUserDao
|
||||
|
||||
role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
|
||||
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao)
|
||||
return [await x.user for x in await role_user_dao.get_by_role_id(self.id)]
|
||||
|
||||
async def has_permission(self, permission: Permissions) -> bool:
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
|
||||
|
||||
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
|
||||
role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
|
||||
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
|
||||
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao)
|
||||
|
||||
p = await permission_dao.get_by_name(permission.value)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from async_property import async_property
|
||||
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class RolePermission(DbModelABC):
|
||||
@@ -31,7 +31,7 @@ class RolePermission(DbModelABC):
|
||||
async def role(self):
|
||||
from cpl.auth.schema._permission.role_dao import RoleDao
|
||||
|
||||
role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
|
||||
role_dao: RoleDao = get_provider().get_service(RoleDao)
|
||||
return await role_dao.get_by_id(self._role_id)
|
||||
|
||||
@property
|
||||
@@ -42,5 +42,5 @@ class RolePermission(DbModelABC):
|
||||
async def permission(self):
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
|
||||
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
|
||||
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
|
||||
return await permission_dao.get_by_id(self._permission_id)
|
||||
|
||||
@@ -5,7 +5,7 @@ from async_property import async_property
|
||||
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbJoinModelABC
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class RoleUser(DbJoinModelABC):
|
||||
@@ -31,7 +31,7 @@ class RoleUser(DbJoinModelABC):
|
||||
async def user(self):
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
|
||||
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
|
||||
return await auth_user_dao.get_by_id(self._user_id)
|
||||
|
||||
@property
|
||||
@@ -42,5 +42,5 @@ class RoleUser(DbJoinModelABC):
|
||||
async def role(self):
|
||||
from cpl.auth.schema._permission.role_dao import RoleDao
|
||||
|
||||
role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
|
||||
role_dao: RoleDao = get_provider().get_service(RoleDao)
|
||||
return await role_dao.get_by_id(self._role_id)
|
||||
|
||||
@@ -2,15 +2,15 @@ from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
|
||||
from cpl.auth.schema._administration.auth_user import AuthUser
|
||||
from cpl.dependency import get_provider
|
||||
|
||||
_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None)
|
||||
|
||||
|
||||
def set_user(user: Optional[AuthUser]):
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
from cpl.core.log.logger_abc import LoggerABC
|
||||
|
||||
logger = ServiceProviderABC.get_global_service(LoggerABC)
|
||||
logger = get_provider().get_service(LoggerABC)
|
||||
logger.trace("Setting user context", user.id)
|
||||
_user_context.set(user)
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from starlette.requests import Request
|
||||
from cpl.core.log.log_level import LogLevel
|
||||
from cpl.core.log.logger import Logger
|
||||
from cpl.core.typing import Source, Messages
|
||||
from cpl.dependency import get_provider
|
||||
|
||||
|
||||
class StructuredLogger(Logger):
|
||||
@@ -99,10 +100,9 @@ class StructuredLogger(Logger):
|
||||
if user is None:
|
||||
return
|
||||
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin
|
||||
|
||||
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
keycloak = get_provider().get_service(KeycloakAdmin)
|
||||
kc_user = keycloak.get_user(user.keycloak_id)
|
||||
message["user"] = {
|
||||
"id": str(user.id),
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import inspect
|
||||
from typing import Type
|
||||
|
||||
from cpl.core.log import LoggerABC, LogLevel
|
||||
from cpl.core.typing import Messages, Source
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
from cpl.core.typing import Messages
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
|
||||
class WrappedLogger(LoggerABC):
|
||||
@@ -11,18 +13,20 @@ class WrappedLogger(LoggerABC):
|
||||
LoggerABC.__init__(self)
|
||||
assert file_prefix is not None and file_prefix != "", "file_prefix must be a non-empty string"
|
||||
|
||||
t_logger = ServiceProviderABC.get_global_service(LoggerABC)
|
||||
self._t_logger = type(t_logger) if t_logger is not None else None
|
||||
self._source = None
|
||||
self._file_prefix = file_prefix
|
||||
|
||||
self._set_logger()
|
||||
|
||||
def _set_logger(self):
|
||||
if self._t_logger is None:
|
||||
raise Exception("No LoggerABC service registered in ServiceProviderABC")
|
||||
@inject
|
||||
def _set_logger(self, services: ServiceProvider):
|
||||
from cpl.core.log import Logger
|
||||
|
||||
self._logger = self._t_logger(self._source, self._file_prefix)
|
||||
t_logger: Type[Logger] = services.get_service_type(LoggerABC)
|
||||
if t_logger is None:
|
||||
raise Exception("No LoggerABC service registered in ServiceProvider")
|
||||
|
||||
self._logger = t_logger(self._source, self._file_prefix)
|
||||
|
||||
def set_level(self, level: LogLevel):
|
||||
self._logger.set_level(level)
|
||||
@@ -39,8 +43,8 @@ class WrappedLogger(LoggerABC):
|
||||
from cpl.dependency import ServiceCollection
|
||||
|
||||
ignore_classes = [
|
||||
ServiceProviderABC,
|
||||
ServiceProviderABC.__subclasses__(),
|
||||
ServiceProvider,
|
||||
ServiceProvider.__subclasses__(),
|
||||
ServiceCollection,
|
||||
WrappedLogger,
|
||||
WrappedLogger.__subclasses__(),
|
||||
|
||||
@@ -14,3 +14,4 @@ UuidId = str | UUID
|
||||
SerialId = int
|
||||
|
||||
Id = UuidId | SerialId
|
||||
TNumber = int | float | complex
|
||||
|
||||
57
src/cpl-core/cpl/core/utils/benchmark.py
Normal file
57
src/cpl-core/cpl/core/utils/benchmark.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import time
|
||||
import tracemalloc
|
||||
from typing import List, Callable
|
||||
|
||||
from cpl.core.console import Console
|
||||
|
||||
|
||||
class Benchmark:
|
||||
|
||||
@staticmethod
|
||||
def all(label: str, func: Callable, iterations: int = 5):
|
||||
times: List[float] = []
|
||||
mems: List[float] = []
|
||||
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
func()
|
||||
end = time.perf_counter()
|
||||
times.append(end - start)
|
||||
|
||||
for _ in range(iterations):
|
||||
tracemalloc.start()
|
||||
func()
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
mems.append(peak)
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
avg_mem = sum(mems) / len(mems) / (1024 * 1024)
|
||||
Console.write_line(f"{label:20s} -> min {min(times):.6f}s avg {avg_time:.6f}s mem {avg_mem:.8f} MB")
|
||||
|
||||
@staticmethod
|
||||
def time(label: str, func: Callable, iterations: int = 5):
|
||||
times: List[float] = []
|
||||
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
func()
|
||||
end = time.perf_counter()
|
||||
times.append(end - start)
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
Console.write_line(f"{label:20s} -> min {min(times):.6f}s avg {avg_time:.6f}s")
|
||||
|
||||
@staticmethod
|
||||
def memory(label: str, func: Callable, iterations: int = 5):
|
||||
mems: List[float] = []
|
||||
|
||||
for _ in range(iterations):
|
||||
tracemalloc.start()
|
||||
func()
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
mems.append(peak)
|
||||
|
||||
avg_mem = sum(mems) / len(mems) / (1024 * 1024)
|
||||
Console.write_line(f"{label:20s} -> mem {avg_mem:.2f} MB")
|
||||
100
src/cpl-core/cpl/core/utils/cache.py
Normal file
100
src/cpl-core/cpl/core/utils/cache.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import threading
|
||||
import time
|
||||
from typing import Generic
|
||||
|
||||
from cpl.core.typing import T
|
||||
|
||||
|
||||
class Cache(Generic[T]):
|
||||
def __init__(self, default_ttl: int = None, cleanup_interval: int = 60, t: type = None):
|
||||
self._store = {}
|
||||
self._default_ttl = default_ttl
|
||||
self._lock = threading.Lock()
|
||||
self._cleanup_interval = cleanup_interval
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
self._type = t
|
||||
|
||||
# Start background cleanup thread
|
||||
self._thread = threading.Thread(target=self._auto_cleanup, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def set(self, key: str, value: T, ttl: int = None) -> None:
|
||||
"""Store a value in the cache with optional TTL override."""
|
||||
expire_at = None
|
||||
ttl = ttl if ttl is not None else self._default_ttl
|
||||
if ttl is not None:
|
||||
expire_at = time.time() + ttl
|
||||
|
||||
with self._lock:
|
||||
self._store[key] = (value, expire_at)
|
||||
|
||||
def get(self, key: str) -> T | None:
|
||||
"""Retrieve a value from the cache if not expired."""
|
||||
with self._lock:
|
||||
item = self._store.get(key)
|
||||
if not item:
|
||||
return None
|
||||
value, expire_at = item
|
||||
if expire_at and expire_at < time.time():
|
||||
# Expired -> remove and return None
|
||||
del self._store[key]
|
||||
return None
|
||||
return value
|
||||
|
||||
def get_all(self) -> list[T]:
|
||||
"""Retrieve all non-expired values from the cache."""
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
valid_items = []
|
||||
expired_keys = []
|
||||
for k, (v, exp) in self._store.items():
|
||||
if exp and exp < now:
|
||||
expired_keys.append(k)
|
||||
else:
|
||||
valid_items.append(v)
|
||||
for k in expired_keys:
|
||||
del self._store[k]
|
||||
return valid_items
|
||||
|
||||
def has(self, key: str) -> bool:
|
||||
"""Check if a key exists and is not expired."""
|
||||
with self._lock:
|
||||
item = self._store.get(key)
|
||||
if not item:
|
||||
return False
|
||||
_, expire_at = item
|
||||
if expire_at and expire_at < time.time():
|
||||
# Expired -> remove and return False
|
||||
del self._store[key]
|
||||
return False
|
||||
return True
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
"""Remove an item from the cache."""
|
||||
with self._lock:
|
||||
self._store.pop(key, None)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the entire cache."""
|
||||
with self._lock:
|
||||
self._store.clear()
|
||||
|
||||
def _auto_cleanup(self):
|
||||
"""Background thread to clean expired items."""
|
||||
while not self._stop_event.is_set():
|
||||
self.cleanup()
|
||||
self._stop_event.wait(self._cleanup_interval)
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Remove expired items immediately."""
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
expired_keys = [k for k, (_, exp) in self._store.items() if exp and exp < now]
|
||||
for k in expired_keys:
|
||||
del self._store[k]
|
||||
|
||||
def stop(self):
|
||||
"""Stop the background cleanup thread."""
|
||||
self._stop_event.set()
|
||||
self._thread.join()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user