Recursive complex filtering #181
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
import sqlparse
|
||||
import asyncio
|
||||
|
||||
from mysql.connector import errors, PoolError
|
||||
from mysql.connector.aio import MySQLConnectionPool
|
||||
|
||||
from cpl.core.environment import Environment
|
||||
@@ -10,7 +12,6 @@ from cpl.dependency.context import get_provider
|
||||
|
||||
|
||||
class MySQLPool:
|
||||
|
||||
def __init__(self, database_settings: DatabaseSettings):
|
||||
self._dbconfig = {
|
||||
"host": database_settings.host,
|
||||
@@ -25,59 +26,87 @@ class MySQLPool:
|
||||
"ssl_disabled": database_settings.ssl_disabled,
|
||||
}
|
||||
self._pool: Optional[MySQLConnectionPool] = None
|
||||
self._pool_lock = asyncio.Lock()
|
||||
|
||||
async def _get_pool(self):
|
||||
async def _get_pool(self) -> MySQLConnectionPool:
|
||||
if self._pool is None:
|
||||
try:
|
||||
self._pool = MySQLConnectionPool(
|
||||
pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig
|
||||
)
|
||||
await self._pool.initialize_pool()
|
||||
async with self._pool_lock:
|
||||
if self._pool is None:
|
||||
try:
|
||||
self._pool = MySQLConnectionPool(
|
||||
pool_name="cplpool",
|
||||
pool_size=Environment.get("DB_POOL_SIZE", int, 20),
|
||||
**self._dbconfig,
|
||||
)
|
||||
await self._pool.initialize_pool()
|
||||
|
||||
con = await self._pool.get_connection()
|
||||
async with await con.cursor() as cursor:
|
||||
await cursor.execute("SELECT 1")
|
||||
await cursor.fetchall()
|
||||
|
||||
await con.close()
|
||||
except Exception as e:
|
||||
logger = get_provider().get_service(DBLogger)
|
||||
logger.fatal(f"Error connecting to the database", e)
|
||||
# Testverbindung (Ping)
|
||||
con = await self._pool.get_connection()
|
||||
try:
|
||||
async with await con.cursor() as cursor:
|
||||
await cursor.execute("SELECT 1")
|
||||
await cursor.fetchall()
|
||||
finally:
|
||||
await con.close()
|
||||
|
||||
except Exception as e:
|
||||
logger = get_provider().get_service(DBLogger)
|
||||
logger.fatal("Error connecting to the database", e)
|
||||
raise
|
||||
return self._pool
|
||||
|
||||
async def _get_connection(self, retries: int = 3, delay: float = 0.5):
|
||||
"""Stabiler Connection-Getter mit Retry und Ping"""
|
||||
pool = await self._get_pool()
|
||||
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
con = await pool.get_connection()
|
||||
|
||||
# Verbindungs-Check (Ping)
|
||||
try:
|
||||
async with await con.cursor() as cursor:
|
||||
await cursor.execute("SELECT 1")
|
||||
await cursor.fetchall()
|
||||
except errors.OperationalError:
|
||||
await con.close()
|
||||
raise
|
||||
|
||||
return con
|
||||
|
||||
except PoolError:
|
||||
if attempt == retries - 1:
|
||||
raise
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
@staticmethod
|
||||
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
|
||||
result = []
|
||||
if multi:
|
||||
queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()]
|
||||
for q in queries:
|
||||
if q.strip() == "":
|
||||
continue
|
||||
await cursor.execute(q, args)
|
||||
if cursor.description is not None:
|
||||
result = await cursor.fetchall()
|
||||
if q:
|
||||
await cursor.execute(q, args)
|
||||
if cursor.description is not None:
|
||||
result = await cursor.fetchall()
|
||||
else:
|
||||
await cursor.execute(query, args)
|
||||
if cursor.description is not None:
|
||||
result = await cursor.fetchall()
|
||||
|
||||
return result
|
||||
|
||||
async def execute(self, query: str, args=None, multi=True) -> list[list]:
|
||||
pool = await self._get_pool()
|
||||
con = await pool.get_connection()
|
||||
async def execute(self, query: str, args=None, multi=True) -> list[str]:
|
||||
con = await self._get_connection()
|
||||
try:
|
||||
async with await con.cursor() as cursor:
|
||||
result = await self._exec_sql(cursor, query, args, multi)
|
||||
res = await self._exec_sql(cursor, query, args, multi)
|
||||
await con.commit()
|
||||
return result
|
||||
return list(res)
|
||||
finally:
|
||||
await con.close()
|
||||
|
||||
async def select(self, query: str, args=None, multi=True) -> list[str]:
|
||||
pool = await self._get_pool()
|
||||
con = await pool.get_connection()
|
||||
con = await self._get_connection()
|
||||
try:
|
||||
async with await con.cursor() as cursor:
|
||||
res = await self._exec_sql(cursor, query, args, multi)
|
||||
@@ -86,8 +115,7 @@ class MySQLPool:
|
||||
await con.close()
|
||||
|
||||
async def select_map(self, query: str, args=None, multi=True) -> list[dict]:
|
||||
pool = await self._get_pool()
|
||||
con = await pool.get_connection()
|
||||
con = await self._get_connection()
|
||||
try:
|
||||
async with await con.cursor(dictionary=True) as cursor:
|
||||
res = await self._exec_sql(cursor, query, args, multi)
|
||||
|
||||
@@ -27,7 +27,7 @@ class PostgresPool:
|
||||
self._pool: Optional[AsyncConnectionPool] = None
|
||||
|
||||
async def _get_pool(self):
|
||||
if self._pool is None:
|
||||
if self._pool is None or self._pool.closed:
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Type, Dict, List
|
||||
import strawberry
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.dependency import get_provider
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
|
||||
|
||||
@@ -14,7 +15,12 @@ class CollectionGraphTypeFactory:
|
||||
if node_type in cls._cache:
|
||||
return cls._cache[node_type]
|
||||
|
||||
gql_node = node_type().to_strawberry() if hasattr(node_type, "to_strawberry") else node_type
|
||||
node_t = get_provider().get_service(node_type)
|
||||
if not node_t:
|
||||
raise ValueError(f"Node type '{node_type.__name__}' not registered in service provider")
|
||||
|
||||
|
||||
gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type
|
||||
|
||||
gql_type = strawberry.type(
|
||||
type(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Type
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.graphql.schema.filter.bool_filter import BoolFilter
|
||||
from cpl.graphql.schema.filter.date_filter import DateFilter
|
||||
@@ -10,6 +12,9 @@ class Filter(Input[T]):
|
||||
def __init__(self):
|
||||
Input.__init__(self)
|
||||
|
||||
def filter_field(self, name: str, filter_type: Type["Filter"]):
|
||||
self.field(name, filter_type())
|
||||
|
||||
def string_field(self, name: str):
|
||||
self.field(name, StringFilter())
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import strawberry
|
||||
from cpl.core.typing import T
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.utils.type_collector import TypeCollector
|
||||
|
||||
_PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
|
||||
|
||||
@@ -18,12 +19,10 @@ class Input(StrawberryProtocol, Generic[T]):
|
||||
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True):
|
||||
self._fields[name] = Field(name, typ, optional=optional)
|
||||
|
||||
_registry: dict[type, Type] = {}
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
cls = self.__class__
|
||||
if cls in self._registry:
|
||||
return self._registry[cls]
|
||||
if TypeCollector.has(cls):
|
||||
return TypeCollector.get(cls)
|
||||
|
||||
annotations = {}
|
||||
namespace = {}
|
||||
@@ -50,5 +49,5 @@ class Input(StrawberryProtocol, Generic[T]):
|
||||
namespace["__annotations__"] = annotations
|
||||
|
||||
gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace))
|
||||
Input._registry[cls] = gql_type
|
||||
TypeCollector.set(cls, gql_type)
|
||||
return gql_type
|
||||
|
||||
@@ -12,6 +12,7 @@ from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
from cpl.graphql.typing import Resolver
|
||||
from cpl.graphql.utils.type_collector import TypeCollector
|
||||
|
||||
|
||||
class Query(StrawberryProtocol):
|
||||
@@ -54,6 +55,9 @@ class Query(StrawberryProtocol):
|
||||
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, list[t], resolver)
|
||||
|
||||
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
|
||||
return self.field(name, t().to_strawberry(), resolver)
|
||||
|
||||
def with_query(self, name: str, subquery_cls: Type["Query"]):
|
||||
sub = self._provider.get_service(subquery_cls)
|
||||
if not sub:
|
||||
@@ -221,6 +225,10 @@ class Query(StrawberryProtocol):
|
||||
) from e
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
cls = self.__class__
|
||||
if TypeCollector.has(cls):
|
||||
return TypeCollector.get(cls)
|
||||
|
||||
annotations: dict[str, Any] = {}
|
||||
namespace: dict[str, Any] = {}
|
||||
|
||||
@@ -229,4 +237,6 @@ class Query(StrawberryProtocol):
|
||||
namespace[name] = self._field_to_strawberry(f)
|
||||
|
||||
namespace["__annotations__"] = annotations
|
||||
return strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
|
||||
gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
|
||||
TypeCollector.set(cls, gql_type)
|
||||
return gql_type
|
||||
|
||||
17
src/cpl-graphql/cpl/graphql/utils/type_collector.py
Normal file
17
src/cpl-graphql/cpl/graphql/utils/type_collector.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Type
|
||||
|
||||
|
||||
class TypeCollector:
|
||||
_registry: dict[type, Type] = {}
|
||||
|
||||
@classmethod
|
||||
def has(cls, base: type) -> bool:
|
||||
return base in cls._registry
|
||||
|
||||
@classmethod
|
||||
def get(cls, base: type) -> Type:
|
||||
return cls._registry[base]
|
||||
|
||||
@classmethod
|
||||
def set(cls, base: type, gql_type: Type):
|
||||
cls._registry[base] = gql_type
|
||||
Reference in New Issue
Block a user