From a12a4082dbdb6fd75404515bcd755be8520b906d Mon Sep 17 00:00:00 2001 From: edraft Date: Sat, 27 Sep 2025 22:35:48 +0200 Subject: [PATCH] Dao complex filtering #181 --- example/api/src/model/post_query.py | 9 ++--- src/cpl-graphql/cpl/graphql/graphql_module.py | 6 +++ .../cpl/graphql/schema/filter/bool_filter.py | 10 +++++ .../cpl/graphql/schema/filter/date_filter.py | 18 +++++++++ .../cpl/graphql/schema/filter/filter.py | 16 ++++++++ .../cpl/graphql/schema/filter/int_filter.py | 16 ++++++++ .../graphql/schema/filter/string_filter.py | 16 ++++++++ src/cpl-graphql/cpl/graphql/schema/input.py | 38 ++++++++++++++----- src/cpl-graphql/cpl/graphql/schema/query.py | 24 +++++++++++- 9 files changed, 137 insertions(+), 16 deletions(-) create mode 100644 src/cpl-graphql/cpl/graphql/schema/filter/bool_filter.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/filter/string_filter.py diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py index 6e25dddc..2e5b2998 100644 --- a/example/api/src/model/post_query.py +++ b/example/api/src/model/post_query.py @@ -7,9 +7,9 @@ from model.post import Post class PostFilter(Filter[Post]): def __init__(self): Filter.__init__(self) - self.field("id", int) - self.field("title", str) - self.field("content", str) + self.int_field("id") + self.string_field("title") + self.string_field("content") class PostSort(Sort[Post]): def __init__(self): @@ -18,7 +18,6 @@ class PostSort(Sort[Post]): self.field("title", SortOrder) self.field("content", SortOrder) - class PostGraphType(GraphType[Post]): def __init__(self): @@ -35,4 +34,4 @@ class PostGraphType(GraphType[Post]): self.string_field( "content", resolver=lambda root: root.content, - ) \ No newline at end of file + ) diff --git a/src/cpl-graphql/cpl/graphql/graphql_module.py b/src/cpl-graphql/cpl/graphql/graphql_module.py index 70efa400..2d5d6b93 100644 --- a/src/cpl-graphql/cpl/graphql/graphql_module.py +++ b/src/cpl-graphql/cpl/graphql/graphql_module.py @@ -1,6 +1,11 @@ from cpl.api.api_module import ApiModule from cpl.dependency.module.module import Module from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.schema.filter.bool_filter import BoolFilter +from cpl.graphql.schema.filter.date_filter import DateFilter +from cpl.graphql.schema.filter.filter import Filter +from cpl.graphql.schema.filter.int_filter import IntFilter +from cpl.graphql.schema.filter.string_filter import StringFilter from cpl.graphql.schema.root_query import RootQuery from cpl.graphql.service.schema import Schema from cpl.graphql.service.service import GraphQLService @@ -10,6 +15,7 @@ class GraphQLModule(Module): dependencies = [ApiModule] singleton = [Schema, RootQuery] scoped = [GraphQLService] + transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter] @staticmethod def configure(services: ServiceProvider) -> None: diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/bool_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/bool_filter.py new file mode 100644 index 00000000..4be0db85 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/filter/bool_filter.py @@ -0,0 +1,10 @@ +from cpl.graphql.schema.input import Input + + +class BoolFilter(Input[bool]): + def __init__(self): + super().__init__() + self.field("equal", bool, optional=True) + self.field("notEqual", bool, optional=True) + self.field("isNull", bool, optional=True) + self.field("isNotNull", bool, optional=True) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py new file mode 100644 index 00000000..2dd1bcf8 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py @@ -0,0 +1,18 @@ +from datetime import datetime + +from cpl.graphql.schema.input import Input + + +class DateFilter(Input[datetime]): + def __init__(self): + super().__init__() + self.field("equal", datetime, optional=True) + self.field("notEqual", datetime, optional=True) + self.field("greater", datetime, optional=True) + self.field("greaterOrEqual", datetime, optional=True) + self.field("less", datetime, optional=True) + self.field("lessOrEqual", datetime, optional=True) + self.field("isNull", datetime, optional=True) + self.field("isNotNull", datetime, optional=True) + self.field("in", list[datetime], optional=True) + self.field("notIn", list[datetime], optional=True) \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py index 2f76c4b4..d1d502e2 100644 --- a/src/cpl-graphql/cpl/graphql/schema/filter/filter.py +++ b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py @@ -1,7 +1,23 @@ from cpl.core.typing import T +from cpl.graphql.schema.filter.bool_filter import BoolFilter +from cpl.graphql.schema.filter.date_filter import DateFilter +from cpl.graphql.schema.filter.int_filter import IntFilter +from cpl.graphql.schema.filter.string_filter import StringFilter from cpl.graphql.schema.input import Input class Filter(Input[T]): def __init__(self): Input.__init__(self) + + def string_field(self, name: str): + self.field(name, StringFilter()) + + def int_field(self, name: str): + self.field(name, IntFilter()) + + def bool_field(self, name: str): + self.field(name, BoolFilter()) + + def date_field(self, name: str): + self.field(name, DateFilter()) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py new file mode 100644 index 00000000..be9eba74 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py @@ -0,0 +1,16 @@ +from cpl.graphql.schema.input import Input + + +class IntFilter(Input[int]): + def __init__(self): + super().__init__() + self.field("equal", int, optional=True) + self.field("notEqual", int, optional=True) + self.field("greater", int, optional=True) + self.field("greaterOrEqual", int, optional=True) + self.field("less", int, optional=True) + self.field("lessOrEqual", int, optional=True) + self.field("isNull", int, optional=True) + self.field("isNotNull", int, optional=True) + self.field("in", list[int], optional=True) + self.field("notIn", list[int], optional=True) \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/string_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/string_filter.py new file mode 100644 index 00000000..7c060abc --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/filter/string_filter.py @@ -0,0 +1,16 @@ +from cpl.graphql.schema.input import Input + + +class StringFilter(Input[str]): + def __init__(self): + super().__init__() + self.field("equal", str, optional=True) + self.field("notEqual", str, optional=True) + self.field("contains", str, optional=True) + self.field("notContains", str, optional=True) + self.field("startsWith", str, optional=True) + self.field("endsWith", str, optional=True) + self.field("isNull", str, optional=True) + self.field("isNotNull", str, optional=True) + self.field("in", list[str], optional=True) + self.field("notIn", list[str], optional=True) diff --git a/src/cpl-graphql/cpl/graphql/schema/input.py b/src/cpl-graphql/cpl/graphql/schema/input.py index 4c9afc86..82ff31de 100644 --- a/src/cpl-graphql/cpl/graphql/schema/input.py +++ b/src/cpl-graphql/cpl/graphql/schema/input.py @@ -1,4 +1,4 @@ -from typing import Generic, Dict, Type, Any, Optional +from typing import Generic, Dict, Type, Optional, Self, Union import strawberry @@ -6,6 +6,7 @@ from cpl.core.typing import T from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.schema.field import Field +_PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"} class Input(StrawberryProtocol, Generic[T]): def __init__(self): @@ -14,21 +15,40 @@ class Input(StrawberryProtocol, Generic[T]): def get_fields(self) -> dict[str, Field]: return self._fields - def field(self, name: str, typ: type, optional: bool = True): + def field(self, name: str, typ: Union[type, "Input"], optional: bool = True): self._fields[name] = Field(name, typ, optional=optional) + _registry: dict[type, Type] = {} + def to_strawberry(self) -> Type: + cls = self.__class__ + if cls in self._registry: + return self._registry[cls] + annotations = {} namespace = {} for name, f in self._fields.items(): - ann = f.type if not f.optional else Optional[f.type] - annotations[name] = ann + typ = f.type + if isinstance(typ, type) and issubclass(typ, Input): + typ = typ().to_strawberry() + elif isinstance(typ, Input): + typ = typ.to_strawberry() - if f.optional: - namespace[name] = None - elif f.default is not None: - namespace[name] = f.default + ann = typ if not f.optional else Optional[typ] + + py_name = name + "_" if name in _PYTHON_KEYWORDS else name + annotations[py_name] = ann + + field_args = {} + if py_name != name: + field_args["name"] = name + + default = None if f.optional else f.default + namespace[py_name] = strawberry.field(default=default, **field_args) namespace["__annotations__"] = annotations - return strawberry.input(type(f"{self.__class__.__name__}Input", (), namespace)) + + gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace)) + Input._registry[cls] = gql_type + return gql_type diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py index 5539c4c6..736de81f 100644 --- a/src/cpl-graphql/cpl/graphql/schema/query.py +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -122,9 +122,29 @@ class Query(StrawberryProtocol): if not sort: raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider") + def input_to_dict(obj) -> dict | None: + if obj is None: + return None + + result = {} + for k, v in obj.__dict__.items(): + if v is None: + continue + + # verschachtelte Inputs rekursiv + if hasattr(v, "__dict__"): + result[k] = input_to_dict(v) + else: + result[k] = v + return result + async def _resolver(filter=None, sort=None, take=10, skip=0): + filter_dict = input_to_dict(filter) if filter is not None else None sort_dict = None + if filter is not None: + pass + if sort is not None: sort_dict = {} for k, v in sort.__dict__.items(): @@ -137,8 +157,8 @@ class Query(StrawberryProtocol): sort_dict[k] = str(v).lower() - total_count = await dao.count(filter) - data = await dao.find_by(filter, sort_dict, take, skip) + total_count = await dao.count(filter_dict) + data = await dao.find_by(filter_dict, sort_dict, take, skip) return Collection(nodes=data, total_count=total_count, count=len(data)) f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver)