Added cache

This commit is contained in:
2025-09-24 12:05:04 +02:00
parent c71a3df62c
commit a1cfe76047
5 changed files with 131 additions and 3 deletions

View File

@@ -0,0 +1,100 @@
import threading
import time
from typing import Generic
from cpl.core.typing import T
class Cache(Generic[T]):
def __init__(self, default_ttl: int = None, cleanup_interval: int = 60, t: type = None):
self._store = {}
self._default_ttl = default_ttl
self._lock = threading.Lock()
self._cleanup_interval = cleanup_interval
self._stop_event = threading.Event()
self._type = t
# Start background cleanup thread
self._thread = threading.Thread(target=self._auto_cleanup, daemon=True)
self._thread.start()
def set(self, key: str, value: T, ttl: int = None) -> None:
"""Store a value in the cache with optional TTL override."""
expire_at = None
ttl = ttl if ttl is not None else self._default_ttl
if ttl is not None:
expire_at = time.time() + ttl
with self._lock:
self._store[key] = (value, expire_at)
def get(self, key: str) -> T | None:
"""Retrieve a value from the cache if not expired."""
with self._lock:
item = self._store.get(key)
if not item:
return None
value, expire_at = item
if expire_at and expire_at < time.time():
# Expired -> remove and return None
del self._store[key]
return None
return value
def get_all(self) -> list[T]:
"""Retrieve all non-expired values from the cache."""
now = time.time()
with self._lock:
valid_items = []
expired_keys = []
for k, (v, exp) in self._store.items():
if exp and exp < now:
expired_keys.append(k)
else:
valid_items.append(v)
for k in expired_keys:
del self._store[k]
return valid_items
def has(self, key: str) -> bool:
"""Check if a key exists and is not expired."""
with self._lock:
item = self._store.get(key)
if not item:
return False
_, expire_at = item
if expire_at and expire_at < time.time():
# Expired -> remove and return False
del self._store[key]
return False
return True
def delete(self, key: str) -> None:
"""Remove an item from the cache."""
with self._lock:
self._store.pop(key, None)
def clear(self) -> None:
"""Clear the entire cache."""
with self._lock:
self._store.clear()
def _auto_cleanup(self):
"""Background thread to clean expired items."""
while not self._stop_event.is_set():
self.cleanup()
self._stop_event.wait(self._cleanup_interval)
def cleanup(self) -> None:
"""Remove expired items immediately."""
now = time.time()
with self._lock:
expired_keys = [k for k, (_, exp) in self._store.items() if exp and exp < now]
for k in expired_keys:
del self._store[k]
def stop(self):
"""Stop the background cleanup thread."""
self._stop_event.set()
self._thread.join()

View File

@@ -2,6 +2,7 @@ from typing import Union, Type, Callable, Self
from cpl.core.log.logger_abc import LoggerABC
from cpl.core.typing import T, Service
from cpl.core.utils.cache import Cache
from cpl.dependency.service_descriptor import ServiceDescriptor
from cpl.dependency.service_lifetime_enum import ServiceLifetimeEnum
from cpl.dependency.service_provider import ServiceProvider
@@ -96,3 +97,7 @@ class ServiceCollection:
for wrapper in WrappedLogger.__subclasses__():
self.add_transient(wrapper)
return self
def add_cache(self, t: Type[T]):
self._service_descriptors.append(ServiceDescriptor(Cache(t=t), ServiceLifetimeEnum.singleton, Cache[t]))
return self

View File

@@ -37,8 +37,23 @@ class ServiceProvider(ServiceProviderABC):
self._scope: Optional[ScopeABC] = None
def _find_service(self, service_type: type) -> Optional[ServiceDescriptor]:
origin_type = typing.get_origin(service_type) or service_type
type_args = list(typing.get_args(service_type))
for descriptor in self._service_descriptors:
if descriptor.service_type == service_type or issubclass(descriptor.base_type, service_type):
descriptor_base_type = typing.get_origin(descriptor.base_type) or descriptor.base_type
descriptor_type_args = list(typing.get_args(descriptor.base_type))
if descriptor_base_type == origin_type and len(descriptor_type_args) == 0 and len(type_args) == 0:
return descriptor
if descriptor_base_type != origin_type or len(descriptor_type_args) != len(type_args):
continue
if descriptor_base_type == origin_type and type_args != descriptor_type_args:
continue
if descriptor.service_type == origin_type or issubclass(descriptor.base_type, origin_type):
return descriptor
return None
@@ -158,7 +173,6 @@ class ServiceProvider(ServiceProviderABC):
return implementation
def get_service_type(self, service_type: Type[T]) -> Optional[Type[T]]:
for descriptor in self._service_descriptors:
if descriptor.service_type == service_type or issubclass(descriptor.service_type, service_type):

View File

@@ -86,7 +86,7 @@ class ServiceProviderABC(ABC):
"""
@abstractmethod
def get_service_type(self,instance_type: Type[T]) -> Optional[Type[T]]:
def get_service_type(self, instance_type: Type[T]) -> Optional[Type[T]]:
r"""Returns the registered service type for loggers
Parameter