Compare commits
6 Commits
2025.09.24
...
2025.09.24
| Author | SHA1 | Date | |
|---|---|---|---|
| 6a3fdb3ebd | |||
| b49f663ae0 | |||
| 287f5e3149 | |||
| 4c8cd988cc | |||
| cdb4a0fb34 | |||
| cf8edafd39 |
@@ -6,8 +6,10 @@ from cpl.application import ApplicationBuilder
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.auth.schema import AuthUser, Role
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.console import Console
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.core.utils.cache import Cache
|
||||
from custom.api.src.scoped_service import ScopedService
|
||||
from service import PingService
|
||||
|
||||
|
||||
@@ -23,6 +25,8 @@ def main():
|
||||
builder.services.add_transient(PingService)
|
||||
builder.services.add_module(api)
|
||||
|
||||
builder.services.add_scoped(ScopedService)
|
||||
|
||||
builder.services.add_cache(AuthUser)
|
||||
builder.services.add_cache(Role)
|
||||
|
||||
@@ -40,6 +44,32 @@ def main():
|
||||
user_cache = provider.get_service(Cache[AuthUser])
|
||||
role_cache = provider.get_service(Cache[Role])
|
||||
|
||||
if role_cache == user_cache:
|
||||
raise Exception("Cache service is not working")
|
||||
|
||||
s1 = provider.get_service(ScopedService)
|
||||
s2 = provider.get_service(ScopedService)
|
||||
|
||||
if s1.name == s2.name:
|
||||
raise Exception("Scoped service is not working")
|
||||
|
||||
with provider.create_scope() as scope:
|
||||
s3 = scope.get_service(ScopedService)
|
||||
s4 = scope.get_service(ScopedService)
|
||||
|
||||
if s3.name != s4.name:
|
||||
raise Exception("Scoped service is not working")
|
||||
|
||||
if s1.name == s3.name:
|
||||
raise Exception("Scoped service is not working")
|
||||
|
||||
Console.write_line(
|
||||
s1.name,
|
||||
s2.name,
|
||||
s3.name,
|
||||
s4.name,
|
||||
)
|
||||
|
||||
app.run()
|
||||
|
||||
|
||||
|
||||
@@ -5,12 +5,17 @@ from starlette.responses import JSONResponse
|
||||
|
||||
from cpl.api import APILogger
|
||||
from cpl.api.router import Router
|
||||
from cpl.core.console import Console
|
||||
from cpl.dependency import ServiceProvider
|
||||
from custom.api.src.scoped_service import ScopedService
|
||||
|
||||
|
||||
@Router.authenticate()
|
||||
# @Router.authorize(permissions=[Permissions.administrator])
|
||||
# @Router.authorize(policies=["test"])
|
||||
@Router.get(f"/ping")
|
||||
async def ping(r: Request, ping: PingService, logger: APILogger):
|
||||
async def ping(r: Request, ping: PingService, logger: APILogger, provider: ServiceProvider, scoped: ScopedService):
|
||||
logger.info(f"Ping: {ping}")
|
||||
|
||||
Console.write_line(scoped.name)
|
||||
return JSONResponse(ping.ping(r))
|
||||
|
||||
14
example/custom/api/src/scoped_service.py
Normal file
14
example/custom/api/src/scoped_service.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from cpl.core.console.console import Console
|
||||
from cpl.core.utils.string import String
|
||||
|
||||
|
||||
class ScopedService:
|
||||
def __init__(self):
|
||||
self._name = String.random(8)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def run(self):
|
||||
Console.write_line(f"Im {self._name}")
|
||||
@@ -3,7 +3,7 @@ from cpl.auth.keycloak import KeycloakAdmin
|
||||
from cpl.core.console import Console
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.core.log import LoggerABC
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
from model.city import City
|
||||
from model.city_dao import CityDao
|
||||
from model.user import User
|
||||
@@ -11,7 +11,7 @@ from model.user_dao import UserDao
|
||||
|
||||
|
||||
class Application(ApplicationABC):
|
||||
def __init__(self, services: ServiceProviderABC):
|
||||
def __init__(self, services: ServiceProvider):
|
||||
ApplicationABC.__init__(self, services)
|
||||
|
||||
self._logger = services.get_service(LoggerABC)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from cpl.application.abc import ApplicationABC
|
||||
from cpl.core.console.console import Console
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency.scope import Scope
|
||||
from cpl.dependency import ServiceProvider
|
||||
from di.static_test import StaticTest
|
||||
from di.test_abc import TestABC
|
||||
from di.test_service import TestService
|
||||
@@ -10,33 +9,37 @@ from di.tester import Tester
|
||||
|
||||
|
||||
class Application(ApplicationABC):
|
||||
def __init__(self, services: ServiceProviderABC):
|
||||
def __init__(self, services: ServiceProvider):
|
||||
ApplicationABC.__init__(self, services)
|
||||
|
||||
def _part_of_scoped(self):
|
||||
ts: TestService = self._services.get_service(TestService)
|
||||
ts.run()
|
||||
|
||||
def configure(self): ...
|
||||
|
||||
def main(self):
|
||||
with self._services.create_scope() as scope:
|
||||
Console.write_line("Scope1")
|
||||
ts: TestService = scope.service_provider.get_service(TestService)
|
||||
ts: TestService = scope.get_service(TestService)
|
||||
ts.run()
|
||||
dit: DITesterService = scope.service_provider.get_service(DITesterService)
|
||||
dit: DITesterService = scope.get_service(DITesterService)
|
||||
dit.run()
|
||||
|
||||
if ts.name != dit.name:
|
||||
raise Exception("DI is broken!")
|
||||
|
||||
with self._services.create_scope() as scope:
|
||||
Console.write_line("Scope2")
|
||||
ts: TestService = scope.service_provider.get_service(TestService)
|
||||
ts: TestService = scope.get_service(TestService)
|
||||
ts.run()
|
||||
dit: DITesterService = scope.service_provider.get_service(DITesterService)
|
||||
dit: DITesterService = scope.get_service(DITesterService)
|
||||
dit.run()
|
||||
|
||||
if ts.name != dit.name:
|
||||
raise Exception("DI is broken!")
|
||||
|
||||
Console.write_line("Global")
|
||||
self._part_of_scoped()
|
||||
StaticTest.test()
|
||||
|
||||
self._services.get_service(Tester)
|
||||
Console.write_line(self._services.get_services(list[TestABC]))
|
||||
Console.write_line(self._services.get_services(TestABC))
|
||||
|
||||
@@ -6,6 +6,10 @@ class DITesterService:
|
||||
def __init__(self, ts: TestService):
|
||||
self._ts = ts
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._ts.name
|
||||
|
||||
def run(self):
|
||||
Console.write_line("DIT: ")
|
||||
self._ts.run()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from cpl.application.abc import StartupABC
|
||||
from cpl.dependency import ServiceProviderABC, ServiceCollection
|
||||
from cpl.dependency import ServiceProvider, ServiceCollection
|
||||
from di.di_tester_service import DITesterService
|
||||
from di.test1_service import Test1Service
|
||||
from di.test2_service import Test2Service
|
||||
@@ -12,9 +12,11 @@ class Startup(StartupABC):
|
||||
def __init__(self):
|
||||
StartupABC.__init__(self)
|
||||
|
||||
def configure_configuration(self): ...
|
||||
@staticmethod
|
||||
def configure_configuration(): ...
|
||||
|
||||
def configure_services(self, services: ServiceCollection) -> ServiceProviderABC:
|
||||
@staticmethod
|
||||
def configure_services(services: ServiceCollection) -> ServiceProvider:
|
||||
services.add_scoped(TestService)
|
||||
services.add_scoped(DITesterService)
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from cpl.dependency import ServiceProvider, ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider, ServiceProvider
|
||||
from cpl.dependency.inject import inject
|
||||
from di.test_service import TestService
|
||||
|
||||
|
||||
class StaticTest:
|
||||
@staticmethod
|
||||
@ServiceProvider.inject
|
||||
def test(services: ServiceProviderABC, t1: TestService):
|
||||
@inject
|
||||
def test(services: ServiceProvider, t1: TestService):
|
||||
t1.run()
|
||||
|
||||
@@ -6,7 +6,7 @@ from di.test_abc import TestABC
|
||||
|
||||
class Test1Service(TestABC):
|
||||
def __init__(self):
|
||||
TestABC.__init__(self, String.random_string(string.ascii_lowercase, 8))
|
||||
TestABC.__init__(self, String.random(8))
|
||||
|
||||
def run(self):
|
||||
Console.write_line(f"Im {self._name}")
|
||||
|
||||
@@ -6,7 +6,7 @@ from di.test_abc import TestABC
|
||||
|
||||
class Test2Service(TestABC):
|
||||
def __init__(self):
|
||||
TestABC.__init__(self, String.random_string(string.ascii_lowercase, 8))
|
||||
TestABC.__init__(self, String.random(8))
|
||||
|
||||
def run(self):
|
||||
Console.write_line(f"Im {self._name}")
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import string
|
||||
|
||||
from cpl.core.console.console import Console
|
||||
from cpl.core.utils.string import String
|
||||
|
||||
@@ -8,5 +6,9 @@ class TestService:
|
||||
def __init__(self):
|
||||
self._name = String.random(8)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def run(self):
|
||||
Console.write_line(f"Im {self._name}")
|
||||
|
||||
@@ -4,19 +4,20 @@ from typing import Optional
|
||||
from cpl.application.abc import ApplicationABC
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.console import Console
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.core.log import LoggerABC
|
||||
from cpl.core.pipes import IPAddressPipe
|
||||
from cpl.mail import EMail, EMailClientABC
|
||||
from cpl.query.extension.list import List
|
||||
from cpl.query import List
|
||||
from general.scoped_service import ScopedService
|
||||
from test_service import TestService
|
||||
from test_settings import TestSettings
|
||||
|
||||
|
||||
class Application(ApplicationABC):
|
||||
|
||||
def __init__(self, services: ServiceProviderABC):
|
||||
def __init__(self, services: ServiceProvider):
|
||||
ApplicationABC.__init__(self, services)
|
||||
self._logger = self._services.get_service(LoggerABC)
|
||||
self._mailer = self._services.get_service(EMailClientABC)
|
||||
@@ -38,7 +39,7 @@ class Application(ApplicationABC):
|
||||
def main(self):
|
||||
self._logger.debug(f"Host: {Environment.get_host_name()}")
|
||||
self._logger.debug(f"Environment: {Environment.get_environment()}")
|
||||
Console.write_line(List(int, range(0, 10)).select(lambda x: f"x={x}").to_list())
|
||||
Console.write_line(List(range(0, 10)).select(lambda x: f"x={x}").to_list())
|
||||
Console.spinner("Test", self._wait, 2, spinner_foreground_color="red")
|
||||
test: TestService = self._services.get_service(TestService)
|
||||
ip_pipe: IPAddressPipe = self._services.get_service(IPAddressPipe)
|
||||
@@ -48,10 +49,21 @@ class Application(ApplicationABC):
|
||||
Console.write_line(f"DI working: {test == test2 and ip_pipe != ip_pipe2}")
|
||||
Console.write_line(self._services.get_service(LoggerABC))
|
||||
|
||||
scope = self._services.create_scope()
|
||||
Console.write_line("scope", scope)
|
||||
with self._services.create_scope() as s:
|
||||
Console.write_line("with scope", s)
|
||||
root_scoped_service = self._services.get_service(ScopedService)
|
||||
with self._services.create_scope() as scope:
|
||||
s_srvc1 = scope.get_service(ScopedService)
|
||||
s_srvc2 = scope.get_service(ScopedService)
|
||||
|
||||
Console.write_line(root_scoped_service)
|
||||
Console.write_line(s_srvc1)
|
||||
Console.write_line(s_srvc2)
|
||||
if root_scoped_service == s_srvc1 or s_srvc1 != s_srvc2:
|
||||
raise Exception("Root scoped service should not be equal to scoped service")
|
||||
|
||||
root_scoped_service2 = self._services.get_service(ScopedService)
|
||||
Console.write_line(root_scoped_service2)
|
||||
if root_scoped_service == root_scoped_service2:
|
||||
raise Exception("Root scoped service should be equal to root scoped service 2")
|
||||
|
||||
test_settings = Configuration.get(TestSettings)
|
||||
Console.write_line(test_settings.value)
|
||||
|
||||
10
example/custom/general/src/general/scoped_service.py
Normal file
10
example/custom/general/src/general/scoped_service.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from cpl.core.console import Console
|
||||
|
||||
|
||||
class ScopedService:
|
||||
def __init__(self):
|
||||
self.value = "I am a scoped service"
|
||||
Console.write_line(self.value, self)
|
||||
|
||||
def get_value(self):
|
||||
return self.value
|
||||
@@ -4,6 +4,7 @@ from cpl.core.configuration import Configuration
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.core.pipes import IPAddressPipe
|
||||
from cpl.dependency import ServiceCollection
|
||||
from general.scoped_service import ScopedService
|
||||
from test_service import TestService
|
||||
|
||||
|
||||
@@ -21,3 +22,4 @@ class Startup(StartupABC):
|
||||
services.add_module(mail)
|
||||
services.add_transient(IPAddressPipe)
|
||||
services.add_singleton(TestService)
|
||||
services.add_scoped(ScopedService)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from cpl.application.abc import ApplicationExtensionABC
|
||||
from cpl.core.console import Console
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class TestExtension(ApplicationExtensionABC):
|
||||
|
||||
@staticmethod
|
||||
def run(services: ServiceProviderABC):
|
||||
def run(services: ServiceProvider):
|
||||
Console.write_line("Hello World from App Extension")
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from cpl.core.console.console import Console
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.core.pipes.ip_address_pipe import IPAddressPipe
|
||||
|
||||
|
||||
class TestService:
|
||||
def __init__(self, provider: ServiceProviderABC):
|
||||
def __init__(self, provider: ServiceProvider):
|
||||
self._provider = provider
|
||||
|
||||
def run(self):
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from cpl.core.console import Console
|
||||
from cpl.core.utils.benchmark import Benchmark
|
||||
from cpl.query.collection import Collection
|
||||
from cpl.query.enumerable import Enumerable
|
||||
from cpl.query.list import List
|
||||
from cpl.query.immutable_list import ImmutableList
|
||||
from cpl.query.list import List
|
||||
from cpl.query.set import Set
|
||||
|
||||
|
||||
@@ -24,11 +23,23 @@ def _default():
|
||||
s.add(1)
|
||||
Console.write_line(s)
|
||||
|
||||
data = Enumerable(
|
||||
[
|
||||
{"name": "Alice", "age": 30},
|
||||
{"name": "Dave", "age": 35},
|
||||
{"name": "Charlie", "age": 25},
|
||||
{"name": "Bob", "age": 25},
|
||||
]
|
||||
)
|
||||
|
||||
Console.write_line(data.order_by(lambda x: x["age"]).to_list())
|
||||
Console.write_line(data.order_by(lambda x: x["age"]).then_by(lambda x: x["name"]).to_list())
|
||||
Console.write_line(data.order_by(lambda x: x["name"]).then_by(lambda x: x["age"]).to_list())
|
||||
|
||||
|
||||
def t_benchmark(data: list):
|
||||
Benchmark.all("Enumerable", lambda: Enumerable(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
|
||||
Benchmark.all("Set", lambda: Set(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
|
||||
Benchmark.all("Collection", lambda: Collection(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
|
||||
Benchmark.all("List", lambda: List(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
|
||||
Benchmark.all(
|
||||
"ImmutableList", lambda: ImmutableList(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list()
|
||||
@@ -39,7 +50,7 @@ def t_benchmark(data: list):
|
||||
def main():
|
||||
N = 10_000_000
|
||||
data = list(range(N))
|
||||
t_benchmark(data)
|
||||
#t_benchmark(data)
|
||||
|
||||
Console.write_line()
|
||||
_default()
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from cpl.application import ApplicationABC
|
||||
from cpl.core.configuration import ConfigurationABC
|
||||
from cpl.core.console import Console
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.translation.translate_pipe import TranslatePipe
|
||||
from cpl.translation.translation_service_abc import TranslationServiceABC
|
||||
from cpl.translation.translation_settings import TranslationSettings
|
||||
|
||||
|
||||
class Application(ApplicationABC):
|
||||
def __init__(self, config: ConfigurationABC, services: ServiceProviderABC):
|
||||
def __init__(self, config: ConfigurationABC, services: ServiceProvider):
|
||||
ApplicationABC.__init__(self, config, services)
|
||||
|
||||
self._translate: TranslatePipe = services.get_service(TranslatePipe)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from cpl.application import StartupABC
|
||||
from cpl.core.configuration import ConfigurationABC
|
||||
from cpl.dependency import ServiceProviderABC, ServiceCollection
|
||||
from cpl.dependency import ServiceProvider, ServiceCollection
|
||||
from cpl.core.environment import Environment
|
||||
|
||||
|
||||
@@ -12,6 +12,6 @@ class Startup(StartupABC):
|
||||
configuration.add_json_file("appsettings.json")
|
||||
return configuration
|
||||
|
||||
def configure_services(self, services: ServiceCollection, environment: Environment) -> ServiceProviderABC:
|
||||
def configure_services(self, services: ServiceCollection, environment: Environment) -> ServiceProvider:
|
||||
services.add_translation()
|
||||
return services.build()
|
||||
|
||||
@@ -27,14 +27,14 @@ from cpl.api.settings import ApiSettings
|
||||
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
|
||||
from cpl.application.abc.application_abc import ApplicationABC
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
PolicyInput = Union[dict[str, PolicyResolver], Policy]
|
||||
|
||||
|
||||
class WebApp(ApplicationABC):
|
||||
def __init__(self, services: ServiceProviderABC):
|
||||
def __init__(self, services: ServiceProvider):
|
||||
super().__init__(services, [auth, api])
|
||||
self._app: Starlette | None = None
|
||||
|
||||
@@ -44,15 +44,15 @@ class WebApp(ApplicationABC):
|
||||
self._policies = services.get_service(PolicyRegistry)
|
||||
self._routes = services.get_service(RouteRegistry)
|
||||
|
||||
self._middleware: list[Middleware] = [
|
||||
Middleware(RequestMiddleware),
|
||||
Middleware(LoggingMiddleware),
|
||||
]
|
||||
self._middleware: list[Middleware] = []
|
||||
self._exception_handlers: Mapping[Any, ExceptionHandler] = {
|
||||
Exception: self._handle_exception,
|
||||
APIError: self._handle_exception,
|
||||
}
|
||||
|
||||
self.with_middleware(RequestMiddleware)
|
||||
self.with_middleware(LoggingMiddleware)
|
||||
|
||||
async def _handle_exception(self, request: Request, exc: Exception):
|
||||
if isinstance(exc, APIError):
|
||||
self._logger.error(exc)
|
||||
@@ -168,9 +168,9 @@ class WebApp(ApplicationABC):
|
||||
self._check_for_app()
|
||||
|
||||
if isinstance(middleware, Middleware):
|
||||
self._middleware.append(middleware)
|
||||
self._middleware.append(inject(middleware))
|
||||
elif callable(middleware):
|
||||
self._middleware.append(Middleware(middleware))
|
||||
self._middleware.append(Middleware(inject(middleware)))
|
||||
else:
|
||||
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
|
||||
|
||||
@@ -220,7 +220,7 @@ class WebApp(ApplicationABC):
|
||||
self._validate_policies()
|
||||
|
||||
if self._app is None:
|
||||
routes = [route.to_starlette(self._services.inject) for route in self._routes.all()]
|
||||
routes = [route.to_starlette(inject) for route in self._routes.all()]
|
||||
|
||||
app = Starlette(
|
||||
routes=routes,
|
||||
|
||||
@@ -9,12 +9,10 @@ from cpl.api.router import Router
|
||||
from cpl.auth.keycloak import KeycloakClient
|
||||
from cpl.auth.schema import AuthUserDao, AuthUser
|
||||
from cpl.core.ctx import set_user
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class AuthenticationMiddleware(ASGIMiddleware):
|
||||
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
|
||||
@@ -9,12 +9,10 @@ from cpl.api.registry.policy import PolicyRegistry
|
||||
from cpl.api.router import Router
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
from cpl.core.ctx.user_context import get_user
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
|
||||
class AuthorizationMiddleware(ASGIMiddleware):
|
||||
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
|
||||
@@ -6,12 +6,10 @@ from starlette.types import Receive, Scope, Send
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.middleware.request import get_request
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class LoggingMiddleware(ASGIMiddleware):
|
||||
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, logger: APILogger):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
|
||||
@@ -9,17 +9,18 @@ from starlette.types import Scope, Receive, Send
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.typing import TRequest
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
|
||||
|
||||
|
||||
class RequestMiddleware(ASGIMiddleware):
|
||||
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, logger: APILogger):
|
||||
def __init__(self, app, provider: ServiceProvider, logger: APILogger):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._provider = provider
|
||||
self._logger = logger
|
||||
|
||||
self._ctx_token = None
|
||||
@@ -29,7 +30,8 @@ class RequestMiddleware(ASGIMiddleware):
|
||||
await self.set_request_data(request)
|
||||
|
||||
try:
|
||||
await self._app(scope, receive, send)
|
||||
with self._provider.create_scope():
|
||||
inject(await self._app(scope, receive, send))
|
||||
finally:
|
||||
await self.clean_request_data()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import Enum
|
||||
from cpl.api.model.validation_match import ValidationMatch
|
||||
from cpl.api.registry.route import RouteRegistry
|
||||
from cpl.api.typing import HTTPMethods
|
||||
from cpl.dependency import get_provider
|
||||
|
||||
|
||||
class Router:
|
||||
@@ -95,9 +96,7 @@ class Router:
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
|
||||
if not registry:
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
routes = ServiceProviderABC.get_global_service(RouteRegistry)
|
||||
routes = get_provider().get_service(RouteRegistry)
|
||||
else:
|
||||
routes = registry
|
||||
|
||||
@@ -144,9 +143,8 @@ class Router:
|
||||
"""
|
||||
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
routes = ServiceProviderABC.get_global_service(RouteRegistry)
|
||||
routes = get_provider().get_service(RouteRegistry)
|
||||
|
||||
def inner(fn):
|
||||
path = getattr(fn, "_route_path", None)
|
||||
|
||||
@@ -5,7 +5,7 @@ from cpl.application.host import Host
|
||||
from cpl.core.log.log_level import LogLevel
|
||||
from cpl.core.log.log_settings import LogSettings
|
||||
from cpl.core.log.logger_abc import LoggerABC
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
|
||||
def __not_implemented__(package: str, func: Callable):
|
||||
@@ -16,12 +16,12 @@ class ApplicationABC(ABC):
|
||||
r"""ABC for the Application class
|
||||
|
||||
Parameters:
|
||||
services: :class:`cpl.dependency.service_provider_abc.ServiceProviderABC`
|
||||
services: :class:`cpl.dependency.service_provider.ServiceProvider`
|
||||
Contains instances of prepared objects
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, services: ServiceProviderABC, required_modules: list[str | object] = None):
|
||||
def __init__(self, services: ServiceProvider, required_modules: list[str | object] = None):
|
||||
self._services = services
|
||||
self._required_modules = (
|
||||
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
|
||||
class ApplicationExtensionABC(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def run(services: ServiceProviderABC): ...
|
||||
def run(services: ServiceProvider): ...
|
||||
|
||||
@@ -7,6 +7,7 @@ from cpl.application.abc.startup_abc import StartupABC
|
||||
from cpl.application.abc.startup_extension_abc import StartupExtensionABC
|
||||
from cpl.application.host import Host
|
||||
from cpl.core.errors import dependency_error
|
||||
from cpl.dependency.context import get_provider, use_root_provider
|
||||
from cpl.dependency.service_collection import ServiceCollection
|
||||
|
||||
TApp = TypeVar("TApp", bound=ApplicationABC)
|
||||
@@ -21,6 +22,7 @@ class ApplicationBuilder(Generic[TApp]):
|
||||
self._app = app if app is not None else ApplicationABC
|
||||
|
||||
self._services = ServiceCollection()
|
||||
use_root_provider(self._services.build())
|
||||
|
||||
self._startup: Optional[StartupABC] = None
|
||||
self._app_extensions: list[Type[ApplicationExtensionABC]] = []
|
||||
@@ -34,7 +36,12 @@ class ApplicationBuilder(Generic[TApp]):
|
||||
|
||||
@property
|
||||
def service_provider(self):
|
||||
return self._services.build()
|
||||
provider = get_provider()
|
||||
if provider is None:
|
||||
provider = self._services.build()
|
||||
use_root_provider(provider)
|
||||
|
||||
return provider
|
||||
|
||||
def validate_app_required_modules(self, app: ApplicationABC):
|
||||
for module in app.required_modules:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from cpl.core.utils.get_value import get_value
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class KeycloakUser:
|
||||
@@ -32,5 +32,5 @@ class KeycloakUser:
|
||||
def id(self) -> str:
|
||||
from cpl.auth import KeycloakAdmin
|
||||
|
||||
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
keycloak_admin: KeycloakAdmin = get_provider().get_service(KeycloakAdmin)
|
||||
return keycloak_admin.get_user_id(self._username)
|
||||
|
||||
@@ -10,7 +10,8 @@ from cpl.core.log.logger import Logger
|
||||
from cpl.core.typing import Id, SerialId
|
||||
from cpl.core.utils.credential_manager import CredentialManager
|
||||
from cpl.database.abc.db_model_abc import DbModelABC
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
from cpl.dependency import get_provider
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
_logger = Logger(__name__)
|
||||
|
||||
@@ -47,7 +48,7 @@ class ApiKey(DbModelABC):
|
||||
async def permissions(self):
|
||||
from cpl.auth.schema._permission.api_key_permission_dao import ApiKeyPermissionDao
|
||||
|
||||
apiKeyPermissionDao = ServiceProviderABC.get_global_provider().get_service(ApiKeyPermissionDao)
|
||||
apiKeyPermissionDao = get_provider().get_service(ApiKeyPermissionDao)
|
||||
|
||||
return [await x.permission for x in await apiKeyPermissionDao.find_by_api_key_id(self.id)]
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class AuthUser(DbModelABC):
|
||||
@@ -36,12 +36,12 @@ class AuthUser(DbModelABC):
|
||||
return "ANONYMOUS"
|
||||
|
||||
try:
|
||||
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
keycloak = get_provider().get_service(KeycloakAdmin)
|
||||
return keycloak.get_user(self._keycloak_id).get("username")
|
||||
except KeycloakGetError as e:
|
||||
return "UNKNOWN"
|
||||
except Exception as e:
|
||||
logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||
logger = get_provider().get_service(DBLogger)
|
||||
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||
return "UNKNOWN"
|
||||
|
||||
@@ -51,12 +51,12 @@ class AuthUser(DbModelABC):
|
||||
return "ANONYMOUS"
|
||||
|
||||
try:
|
||||
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
keycloak = get_provider().get_service(KeycloakAdmin)
|
||||
return keycloak.get_user(self._keycloak_id).get("email")
|
||||
except KeycloakGetError as e:
|
||||
return "UNKNOWN"
|
||||
except Exception as e:
|
||||
logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||
logger = get_provider().get_service(DBLogger)
|
||||
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||
return "UNKNOWN"
|
||||
|
||||
@@ -64,26 +64,26 @@ class AuthUser(DbModelABC):
|
||||
async def roles(self):
|
||||
from cpl.auth.schema._permission.role_user_dao import RoleUserDao
|
||||
|
||||
role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
|
||||
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao)
|
||||
return [await x.role for x in await role_user_dao.get_by_user_id(self.id)]
|
||||
|
||||
@async_property
|
||||
async def permissions(self):
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
|
||||
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
|
||||
return await auth_user_dao.get_permissions(self.id)
|
||||
|
||||
async def has_permission(self, permission: Permissions) -> bool:
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
|
||||
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
|
||||
return await auth_user_dao.has_permission(self.id, permission)
|
||||
|
||||
async def anonymize(self):
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
|
||||
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
|
||||
|
||||
self._keycloak_id = str(uuid.UUID(int=0))
|
||||
await auth_user_dao.update(self)
|
||||
|
||||
@@ -5,7 +5,7 @@ from cpl.auth.schema._administration.auth_user import AuthUser
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class AuthUserDao(DbModelDaoABC[AuthUser]):
|
||||
@@ -36,7 +36,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
|
||||
async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool:
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
|
||||
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
|
||||
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
|
||||
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
|
||||
result = await self._db.select_map(
|
||||
f"""
|
||||
|
||||
@@ -5,7 +5,7 @@ from async_property import async_property
|
||||
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbJoinModelABC
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class ApiKeyPermission(DbJoinModelABC):
|
||||
@@ -31,7 +31,7 @@ class ApiKeyPermission(DbJoinModelABC):
|
||||
async def api_key(self):
|
||||
from cpl.auth.schema._administration.api_key_dao import ApiKeyDao
|
||||
|
||||
api_key_dao: ApiKeyDao = ServiceProviderABC.get_global_service(ApiKeyDao)
|
||||
api_key_dao: ApiKeyDao = get_provider().get_service(ApiKeyDao)
|
||||
return await api_key_dao.get_by_id(self._api_key_id)
|
||||
|
||||
@property
|
||||
@@ -42,5 +42,5 @@ class ApiKeyPermission(DbJoinModelABC):
|
||||
async def permission(self):
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
|
||||
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
|
||||
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
|
||||
return await permission_dao.get_by_id(self._permission_id)
|
||||
|
||||
@@ -6,7 +6,7 @@ from async_property import async_property
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class Role(DbModelABC):
|
||||
@@ -44,22 +44,22 @@ class Role(DbModelABC):
|
||||
async def permissions(self):
|
||||
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
|
||||
|
||||
role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
|
||||
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao)
|
||||
return [await x.permission for x in await role_permission_dao.get_by_role_id(self.id)]
|
||||
|
||||
@async_property
|
||||
async def users(self):
|
||||
from cpl.auth.schema._permission.role_user_dao import RoleUserDao
|
||||
|
||||
role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
|
||||
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao)
|
||||
return [await x.user for x in await role_user_dao.get_by_role_id(self.id)]
|
||||
|
||||
async def has_permission(self, permission: Permissions) -> bool:
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
|
||||
|
||||
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
|
||||
role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
|
||||
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
|
||||
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao)
|
||||
|
||||
p = await permission_dao.get_by_name(permission.value)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from async_property import async_property
|
||||
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class RolePermission(DbModelABC):
|
||||
@@ -31,7 +31,7 @@ class RolePermission(DbModelABC):
|
||||
async def role(self):
|
||||
from cpl.auth.schema._permission.role_dao import RoleDao
|
||||
|
||||
role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
|
||||
role_dao: RoleDao = get_provider().get_service(RoleDao)
|
||||
return await role_dao.get_by_id(self._role_id)
|
||||
|
||||
@property
|
||||
@@ -42,5 +42,5 @@ class RolePermission(DbModelABC):
|
||||
async def permission(self):
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
|
||||
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
|
||||
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
|
||||
return await permission_dao.get_by_id(self._permission_id)
|
||||
|
||||
@@ -5,7 +5,7 @@ from async_property import async_property
|
||||
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbJoinModelABC
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class RoleUser(DbJoinModelABC):
|
||||
@@ -31,7 +31,7 @@ class RoleUser(DbJoinModelABC):
|
||||
async def user(self):
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
|
||||
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao)
|
||||
return await auth_user_dao.get_by_id(self._user_id)
|
||||
|
||||
@property
|
||||
@@ -42,5 +42,5 @@ class RoleUser(DbJoinModelABC):
|
||||
async def role(self):
|
||||
from cpl.auth.schema._permission.role_dao import RoleDao
|
||||
|
||||
role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
|
||||
role_dao: RoleDao = get_provider().get_service(RoleDao)
|
||||
return await role_dao.get_by_id(self._role_id)
|
||||
|
||||
@@ -2,15 +2,15 @@ from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
|
||||
from cpl.auth.schema._administration.auth_user import AuthUser
|
||||
from cpl.dependency import get_provider
|
||||
|
||||
_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None)
|
||||
|
||||
|
||||
def set_user(user: Optional[AuthUser]):
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
from cpl.core.log.logger_abc import LoggerABC
|
||||
|
||||
logger = ServiceProviderABC.get_global_service(LoggerABC)
|
||||
logger = get_provider().get_service(LoggerABC)
|
||||
logger.trace("Setting user context", user.id)
|
||||
_user_context.set(user)
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from starlette.requests import Request
|
||||
from cpl.core.log.log_level import LogLevel
|
||||
from cpl.core.log.logger import Logger
|
||||
from cpl.core.typing import Source, Messages
|
||||
from cpl.dependency import get_provider
|
||||
|
||||
|
||||
class StructuredLogger(Logger):
|
||||
@@ -99,10 +100,9 @@ class StructuredLogger(Logger):
|
||||
if user is None:
|
||||
return
|
||||
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin
|
||||
|
||||
keycloak = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
keycloak = get_provider().get_service(KeycloakAdmin)
|
||||
kc_user = keycloak.get_user(user.keycloak_id)
|
||||
message["user"] = {
|
||||
"id": str(user.id),
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Type
|
||||
|
||||
from cpl.core.log import LoggerABC, LogLevel
|
||||
from cpl.core.typing import Messages
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
|
||||
class WrappedLogger(LoggerABC):
|
||||
@@ -17,13 +18,13 @@ class WrappedLogger(LoggerABC):
|
||||
|
||||
self._set_logger()
|
||||
|
||||
@ServiceProviderABC.inject
|
||||
def _set_logger(self, services: ServiceProviderABC):
|
||||
@inject
|
||||
def _set_logger(self, services: ServiceProvider):
|
||||
from cpl.core.log import Logger
|
||||
|
||||
t_logger: Type[Logger] = services.get_service_type(LoggerABC)
|
||||
if t_logger is None:
|
||||
raise Exception("No LoggerABC service registered in ServiceProviderABC")
|
||||
raise Exception("No LoggerABC service registered in ServiceProvider")
|
||||
|
||||
self._logger = t_logger(self._source, self._file_prefix)
|
||||
|
||||
@@ -42,8 +43,8 @@ class WrappedLogger(LoggerABC):
|
||||
from cpl.dependency import ServiceCollection
|
||||
|
||||
ignore_classes = [
|
||||
ServiceProviderABC,
|
||||
ServiceProviderABC.__subclasses__(),
|
||||
ServiceProvider,
|
||||
ServiceProvider.__subclasses__(),
|
||||
ServiceCollection,
|
||||
WrappedLogger,
|
||||
WrappedLogger.__subclasses__(),
|
||||
|
||||
@@ -114,12 +114,15 @@ class String:
|
||||
|
||||
characters = []
|
||||
if letters:
|
||||
characters.append(string.ascii_letters)
|
||||
characters.extend(string.ascii_letters)
|
||||
|
||||
if digits:
|
||||
characters.append(string.digits)
|
||||
characters.extend(string.digits)
|
||||
|
||||
if special_characters:
|
||||
characters.append(string.punctuation)
|
||||
characters.extend(string.punctuation)
|
||||
|
||||
return "".join(random.choice(characters) for _ in range(length)) if characters else ""
|
||||
x = "".join(random.choice(list(characters)) for _ in range(length)) if characters else ""
|
||||
if len(x) != length:
|
||||
raise Exception("No characters selected to generate random string")
|
||||
return x
|
||||
|
||||
@@ -9,21 +9,19 @@ from cpl.core.utils.get_value import get_value
|
||||
from cpl.core.utils.string import String
|
||||
from cpl.database.abc.db_context_abc import DBContextABC
|
||||
from cpl.database.const import DATETIME_FORMAT
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.postgres.sql_select_builder import SQLSelectBuilder
|
||||
from cpl.database.typing import T_DBM, Attribute, AttributeFilters, AttributeSorts
|
||||
from cpl.dependency import get_provider
|
||||
|
||||
|
||||
class DataAccessObjectABC(ABC, Generic[T_DBM]):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, model_type: Type[T_DBM], table_name: str):
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
self._db = ServiceProviderABC.get_global_service(DBContextABC)
|
||||
|
||||
self._logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||
self._db = get_provider().get_service(DBContextABC)
|
||||
self._logger = get_provider().get_service(DBLogger)
|
||||
self._model_type = model_type
|
||||
self._table_name = table_name
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from mysql.connector.aio import MySQLConnectionPool
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.model import DatabaseSettings
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class MySQLPool:
|
||||
@@ -35,7 +35,7 @@ class MySQLPool:
|
||||
await cursor.execute("SELECT 1")
|
||||
await cursor.fetchall()
|
||||
except Exception as e:
|
||||
logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||
logger = get_provider().get_service(DBLogger)
|
||||
logger.fatal(f"Error connecting to the database: {e}")
|
||||
finally:
|
||||
await con.close()
|
||||
|
||||
@@ -7,7 +7,7 @@ from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.model import DatabaseSettings
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class PostgresPool:
|
||||
@@ -37,7 +37,7 @@ class PostgresPool:
|
||||
await pool.check_connection(con)
|
||||
except PoolTimeout as e:
|
||||
await pool.close()
|
||||
logger = ServiceProviderABC.get_global_service(DBLogger)
|
||||
logger = get_provider().get_service(DBLogger)
|
||||
logger.fatal(f"Failed to connect to the database", e)
|
||||
self._pool = pool
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class SeederService:
|
||||
|
||||
def __init__(self, provider: ServiceProviderABC):
|
||||
def __init__(self, provider: ServiceProvider):
|
||||
self._provider = provider
|
||||
self._logger = provider.get_service(DBLogger)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from .scope import Scope
|
||||
from .scope_abc import ScopeABC
|
||||
from .context import get_provider, use_provider
|
||||
from .inject import inject
|
||||
from .service_collection import ServiceCollection
|
||||
from .service_descriptor import ServiceDescriptor
|
||||
from .service_lifetime_enum import ServiceLifetimeEnum
|
||||
from .service_lifetime import ServiceLifetimeEnum
|
||||
from .service_provider import ServiceProvider
|
||||
from .service_provider import ServiceProvider
|
||||
from .service_provider_abc import ServiceProviderABC
|
||||
|
||||
21
src/cpl-dependency/cpl/dependency/context.py
Normal file
21
src/cpl-dependency/cpl/dependency/context.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import contextvars
|
||||
from contextlib import contextmanager
|
||||
|
||||
_current_provider = contextvars.ContextVar("current_provider", default=None)
|
||||
|
||||
|
||||
def use_root_provider(provider):
|
||||
_current_provider.set(provider)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_provider(provider):
|
||||
token = _current_provider.set(provider)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_current_provider.reset(token)
|
||||
|
||||
|
||||
def get_provider():
|
||||
return _current_provider.get()
|
||||
42
src/cpl-dependency/cpl/dependency/inject.py
Normal file
42
src/cpl-dependency/cpl/dependency/inject.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import functools
|
||||
from asyncio import iscoroutinefunction
|
||||
from inspect import signature
|
||||
|
||||
from cpl.dependency.context import get_provider
|
||||
|
||||
|
||||
def inject(f=None):
|
||||
if f is None:
|
||||
return functools.partial(inject)
|
||||
|
||||
if iscoroutinefunction(f):
|
||||
|
||||
@functools.wraps(f)
|
||||
async def async_inner(*args, **kwargs):
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
provider: ServiceProvider | None = get_provider()
|
||||
if provider is None:
|
||||
raise ValueError(
|
||||
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
|
||||
)
|
||||
|
||||
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
|
||||
return await f(*args, *injection, **kwargs)
|
||||
|
||||
return async_inner
|
||||
|
||||
@functools.wraps(f)
|
||||
def inner(*args, **kwargs):
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
provider: ServiceProvider | None = get_provider()
|
||||
if provider is None:
|
||||
raise ValueError(
|
||||
"No provider in current context. Use 'with use_provider(provider):' to set the provider in the current context."
|
||||
)
|
||||
|
||||
injection = [x for x in provider._build_by_signature(signature(f)) if x is not None]
|
||||
return f(*args, *injection, **kwargs)
|
||||
|
||||
return inner
|
||||
@@ -1,22 +0,0 @@
|
||||
from cpl.dependency.scope_abc import ScopeABC
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
|
||||
class Scope(ScopeABC):
|
||||
def __init__(self, service_provider: ServiceProviderABC):
|
||||
self._service_provider = service_provider
|
||||
self._service_provider.set_scope(self)
|
||||
ScopeABC.__init__(self)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.dispose()
|
||||
|
||||
@property
|
||||
def service_provider(self) -> ServiceProviderABC:
|
||||
return self._service_provider
|
||||
|
||||
def dispose(self):
|
||||
self._service_provider = None
|
||||
@@ -1,20 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ScopeABC(ABC):
|
||||
r"""ABC for the class :class:`cpl.dependency.scope.Scope`"""
|
||||
|
||||
def __init__(self): ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def service_provider(self):
|
||||
r"""Returns to service provider of scope
|
||||
|
||||
Returns:
|
||||
Object of type :class:`cpl.dependency.service_provider_abc.ServiceProviderABC`
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def dispose(self):
|
||||
r"""Sets service_provider to None"""
|
||||
@@ -1,18 +0,0 @@
|
||||
from cpl.dependency.scope import Scope
|
||||
from cpl.dependency.scope_abc import ScopeABC
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
|
||||
class ScopeBuilder:
|
||||
r"""Class to build :class:`cpl.dependency.scope.Scope`"""
|
||||
|
||||
def __init__(self, service_provider: ServiceProviderABC) -> None:
|
||||
self._service_provider = service_provider
|
||||
|
||||
def build(self) -> ScopeABC:
|
||||
r"""Returns scope
|
||||
|
||||
Returns:
|
||||
Object of type :class:`cpl.dependency.scope.Scope`
|
||||
"""
|
||||
return Scope(self._service_provider)
|
||||
@@ -4,9 +4,8 @@ from cpl.core.log.logger_abc import LoggerABC
|
||||
from cpl.core.typing import T, Service
|
||||
from cpl.core.utils.cache import Cache
|
||||
from cpl.dependency.service_descriptor import ServiceDescriptor
|
||||
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
|
||||
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
|
||||
class ServiceCollection:
|
||||
@@ -62,9 +61,8 @@ class ServiceCollection:
|
||||
self._add_descriptor_by_lifetime(service_type, ServiceLifetimeEnum.transient, service)
|
||||
return self
|
||||
|
||||
def build(self) -> ServiceProviderABC:
|
||||
def build(self) -> ServiceProvider:
|
||||
sp = ServiceProvider(self._service_descriptors)
|
||||
ServiceProviderABC.set_global_provider(sp)
|
||||
return sp
|
||||
|
||||
def add_module(self, module: str | object) -> Self:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Union, Optional
|
||||
|
||||
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
|
||||
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
|
||||
|
||||
|
||||
class ServiceDescriptor:
|
||||
|
||||
7
src/cpl-dependency/cpl/dependency/service_lifetime.py
Normal file
7
src/cpl-dependency/cpl/dependency/service_lifetime.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
class ServiceLifetimeEnum(Enum):
|
||||
singleton = auto()
|
||||
scoped = auto()
|
||||
transient = auto()
|
||||
@@ -1,7 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ServiceLifetimeEnum(Enum):
|
||||
singleton = 0
|
||||
scoped = 1
|
||||
transient = 2
|
||||
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import typing
|
||||
from contextlib import contextmanager
|
||||
from inspect import signature, Parameter, Signature
|
||||
from typing import Optional, Type
|
||||
|
||||
@@ -7,34 +8,15 @@ from cpl.core.configuration import Configuration
|
||||
from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.core.typing import T, R, Source
|
||||
from cpl.dependency.scope_abc import ScopeABC
|
||||
from cpl.dependency.scope_builder import ScopeBuilder
|
||||
from cpl.dependency import use_provider
|
||||
from cpl.dependency.service_descriptor import ServiceDescriptor
|
||||
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
from cpl.dependency.service_lifetime import ServiceLifetimeEnum
|
||||
|
||||
|
||||
class ServiceProvider(ServiceProviderABC):
|
||||
r"""Provider for the services
|
||||
|
||||
Parameter
|
||||
---------
|
||||
service_descriptors: list[:class:`cpl.dependency.service_descriptor.ServiceDescriptor`]
|
||||
Descriptor of the service
|
||||
config: :class:`cpl.core.configuration.configuration_abc.ConfigurationABC`
|
||||
CPL Configuration
|
||||
db_context: Optional[:class:`cpl.database.context.database_context_abc.DatabaseContextABC`]
|
||||
Database representation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_descriptors: list[ServiceDescriptor],
|
||||
):
|
||||
ServiceProviderABC.__init__(self)
|
||||
|
||||
class ServiceProvider:
|
||||
def __init__(self, service_descriptors: list[ServiceDescriptor], is_scope: bool = False):
|
||||
self._service_descriptors: list[ServiceDescriptor] = service_descriptors
|
||||
self._scope: Optional[ScopeABC] = None
|
||||
self._is_scope = is_scope
|
||||
|
||||
def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]:
|
||||
origin_type = typing.get_origin(service_type) or service_type
|
||||
@@ -67,7 +49,7 @@ class ServiceProvider(ServiceProviderABC):
|
||||
return descriptor.implementation
|
||||
|
||||
implementation = self._build_service(descriptor.service_type, origin_service_type=origin_service_type)
|
||||
if descriptor.lifetime == ServiceLifetimeEnum.singleton:
|
||||
if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped):
|
||||
descriptor.implementation = implementation
|
||||
|
||||
return implementation
|
||||
@@ -85,7 +67,7 @@ class ServiceProvider(ServiceProviderABC):
|
||||
implementation = self._build_service(
|
||||
descriptor.service_type, origin_service_type=service_type, **kwargs
|
||||
)
|
||||
if descriptor.lifetime == ServiceLifetimeEnum.singleton:
|
||||
if descriptor.lifetime in (ServiceLifetimeEnum.singleton, ServiceLifetimeEnum.scoped):
|
||||
descriptor.implementation = implementation
|
||||
|
||||
implementations.append(implementation)
|
||||
@@ -103,7 +85,7 @@ class ServiceProvider(ServiceProviderABC):
|
||||
elif parameter.annotation == Source:
|
||||
params.append(origin_service_type.__name__)
|
||||
|
||||
elif issubclass(parameter.annotation, ServiceProviderABC):
|
||||
elif issubclass(parameter.annotation, ServiceProvider):
|
||||
params.append(self)
|
||||
|
||||
elif issubclass(parameter.annotation, Environment):
|
||||
@@ -131,32 +113,27 @@ class ServiceProvider(ServiceProviderABC):
|
||||
service_type = type(descriptor.implementation)
|
||||
else:
|
||||
service_type = descriptor.service_type
|
||||
|
||||
break
|
||||
|
||||
sig = signature(service_type.__init__)
|
||||
params = self._build_by_signature(sig, origin_service_type)
|
||||
|
||||
return service_type(*params, *args, **kwargs)
|
||||
|
||||
def set_scope(self, scope: ScopeABC):
|
||||
self._scope = scope
|
||||
|
||||
def create_scope(self) -> ScopeABC:
|
||||
descriptors = []
|
||||
|
||||
for descriptor in self._service_descriptors:
|
||||
if descriptor.lifetime == ServiceLifetimeEnum.singleton:
|
||||
descriptors.append(descriptor)
|
||||
@contextmanager
|
||||
def create_scope(self):
|
||||
scoped_descriptors = []
|
||||
for d in self._service_descriptors:
|
||||
if d.lifetime == ServiceLifetimeEnum.singleton:
|
||||
scoped_descriptors.append(d)
|
||||
else:
|
||||
descriptors.append(copy.deepcopy(descriptor))
|
||||
scoped_descriptors.append(copy.deepcopy(d))
|
||||
|
||||
sb = ScopeBuilder(ServiceProvider(descriptors))
|
||||
return sb.build()
|
||||
scoped_provider = ServiceProvider(scoped_descriptors, is_scope=True)
|
||||
with use_provider(scoped_provider):
|
||||
yield scoped_provider
|
||||
|
||||
def get_service(self, service_type: T, *args, **kwargs) -> Optional[R]:
|
||||
result = self._find_service(service_type)
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
@@ -164,11 +141,10 @@ class ServiceProvider(ServiceProviderABC):
|
||||
return result.implementation
|
||||
|
||||
implementation = self._build_service(service_type, *args, **kwargs)
|
||||
if (
|
||||
result.lifetime == ServiceLifetimeEnum.singleton
|
||||
or result.lifetime == ServiceLifetimeEnum.scoped
|
||||
and self._scope is not None
|
||||
):
|
||||
|
||||
if result.lifetime == ServiceLifetimeEnum.singleton:
|
||||
result.implementation = implementation
|
||||
elif result.lifetime == ServiceLifetimeEnum.scoped and self._is_scope:
|
||||
result.implementation = implementation
|
||||
|
||||
return implementation
|
||||
@@ -181,12 +157,9 @@ class ServiceProvider(ServiceProviderABC):
|
||||
|
||||
def get_services(self, service_type: T, *args, **kwargs) -> list[Optional[R]]:
|
||||
implementations = []
|
||||
|
||||
if typing.get_origin(service_type) == list:
|
||||
raise Exception(f"Invalid type {service_type}! Expected single type not list of type")
|
||||
|
||||
implementations.extend(self._get_services(service_type))
|
||||
|
||||
return implementations
|
||||
|
||||
def get_service_types(self, service_type: Type[T]) -> list[Type[T]]:
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
import functools
|
||||
from abc import abstractmethod, ABC
|
||||
from inspect import Signature, signature, iscoroutinefunction
|
||||
from typing import Optional, Type
|
||||
|
||||
from cpl.core.typing import T, R
|
||||
from cpl.dependency.scope_abc import ScopeABC
|
||||
|
||||
|
||||
class ServiceProviderABC(ABC):
|
||||
r"""ABC for the class :class:`cpl.dependency.service_provider.ServiceProvider`"""
|
||||
|
||||
_provider: Optional["ServiceProviderABC"] = None
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self): ...
|
||||
|
||||
@classmethod
|
||||
def set_global_provider(cls, provider: "ServiceProviderABC"):
|
||||
cls._provider = provider
|
||||
|
||||
@classmethod
|
||||
def get_global_provider(cls) -> Optional["ServiceProviderABC"]:
|
||||
return cls._provider
|
||||
|
||||
@classmethod
|
||||
def get_global_service(cls, instance_type: Type[T], *args, **kwargs) -> Optional[T]:
|
||||
if cls._provider is None:
|
||||
return None
|
||||
return cls._provider.get_service(instance_type, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_global_services(cls, instance_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
|
||||
if cls._provider is None:
|
||||
return []
|
||||
return cls._provider.get_services(instance_type, *args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[T]: ...
|
||||
|
||||
@abstractmethod
|
||||
def _build_service(self, service_type: type, *args, **kwargs) -> object:
|
||||
r"""Creates instance of given type
|
||||
|
||||
Parameter
|
||||
---------
|
||||
instance_type: :class:`type`
|
||||
The type of the searched instance
|
||||
|
||||
Returns
|
||||
-------
|
||||
Object of the given type
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_scope(self, scope: ScopeABC):
|
||||
r"""Sets the scope of service provider
|
||||
|
||||
Parameter
|
||||
---------
|
||||
Object of type :class:`cpl.dependency.scope_abc.ScopeABC`
|
||||
Service scope
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create_scope(self) -> ScopeABC:
|
||||
r"""Creates a service scope
|
||||
|
||||
Returns
|
||||
-------
|
||||
Object of type :class:`cpl.dependency.scope_abc.ScopeABC`
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_service(self, instance_type: Type[T], *args, **kwargs) -> Optional[T]:
|
||||
r"""Returns instance of given type
|
||||
|
||||
Parameter
|
||||
---------
|
||||
instance_type: :class:`cpl.core.type.T`
|
||||
The type of the searched instance
|
||||
|
||||
Returns
|
||||
-------
|
||||
Object of type Optional[:class:`cpl.core.type.T`]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_service_type(self, instance_type: Type[T]) -> Optional[Type[T]]:
|
||||
r"""Returns the registered service type for loggers
|
||||
|
||||
Parameter
|
||||
---------
|
||||
instance_type: :class:`cpl.core.type.T`
|
||||
The type of the searched instance
|
||||
|
||||
Returns
|
||||
-------
|
||||
Object of type Optional[:class:`type`]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_services(self, service_type: Type[T], *args, **kwargs) -> list[Optional[T]]:
|
||||
r"""Returns instance of given type
|
||||
|
||||
Parameter
|
||||
---------
|
||||
service_type: :class:`cpl.core.type.T`
|
||||
The type of the searched instance
|
||||
|
||||
Returns
|
||||
-------
|
||||
Object of type list[Optional[:class:`cpl.core.type.T`]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_service_types(self, service_type: Type[T]) -> list[Type[T]]:
|
||||
r"""Returns all registered service types
|
||||
|
||||
Parameter
|
||||
---------
|
||||
service_type: :class:`cpl.core.type.T`
|
||||
The type of the searched instance
|
||||
|
||||
Returns
|
||||
-------
|
||||
Object of type list[:class:`type`]
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def inject(cls, f=None):
|
||||
r"""Decorator to allow injection into static and class methods
|
||||
|
||||
Parameter
|
||||
---------
|
||||
f: Callable
|
||||
|
||||
Returns
|
||||
-------
|
||||
function
|
||||
"""
|
||||
if f is None:
|
||||
return functools.partial(cls.inject)
|
||||
|
||||
if iscoroutinefunction(f):
|
||||
|
||||
@functools.wraps(f)
|
||||
async def async_inner(*args, **kwargs):
|
||||
if cls._provider is None:
|
||||
raise Exception(f"{cls.__name__} not build!")
|
||||
|
||||
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
|
||||
return await f(*args, *injection, **kwargs)
|
||||
|
||||
return async_inner
|
||||
|
||||
@functools.wraps(f)
|
||||
def inner(*args, **kwargs):
|
||||
if cls._provider is None:
|
||||
raise Exception(f"{cls.__name__} not build!")
|
||||
|
||||
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
|
||||
return f(*args, *injection, **kwargs)
|
||||
|
||||
return inner
|
||||
@@ -1 +1,7 @@
|
||||
|
||||
from .array import Array
|
||||
from .enumerable import Enumerable
|
||||
from .immutable_list import ImmutableList
|
||||
from .immutable_set import ImmutableSet
|
||||
from .list import List
|
||||
from .ordered_enumerable import OrderedEnumerable
|
||||
from .set import Set
|
||||
|
||||
44
src/cpl-query/cpl/query/array.py
Normal file
44
src/cpl-query/cpl/query/array.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Generic, Iterable, Optional
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.query.list import List
|
||||
from cpl.query.enumerable import Enumerable
|
||||
|
||||
|
||||
class Array(Generic[T], List[T]):
|
||||
def __init__(self, length: int, source: Optional[Iterable[T]] = None):
|
||||
List.__init__(self, source)
|
||||
self._length = length
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return len(self._source)
|
||||
|
||||
def add(self, item: T) -> None:
|
||||
if self._length == self.length:
|
||||
raise IndexError("Array is full")
|
||||
self._source.append(item)
|
||||
|
||||
def extend(self, items: Iterable[T]) -> None:
|
||||
if self._length == self.length:
|
||||
raise IndexError("Array is full")
|
||||
self._source.extend(items)
|
||||
|
||||
def insert(self, index: int, item: T) -> None:
|
||||
if index < 0 or index > self.length:
|
||||
raise IndexError("Index out of range")
|
||||
self._source.insert(index, item)
|
||||
|
||||
def remove(self, item: T) -> None:
|
||||
self._source.remove(item)
|
||||
|
||||
def pop(self, index: int = -1) -> T:
|
||||
return self._source.pop(index)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._source.clear()
|
||||
|
||||
def to_enumerable(self) -> "Enumerable[T]":
|
||||
from cpl.query.enumerable import Enumerable
|
||||
|
||||
return Enumerable(self._source)
|
||||
@@ -1,173 +0,0 @@
|
||||
from itertools import islice, groupby
|
||||
from typing import Generic, Callable, Iterable, Iterator, Dict, Tuple, Optional
|
||||
|
||||
from cpl.core.typing import T, R
|
||||
from cpl.query.list import List
|
||||
from cpl.query.typing import Predicate, K, Selector
|
||||
|
||||
|
||||
class Collection(Generic[T]):
|
||||
def __init__(self, source: Iterable[T]):
|
||||
self._source = source
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
return iter(self._source)
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return sum(1 for _ in self._source)
|
||||
|
||||
def where(self, f: Predicate) -> "Collection[T]":
|
||||
return Collection([x for x in self._source if f(x)])
|
||||
|
||||
def select(self, f: Selector) -> "Collection[R]":
|
||||
return Collection([f(x) for x in self._source])
|
||||
|
||||
def select_many(self, f: Callable[[T], Iterable[R]]) -> "Collection[R]":
|
||||
return Collection([y for x in self._source for y in f(x)])
|
||||
|
||||
def take(self, count: int) -> "Collection[T]":
|
||||
return Collection(islice(self._source, count))
|
||||
|
||||
def skip(self, count: int) -> "Collection[T]":
|
||||
return Collection(islice(self._source, count, None))
|
||||
|
||||
def take_while(self, f: Predicate) -> "Collection[T]":
|
||||
result = []
|
||||
for x in self._source:
|
||||
if f(x):
|
||||
result.append(x)
|
||||
else:
|
||||
break
|
||||
return Collection(result)
|
||||
|
||||
def skip_while(self, f: Predicate) -> "Collection[T]":
|
||||
it = iter(self._source)
|
||||
for x in it:
|
||||
if not f(x):
|
||||
return Collection([x] + list(it))
|
||||
return Collection([])
|
||||
|
||||
def distinct(self) -> "Collection[T]":
|
||||
seen = set()
|
||||
return Collection([x for x in self._source if not (x in seen or seen.add(x))])
|
||||
|
||||
def union(self, other: Iterable[T]) -> "Collection[T]":
|
||||
return self.concat(other).distinct()
|
||||
|
||||
def intersect(self, other: Iterable[T]) -> "Collection[T]":
|
||||
other_set = set(other)
|
||||
return Collection([x for x in self._source if x in other_set])
|
||||
|
||||
def except_(self, other: Iterable[T]) -> "Collection[T]":
|
||||
other_set = set(other)
|
||||
return Collection([x for x in self._source if x not in other_set])
|
||||
|
||||
def concat(self, other: Iterable[T]) -> "Collection[T]":
|
||||
return Collection(self._source) + list(other)
|
||||
|
||||
def count(self) -> int:
|
||||
return len(list(self._source))
|
||||
|
||||
def sum(self, f: Optional[Selector] = None) -> R:
|
||||
return sum([f(x) for x in self._source]) if f else sum(self._source) # type: ignore
|
||||
|
||||
def min(self, f: Optional[Selector] = None) -> R:
|
||||
return min([f(x) for x in self._source]) if f else min(self._source) # type: ignore
|
||||
|
||||
def max(self, f: Optional[Selector] = None) -> R:
|
||||
return max([f(x) for x in self._source]) if f else max(self._source) # type: ignore
|
||||
|
||||
def average(self, f: Optional[Callable[[T], float]] = None) -> float:
|
||||
values = [f(x) for x in self._source] if f else list(self._source)
|
||||
return sum(values) / len(values) if values else 0.0
|
||||
|
||||
def aggregate(self, func: Callable[[R, T], R], seed: Optional[R] = None) -> R:
|
||||
it = iter(self._source)
|
||||
if seed is None:
|
||||
acc = next(it) # type: ignore
|
||||
else:
|
||||
acc = seed
|
||||
for x in it:
|
||||
acc = func(acc, x)
|
||||
return acc
|
||||
|
||||
def any(self, f: Optional[Predicate] = None) -> bool:
|
||||
return any(f(x) if f else x for x in self._source)
|
||||
|
||||
def all(self, f: Predicate) -> bool:
|
||||
return all(f(x) for x in self._source)
|
||||
|
||||
def contains(self, value: T) -> bool:
|
||||
return value in self._source
|
||||
|
||||
def sequence_equal(self, other: Iterable[T]) -> bool:
|
||||
return list(self._source) == list(other)
|
||||
|
||||
def group_by(self, key_f: Callable[[T], K]) -> "Collection[Tuple[K, List[T]]]":
|
||||
sorted_data = sorted(self._source, key=key_f)
|
||||
return Collection([(key, list(group)) for key, group in groupby(sorted_data, key=key_f)])
|
||||
|
||||
def join(
|
||||
self, inner: Iterable[R], outer_key: Callable[[T], K], inner_key: Callable[[R], K], result: Callable[[T, R], R]
|
||||
) -> "Collection[R]":
|
||||
lookup: Dict[K, List[R]] = {}
|
||||
for i in inner:
|
||||
k = inner_key(i)
|
||||
lookup.setdefault(k, []).append(i)
|
||||
return Collection([result(o, i) for o in self._source for i in lookup.get(outer_key(o), [])])
|
||||
|
||||
def first(self, f: Optional[Predicate] = None) -> T:
|
||||
if f:
|
||||
for x in self._source:
|
||||
if f(x):
|
||||
return x
|
||||
raise ValueError("No matching element")
|
||||
return next(iter(self._source))
|
||||
|
||||
def first_or_default(self, default: Optional[T] = None) -> Optional[T]:
|
||||
return next(iter(self._source), default)
|
||||
|
||||
def last(self) -> T:
|
||||
return list(self._source)[-1]
|
||||
|
||||
def single(self, f: Optional[Predicate] = None) -> T:
|
||||
items = [x for x in self._source if f(x)] if f else list(self._source)
|
||||
if len(items) != 1:
|
||||
raise ValueError("Sequence does not contain exactly one element")
|
||||
return items[0]
|
||||
|
||||
def to_list(self) -> List[T]:
|
||||
return List(self._source)
|
||||
|
||||
def to_enumerable(self) -> "Enumerable[T]":
|
||||
from cpl.query.enumerable import Enumerable
|
||||
|
||||
return Enumerable(self._source)
|
||||
|
||||
def to_dict(self, key_f: Callable[[T], K], value_f: Selector) -> Dict[K, R]:
|
||||
return {key_f(x): value_f(x) for x in self._source}
|
||||
|
||||
def cast(self, t: Selector) -> "Collection[R]":
|
||||
return Collection([t(x) for x in self._source])
|
||||
|
||||
def of_type(self, t: type) -> "Collection[T]":
|
||||
return Collection([x for x in self._source if isinstance(x, t)])
|
||||
|
||||
def reverse(self) -> "Collection[T]":
|
||||
return Collection(reversed(list(self._source)))
|
||||
|
||||
def zip(self, other: Iterable[R]) -> "Collection[Tuple[T, R]]":
|
||||
return Collection(zip(self._source, other))
|
||||
|
||||
@staticmethod
|
||||
def range(start: int, count: int) -> "Collection[int]":
|
||||
return Collection(range(start, start + count))
|
||||
|
||||
@staticmethod
|
||||
def repeat(value: T, count: int) -> "Collection[T]":
|
||||
return Collection([value for _ in range(count)])
|
||||
|
||||
@staticmethod
|
||||
def empty() -> "Collection[T]":
|
||||
return Collection([])
|
||||
@@ -167,17 +167,13 @@ class Enumerable(Generic[T]):
|
||||
|
||||
def to_list(self) -> "List[T]":
|
||||
from cpl.query.list import List
|
||||
return List(self._source)
|
||||
|
||||
def to_collection(self) -> "Collection[T]":
|
||||
from cpl.query.collection import Collection
|
||||
|
||||
return Collection(self._source)
|
||||
return List(self)
|
||||
|
||||
def to_set(self) -> "Set[T]":
|
||||
from cpl.query.set import Set
|
||||
|
||||
return Set(self._source)
|
||||
return Set(self)
|
||||
|
||||
def to_dict(self, key_f: Callable[[T], K], value_f: Selector) -> Dict[K, R]:
|
||||
return {key_f(x): value_f(x) for x in self._source}
|
||||
@@ -194,6 +190,16 @@ class Enumerable(Generic[T]):
|
||||
def zip(self, other: Iterable[R]) -> "Enumerable[Tuple[T, R]]":
|
||||
return Enumerable(zip(self._source, other))
|
||||
|
||||
def order_by(self, key_selector: Callable[[T], K]) -> "OrderedEnumerable[T]":
|
||||
from cpl.query.ordered_enumerable import OrderedEnumerable
|
||||
|
||||
return OrderedEnumerable(self._source, [(key_selector, False)])
|
||||
|
||||
def order_by_descending(self, key_selector: Callable[[T], K]) -> "OrderedEnumerable[T]":
|
||||
from cpl.query.ordered_enumerable import OrderedEnumerable
|
||||
|
||||
return OrderedEnumerable(self._source, [(key_selector, True)])
|
||||
|
||||
@staticmethod
|
||||
def range(start: int, count: int) -> "Enumerable[int]":
|
||||
return Enumerable(range(start, start + count))
|
||||
|
||||
@@ -6,40 +6,60 @@ from cpl.query.enumerable import Enumerable
|
||||
|
||||
class ImmutableList(Generic[T], Enumerable[T]):
|
||||
def __init__(self, source: Optional[Iterable[T]] = None):
|
||||
Enumerable.__init__(self, [])
|
||||
if source is None:
|
||||
source = []
|
||||
elif not isinstance(source, list):
|
||||
source = list(source)
|
||||
|
||||
Enumerable.__init__(self, source)
|
||||
self.__source = source
|
||||
|
||||
@property
|
||||
def _items(self) -> list[T]:
|
||||
return list(self._source)
|
||||
def _source(self) -> list[T]:
|
||||
return self.__source
|
||||
|
||||
@_source.setter
|
||||
def _source(self, value: list[T]) -> None:
|
||||
self.__source = value
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
return iter(self._items)
|
||||
return iter(self._source)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._items)
|
||||
return len(self._source)
|
||||
|
||||
def __getitem__(self, index: int) -> T:
|
||||
return self._items[index]
|
||||
return self._source[index]
|
||||
|
||||
def __contains__(self, item: T) -> bool:
|
||||
return item in self._items
|
||||
return item in self._source
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"List({self._items!r})"
|
||||
return f"List({self._source!r})"
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return len(self._items)
|
||||
return len(self._source)
|
||||
|
||||
def add(self, item: T) -> None:
|
||||
self._source.append(item)
|
||||
|
||||
def extend(self, items: Iterable[T]) -> None:
|
||||
self._source.extend(items)
|
||||
|
||||
def insert(self, index: int, item: T) -> None:
|
||||
self._source.insert(index, item)
|
||||
|
||||
def remove(self, item: T) -> None:
|
||||
self._source.remove(item)
|
||||
|
||||
def pop(self, index: int = -1) -> T:
|
||||
return self._source.pop(index)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._source.clear()
|
||||
|
||||
def to_enumerable(self) -> "Enumerable[T]":
|
||||
from cpl.query.enumerable import Enumerable
|
||||
|
||||
return Enumerable(self._items)
|
||||
|
||||
def to_collection(self) -> "Collection[T]":
|
||||
from cpl.query.collection import Collection
|
||||
|
||||
return Collection(self._items)
|
||||
return Enumerable(self._source)
|
||||
|
||||
@@ -6,11 +6,13 @@ from cpl.query.enumerable import Enumerable
|
||||
|
||||
class ImmutableSet(Generic[T], Enumerable[T]):
|
||||
def __init__(self, source: Optional[Iterable[T]] = None):
|
||||
Enumerable.__init__(self, [])
|
||||
if source is None:
|
||||
source = set()
|
||||
elif not isinstance(source, set):
|
||||
source = set(source)
|
||||
|
||||
self.__source = source
|
||||
Enumerable.__init__(self, [])
|
||||
|
||||
@property
|
||||
def _source(self) -> set[T]:
|
||||
@@ -41,4 +43,5 @@ class ImmutableSet(Generic[T], Enumerable[T]):
|
||||
|
||||
def to_enumerable(self) -> "Enumerable[T]":
|
||||
from cpl.query.enumerable import Enumerable
|
||||
return Enumerable(self._source)
|
||||
|
||||
return Enumerable(self._source)
|
||||
|
||||
@@ -1,66 +1,36 @@
|
||||
from typing import Generic, Iterable, Iterator, Optional
|
||||
from typing import Generic, Iterable, Optional
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.query.immutable_list import ImmutableList
|
||||
from cpl.query.enumerable import Enumerable
|
||||
|
||||
|
||||
class List(Generic[T], Enumerable[T]):
|
||||
class List(Generic[T], ImmutableList[T]):
|
||||
def __init__(self, source: Optional[Iterable[T]] = None):
|
||||
if source is None:
|
||||
source = []
|
||||
|
||||
Enumerable.__init__(self, source)
|
||||
|
||||
@property
|
||||
def _items(self) -> list[T]:
|
||||
return list(self._source)
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
return iter(self._items)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._items)
|
||||
|
||||
def __getitem__(self, index: int) -> T:
|
||||
return self._items[index]
|
||||
ImmutableList.__init__(self, source)
|
||||
|
||||
def __setitem__(self, index: int, value: T) -> None:
|
||||
self._items[index] = value
|
||||
|
||||
def __contains__(self, item: T) -> bool:
|
||||
return item in self._items
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"List({self._items!r})"
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return len(self._items)
|
||||
self._source[index] = value
|
||||
|
||||
def add(self, item: T) -> None:
|
||||
self._items.append(item)
|
||||
self._source.append(item)
|
||||
|
||||
def extend(self, items: Iterable[T]) -> None:
|
||||
self._items.extend(items)
|
||||
self._source.extend(items)
|
||||
|
||||
def insert(self, index: int, item: T) -> None:
|
||||
self._items.insert(index, item)
|
||||
self._source.insert(index, item)
|
||||
|
||||
def remove(self, item: T) -> None:
|
||||
self._items.remove(item)
|
||||
self._source.remove(item)
|
||||
|
||||
def pop(self, index: int = -1) -> T:
|
||||
return self._items.pop(index)
|
||||
return self._source.pop(index)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._items.clear()
|
||||
self._source.clear()
|
||||
|
||||
def to_enumerable(self) -> "Enumerable[T]":
|
||||
from cpl.query.enumerable import Enumerable
|
||||
|
||||
return Enumerable(self._items)
|
||||
|
||||
def to_collection(self) -> "Collection[T]":
|
||||
from cpl.query.collection import Collection
|
||||
|
||||
return Collection(self._items)
|
||||
return Enumerable(self._source)
|
||||
|
||||
40
src/cpl-query/cpl/query/ordered_enumerable.py
Normal file
40
src/cpl-query/cpl/query/ordered_enumerable.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Callable, List, Generic, Iterator
|
||||
from cpl.core.typing import T
|
||||
from cpl.query.enumerable import Enumerable
|
||||
from cpl.query.typing import K
|
||||
|
||||
|
||||
class OrderedEnumerable(Enumerable[T]):
|
||||
def __init__(self, source, key_selectors: List[tuple[Callable[[T], K], bool]]):
|
||||
super().__init__(source)
|
||||
self._key_selectors = key_selectors
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
def composite_key(x):
|
||||
keys = []
|
||||
for selector, descending in self._key_selectors:
|
||||
k = selector(x)
|
||||
keys.append((k, not descending))
|
||||
return tuple(k if asc else _DescendingWrapper(k) for k, asc in keys)
|
||||
|
||||
return iter(sorted(self._source, key=composite_key))
|
||||
|
||||
def then_by(self, key_selector: Callable[[T], K]) -> "OrderedEnumerable[T]":
|
||||
return OrderedEnumerable(self._source, self._key_selectors + [(key_selector, False)])
|
||||
|
||||
def then_by_descending(self, key_selector: Callable[[T], K]) -> "OrderedEnumerable[T]":
|
||||
return OrderedEnumerable(self._source, self._key_selectors + [(key_selector, True)])
|
||||
|
||||
|
||||
class _DescendingWrapper:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.value > other.value
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.value < other.value
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.value == other.value
|
||||
@@ -0,0 +1 @@
|
||||
from .sequence import Sequence
|
||||
|
||||
@@ -1,39 +1,13 @@
|
||||
from typing import Generic, Iterable, Iterator, Optional
|
||||
from typing import Generic, Iterable, Optional
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.query.immutable_set import ImmutableSet
|
||||
from cpl.query.enumerable import Enumerable
|
||||
|
||||
|
||||
class Set(Generic[T], Enumerable[T]):
|
||||
class Set(Generic[T], ImmutableSet[T]):
|
||||
def __init__(self, source: Optional[Iterable[T]] = None):
|
||||
if source is None:
|
||||
source = set()
|
||||
|
||||
self.__source = source
|
||||
Enumerable.__init__(self, [])
|
||||
|
||||
@property
|
||||
def _source(self) -> set[T]:
|
||||
return self.__source
|
||||
|
||||
@_source.setter
|
||||
def _source(self, value: set[T]) -> None:
|
||||
if not isinstance(value, set):
|
||||
value = set(value)
|
||||
|
||||
self.__source = value
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
return iter(self._source)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._source)
|
||||
|
||||
def __contains__(self, item: T) -> bool:
|
||||
return item in self._source
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Set({self._source!r})"
|
||||
ImmutableSet.__init__(self, source)
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
@@ -50,4 +24,5 @@ class Set(Generic[T], Enumerable[T]):
|
||||
|
||||
def to_enumerable(self) -> "Enumerable[T]":
|
||||
from cpl.query.enumerable import Enumerable
|
||||
return Enumerable(self._source)
|
||||
|
||||
return Enumerable(self._source)
|
||||
|
||||
@@ -2,7 +2,7 @@ import unittest
|
||||
|
||||
from cpl.application import ApplicationABC
|
||||
from cpl.core.configuration import ConfigurationABC
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
from unittests_cli.cli_test_suite import CLITestSuite
|
||||
from unittests_core.core_test_suite import CoreTestSuite
|
||||
from unittests_query.query_test_suite import QueryTestSuite
|
||||
@@ -10,7 +10,7 @@ from unittests_translation.translation_test_suite import TranslationTestSuite
|
||||
|
||||
|
||||
class Application(ApplicationABC):
|
||||
def __init__(self, config: ConfigurationABC, services: ServiceProviderABC):
|
||||
def __init__(self, config: ConfigurationABC, services: ServiceProvider):
|
||||
ApplicationABC.__init__(self, config, services)
|
||||
|
||||
def configure(self): ...
|
||||
|
||||
@@ -2,7 +2,7 @@ import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.dependency import ServiceCollection, ServiceLifetimeEnum, ServiceProviderABC
|
||||
from cpl.dependency import ServiceCollection, ServiceLifetimeEnum, ServiceProvider
|
||||
|
||||
|
||||
class ServiceCollectionTestCase(unittest.TestCase):
|
||||
@@ -51,6 +51,6 @@ class ServiceCollectionTestCase(unittest.TestCase):
|
||||
service = self._sc._service_descriptors[0]
|
||||
self.assertIsNone(service.implementation)
|
||||
sp = self._sc.build()
|
||||
self.assertTrue(isinstance(sp, ServiceProviderABC))
|
||||
self.assertTrue(isinstance(sp, ServiceProvider))
|
||||
self.assertTrue(isinstance(sp.get_service(Mock), Mock))
|
||||
self.assertIsNotNone(service.implementation)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.dependency import ServiceCollection, ServiceProviderABC
|
||||
from cpl.dependency import ServiceCollection, ServiceProvider
|
||||
|
||||
|
||||
class ServiceCount:
|
||||
@@ -10,21 +10,21 @@ class ServiceCount:
|
||||
|
||||
|
||||
class TestService:
|
||||
def __init__(self, sp: ServiceProviderABC, count: ServiceCount):
|
||||
def __init__(self, sp: ServiceProvider, count: ServiceCount):
|
||||
count.count += 1
|
||||
self.sp = sp
|
||||
self.id = count.count
|
||||
|
||||
|
||||
class DifferentService:
|
||||
def __init__(self, sp: ServiceProviderABC, count: ServiceCount):
|
||||
def __init__(self, sp: ServiceProvider, count: ServiceCount):
|
||||
count.count += 1
|
||||
self.sp = sp
|
||||
self.id = count.count
|
||||
|
||||
|
||||
class MoreDifferentService:
|
||||
def __init__(self, sp: ServiceProviderABC, count: ServiceCount):
|
||||
def __init__(self, sp: ServiceProvider, count: ServiceCount):
|
||||
count.count += 1
|
||||
self.sp = sp
|
||||
self.id = count.count
|
||||
@@ -72,7 +72,7 @@ class ServiceProviderTestCase(unittest.TestCase):
|
||||
singleton = self._services.get_service(TestService)
|
||||
transient = self._services.get_service(DifferentService)
|
||||
with self._services.create_scope() as scope:
|
||||
sp: ServiceProviderABC = scope.service_provider
|
||||
sp: ServiceProvider = scope.service_provider
|
||||
self.assertNotEqual(sp, self._services)
|
||||
y = sp.get_service(DifferentService)
|
||||
self.assertIsNotNone(y)
|
||||
|
||||
Reference in New Issue
Block a user