Dao complex filtering #181
Some checks failed
Test before pr merge / test-lint (pull_request) Failing after 6s

This commit is contained in:
2025-09-27 22:35:48 +02:00
parent af7945fe92
commit 2634ad3dca
9 changed files with 137 additions and 16 deletions

View File

@@ -7,9 +7,9 @@ from model.post import Post
class PostFilter(Filter[Post]):
def __init__(self):
Filter.__init__(self)
self.field("id", int)
self.field("title", str)
self.field("content", str)
self.int_field("id")
self.string_field("title")
self.string_field("content")
class PostSort(Sort[Post]):
def __init__(self):
@@ -18,7 +18,6 @@ class PostSort(Sort[Post]):
self.field("title", SortOrder)
self.field("content", SortOrder)
class PostGraphType(GraphType[Post]):
def __init__(self):

View File

@@ -1,6 +1,11 @@
from cpl.api.api_module import ApiModule
from cpl.dependency.module.module import Module
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.schema.filter.bool_filter import BoolFilter
from cpl.graphql.schema.filter.date_filter import DateFilter
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.filter.int_filter import IntFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
from cpl.graphql.schema.root_query import RootQuery
from cpl.graphql.service.schema import Schema
from cpl.graphql.service.service import GraphQLService
@@ -10,6 +15,7 @@ class GraphQLModule(Module):
dependencies = [ApiModule]
singleton = [Schema, RootQuery]
scoped = [GraphQLService]
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]
@staticmethod
def configure(services: ServiceProvider) -> None:

View File

@@ -0,0 +1,10 @@
from cpl.graphql.schema.input import Input
class BoolFilter(Input[bool]):
def __init__(self):
super().__init__()
self.field("equal", bool, optional=True)
self.field("notEqual", bool, optional=True)
self.field("isNull", bool, optional=True)
self.field("isNotNull", bool, optional=True)

View File

@@ -0,0 +1,18 @@
from datetime import datetime
from cpl.graphql.schema.input import Input
class DateFilter(Input[datetime]):
def __init__(self):
super().__init__()
self.field("equal", datetime, optional=True)
self.field("notEqual", datetime, optional=True)
self.field("greater", datetime, optional=True)
self.field("greaterOrEqual", datetime, optional=True)
self.field("less", datetime, optional=True)
self.field("lessOrEqual", datetime, optional=True)
self.field("isNull", datetime, optional=True)
self.field("isNotNull", datetime, optional=True)
self.field("in", list[datetime], optional=True)
self.field("notIn", list[datetime], optional=True)

View File

@@ -1,7 +1,23 @@
from cpl.core.typing import T
from cpl.graphql.schema.filter.bool_filter import BoolFilter
from cpl.graphql.schema.filter.date_filter import DateFilter
from cpl.graphql.schema.filter.int_filter import IntFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
from cpl.graphql.schema.input import Input
class Filter(Input[T]):
def __init__(self):
Input.__init__(self)
def string_field(self, name: str):
self.field(name, StringFilter())
def int_field(self, name: str):
self.field(name, IntFilter())
def bool_field(self, name: str):
self.field(name, BoolFilter())
def date_field(self, name: str):
self.field(name, DateFilter())

View File

@@ -0,0 +1,16 @@
from cpl.graphql.schema.input import Input
class IntFilter(Input[int]):
def __init__(self):
super().__init__()
self.field("equal", int, optional=True)
self.field("notEqual", int, optional=True)
self.field("greater", int, optional=True)
self.field("greaterOrEqual", int, optional=True)
self.field("less", int, optional=True)
self.field("lessOrEqual", int, optional=True)
self.field("isNull", int, optional=True)
self.field("isNotNull", int, optional=True)
self.field("in", list[int], optional=True)
self.field("notIn", list[int], optional=True)

View File

@@ -0,0 +1,16 @@
from cpl.graphql.schema.input import Input
class StringFilter(Input[str]):
def __init__(self):
super().__init__()
self.field("equal", str, optional=True)
self.field("notEqual", str, optional=True)
self.field("contains", str, optional=True)
self.field("notContains", str, optional=True)
self.field("startsWith", str, optional=True)
self.field("endsWith", str, optional=True)
self.field("isNull", str, optional=True)
self.field("isNotNull", str, optional=True)
self.field("in", list[str], optional=True)
self.field("notIn", list[str], optional=True)

View File

@@ -1,4 +1,4 @@
from typing import Generic, Dict, Type, Any, Optional
from typing import Generic, Dict, Type, Optional, Self, Union
import strawberry
@@ -6,6 +6,7 @@ from cpl.core.typing import T
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.field import Field
_PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
class Input(StrawberryProtocol, Generic[T]):
def __init__(self):
@@ -14,21 +15,40 @@ class Input(StrawberryProtocol, Generic[T]):
def get_fields(self) -> dict[str, Field]:
return self._fields
def field(self, name: str, typ: type, optional: bool = True):
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True):
self._fields[name] = Field(name, typ, optional=optional)
_registry: dict[type, Type] = {}
def to_strawberry(self) -> Type:
cls = self.__class__
if cls in self._registry:
return self._registry[cls]
annotations = {}
namespace = {}
for name, f in self._fields.items():
ann = f.type if not f.optional else Optional[f.type]
annotations[name] = ann
typ = f.type
if isinstance(typ, type) and issubclass(typ, Input):
typ = typ().to_strawberry()
elif isinstance(typ, Input):
typ = typ.to_strawberry()
if f.optional:
namespace[name] = None
elif f.default is not None:
namespace[name] = f.default
ann = typ if not f.optional else Optional[typ]
py_name = name + "_" if name in _PYTHON_KEYWORDS else name
annotations[py_name] = ann
field_args = {}
if py_name != name:
field_args["name"] = name
default = None if f.optional else f.default
namespace[py_name] = strawberry.field(default=default, **field_args)
namespace["__annotations__"] = annotations
return strawberry.input(type(f"{self.__class__.__name__}Input", (), namespace))
gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace))
Input._registry[cls] = gql_type
return gql_type

View File

@@ -122,9 +122,29 @@ class Query(StrawberryProtocol):
if not sort:
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
def input_to_dict(obj) -> dict | None:
if obj is None:
return None
result = {}
for k, v in obj.__dict__.items():
if v is None:
continue
# verschachtelte Inputs rekursiv
if hasattr(v, "__dict__"):
result[k] = input_to_dict(v)
else:
result[k] = v
return result
async def _resolver(filter=None, sort=None, take=10, skip=0):
filter_dict = input_to_dict(filter) if filter is not None else None
sort_dict = None
if filter is not None:
pass
if sort is not None:
sort_dict = {}
for k, v in sort.__dict__.items():
@@ -137,8 +157,8 @@ class Query(StrawberryProtocol):
sort_dict[k] = str(v).lower()
total_count = await dao.count(filter)
data = await dao.find_by(filter, sort_dict, take, skip)
total_count = await dao.count(filter_dict)
data = await dao.find_by(filter_dict, sort_dict, take, skip)
return Collection(nodes=data, total_count=total_count, count=len(data))
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver)