diff --git a/src/cpl-core/cpl/core/utils/cache.py b/src/cpl-core/cpl/core/utils/cache.py new file mode 100644 index 00000000..81d945bf --- /dev/null +++ b/src/cpl-core/cpl/core/utils/cache.py @@ -0,0 +1,100 @@ +import threading +import time +from typing import Generic + +from cpl.core.typing import T + + +class Cache(Generic[T]): + def __init__(self, default_ttl: int = None, cleanup_interval: int = 60, t: type = None): + self._store = {} + self._default_ttl = default_ttl + self._lock = threading.Lock() + self._cleanup_interval = cleanup_interval + self._stop_event = threading.Event() + + self._type = t + + # Start background cleanup thread + self._thread = threading.Thread(target=self._auto_cleanup, daemon=True) + self._thread.start() + + def set(self, key: str, value: T, ttl: int = None) -> None: + """Store a value in the cache with optional TTL override.""" + expire_at = None + ttl = ttl if ttl is not None else self._default_ttl + if ttl is not None: + expire_at = time.time() + ttl + + with self._lock: + self._store[key] = (value, expire_at) + + def get(self, key: str) -> T | None: + """Retrieve a value from the cache if not expired.""" + with self._lock: + item = self._store.get(key) + if not item: + return None + value, expire_at = item + if expire_at and expire_at < time.time(): + # Expired -> remove and return None + del self._store[key] + return None + return value + + def get_all(self) -> list[T]: + """Retrieve all non-expired values from the cache.""" + now = time.time() + with self._lock: + valid_items = [] + expired_keys = [] + for k, (v, exp) in self._store.items(): + if exp and exp < now: + expired_keys.append(k) + else: + valid_items.append(v) + for k in expired_keys: + del self._store[k] + return valid_items + + def has(self, key: str) -> bool: + """Check if a key exists and is not expired.""" + with self._lock: + item = self._store.get(key) + if not item: + return False + _, expire_at = item + if expire_at and expire_at < time.time(): + # Expired -> remove and return False + del self._store[key] + return False + return True + + def delete(self, key: str) -> None: + """Remove an item from the cache.""" + with self._lock: + self._store.pop(key, None) + + def clear(self) -> None: + """Clear the entire cache.""" + with self._lock: + self._store.clear() + + def _auto_cleanup(self): + """Background thread to clean expired items.""" + while not self._stop_event.is_set(): + self.cleanup() + self._stop_event.wait(self._cleanup_interval) + + def cleanup(self) -> None: + """Remove expired items immediately.""" + now = time.time() + with self._lock: + expired_keys = [k for k, (_, exp) in self._store.items() if exp and exp < now] + for k in expired_keys: + del self._store[k] + + def stop(self): + """Stop the background cleanup thread.""" + self._stop_event.set() + self._thread.join() diff --git a/src/cpl-dependency/cpl/dependency/service_collection.py b/src/cpl-dependency/cpl/dependency/service_collection.py index e2088050..52fa500a 100644 --- a/src/cpl-dependency/cpl/dependency/service_collection.py +++ b/src/cpl-dependency/cpl/dependency/service_collection.py @@ -2,6 +2,7 @@ from typing import Union, Type, Callable, Self from cpl.core.log.logger_abc import LoggerABC from cpl.core.typing import T, Service +from cpl.core.utils.cache import Cache from cpl.dependency.service_descriptor import ServiceDescriptor from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum from cpl.dependency.service_provider import ServiceProvider @@ -96,3 +97,7 @@ class ServiceCollection: for wrapper in WrappedLogger.__subclasses__(): self.add_transient(wrapper) return self + + def add_cache(self, t: Type[T]): + self._service_descriptors.append(ServiceDescriptor(Cache(t=t), ServiceLifetimeEnum.singleton, Cache[t])) + return self diff --git a/src/cpl-dependency/cpl/dependency/service_provider.py b/src/cpl-dependency/cpl/dependency/service_provider.py index 3dcac39d..0eedc46f 100644 --- a/src/cpl-dependency/cpl/dependency/service_provider.py +++ b/src/cpl-dependency/cpl/dependency/service_provider.py @@ -37,8 +37,23 @@ class ServiceProvider(ServiceProviderABC): self._scope: Optional[ScopeABC] = None def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]: + origin_type = typing.get_origin(service_type) or service_type + type_args = list(typing.get_args(service_type)) + for descriptor in self._service_descriptors: - if descriptor.service_type == service_type or issubclass(descriptor.base_type, service_type): + descriptor_base_type = typing.get_origin(descriptor.base_type) or descriptor.base_type + descriptor_type_args = list(typing.get_args(descriptor.base_type)) + + if descriptor_base_type == origin_type and len(descriptor_type_args) == 0 and len(type_args) == 0: + return descriptor + + if descriptor_base_type != origin_type or len(descriptor_type_args) != len(type_args): + continue + + if descriptor_base_type == origin_type and type_args != descriptor_type_args: + continue + + if descriptor.service_type == origin_type or issubclass(descriptor.base_type, origin_type): return descriptor return None @@ -158,7 +173,6 @@ class ServiceProvider(ServiceProviderABC): return implementation - def get_service_type(self, service_type: Type[T]) -> Optional[Type[T]]: for descriptor in self._service_descriptors: if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type): diff --git a/src/cpl-dependency/cpl/dependency/service_provider_abc.py b/src/cpl-dependency/cpl/dependency/service_provider_abc.py index 93ba3c8a..53e87ba2 100644 --- a/src/cpl-dependency/cpl/dependency/service_provider_abc.py +++ b/src/cpl-dependency/cpl/dependency/service_provider_abc.py @@ -86,7 +86,7 @@ class ServiceProviderABC(ABC): """ @abstractmethod - def get_service_type(self,instance_type: Type[T]) -> Optional[Type[T]]: + def get_service_type(self, instance_type: Type[T]) -> Optional[Type[T]]: r"""Returns the registered service type for loggers Parameter diff --git a/tests/custom/api/src/main.py b/tests/custom/api/src/main.py index 4432a97e..3a5ee2f9 100644 --- a/tests/custom/api/src/main.py +++ b/tests/custom/api/src/main.py @@ -4,8 +4,10 @@ from cpl import api from cpl.api.application.web_app import WebApp from cpl.application import ApplicationBuilder from cpl.auth.permission.permissions import Permissions +from cpl.auth.schema import AuthUser, Role from cpl.core.configuration import Configuration from cpl.core.environment import Environment +from cpl.core.utils.cache import Cache from service import PingService @@ -21,6 +23,9 @@ def main(): builder.services.add_transient(PingService) builder.services.add_module(api) + builder.services.add_cache(AuthUser) + builder.services.add_cache(Role) + app = builder.build() app.with_logging() app.with_database() @@ -31,6 +36,10 @@ def main(): app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET", authentication=True, permissions=[Permissions.administrator]) app.with_routes_directory("routes") + provider = builder.service_provider + user_cache = provider.get_service(Cache[AuthUser]) + role_cache = provider.get_service(Cache[Role]) + app.run()