Dao complex filtering #181
This commit is contained in:
@@ -7,9 +7,9 @@ from model.post import Post
|
|||||||
class PostFilter(Filter[Post]):
|
class PostFilter(Filter[Post]):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
Filter.__init__(self)
|
Filter.__init__(self)
|
||||||
self.field("id", int)
|
self.int_field("id")
|
||||||
self.field("title", str)
|
self.string_field("title")
|
||||||
self.field("content", str)
|
self.string_field("content")
|
||||||
|
|
||||||
class PostSort(Sort[Post]):
|
class PostSort(Sort[Post]):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -18,7 +18,6 @@ class PostSort(Sort[Post]):
|
|||||||
self.field("title", SortOrder)
|
self.field("title", SortOrder)
|
||||||
self.field("content", SortOrder)
|
self.field("content", SortOrder)
|
||||||
|
|
||||||
|
|
||||||
class PostGraphType(GraphType[Post]):
|
class PostGraphType(GraphType[Post]):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -35,4 +34,4 @@ class PostGraphType(GraphType[Post]):
|
|||||||
self.string_field(
|
self.string_field(
|
||||||
"content",
|
"content",
|
||||||
resolver=lambda root: root.content,
|
resolver=lambda root: root.content,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
from cpl.api.api_module import ApiModule
|
from cpl.api.api_module import ApiModule
|
||||||
from cpl.dependency.module.module import Module
|
from cpl.dependency.module.module import Module
|
||||||
from cpl.dependency.service_provider import ServiceProvider
|
from cpl.dependency.service_provider import ServiceProvider
|
||||||
|
from cpl.graphql.schema.filter.bool_filter import BoolFilter
|
||||||
|
from cpl.graphql.schema.filter.date_filter import DateFilter
|
||||||
|
from cpl.graphql.schema.filter.filter import Filter
|
||||||
|
from cpl.graphql.schema.filter.int_filter import IntFilter
|
||||||
|
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||||
from cpl.graphql.schema.root_query import RootQuery
|
from cpl.graphql.schema.root_query import RootQuery
|
||||||
from cpl.graphql.service.schema import Schema
|
from cpl.graphql.service.schema import Schema
|
||||||
from cpl.graphql.service.service import GraphQLService
|
from cpl.graphql.service.service import GraphQLService
|
||||||
@@ -10,6 +15,7 @@ class GraphQLModule(Module):
|
|||||||
dependencies = [ApiModule]
|
dependencies = [ApiModule]
|
||||||
singleton = [Schema, RootQuery]
|
singleton = [Schema, RootQuery]
|
||||||
scoped = [GraphQLService]
|
scoped = [GraphQLService]
|
||||||
|
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def configure(services: ServiceProvider) -> None:
|
def configure(services: ServiceProvider) -> None:
|
||||||
|
|||||||
10
src/cpl-graphql/cpl/graphql/schema/filter/bool_filter.py
Normal file
10
src/cpl-graphql/cpl/graphql/schema/filter/bool_filter.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from cpl.graphql.schema.input import Input
|
||||||
|
|
||||||
|
|
||||||
|
class BoolFilter(Input[bool]):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.field("equal", bool, optional=True)
|
||||||
|
self.field("notEqual", bool, optional=True)
|
||||||
|
self.field("isNull", bool, optional=True)
|
||||||
|
self.field("isNotNull", bool, optional=True)
|
||||||
18
src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py
Normal file
18
src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from cpl.graphql.schema.input import Input
|
||||||
|
|
||||||
|
|
||||||
|
class DateFilter(Input[datetime]):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.field("equal", datetime, optional=True)
|
||||||
|
self.field("notEqual", datetime, optional=True)
|
||||||
|
self.field("greater", datetime, optional=True)
|
||||||
|
self.field("greaterOrEqual", datetime, optional=True)
|
||||||
|
self.field("less", datetime, optional=True)
|
||||||
|
self.field("lessOrEqual", datetime, optional=True)
|
||||||
|
self.field("isNull", datetime, optional=True)
|
||||||
|
self.field("isNotNull", datetime, optional=True)
|
||||||
|
self.field("in", list[datetime], optional=True)
|
||||||
|
self.field("notIn", list[datetime], optional=True)
|
||||||
@@ -1,7 +1,23 @@
|
|||||||
from cpl.core.typing import T
|
from cpl.core.typing import T
|
||||||
|
from cpl.graphql.schema.filter.bool_filter import BoolFilter
|
||||||
|
from cpl.graphql.schema.filter.date_filter import DateFilter
|
||||||
|
from cpl.graphql.schema.filter.int_filter import IntFilter
|
||||||
|
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||||
from cpl.graphql.schema.input import Input
|
from cpl.graphql.schema.input import Input
|
||||||
|
|
||||||
|
|
||||||
class Filter(Input[T]):
|
class Filter(Input[T]):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
Input.__init__(self)
|
Input.__init__(self)
|
||||||
|
|
||||||
|
def string_field(self, name: str):
|
||||||
|
self.field(name, StringFilter())
|
||||||
|
|
||||||
|
def int_field(self, name: str):
|
||||||
|
self.field(name, IntFilter())
|
||||||
|
|
||||||
|
def bool_field(self, name: str):
|
||||||
|
self.field(name, BoolFilter())
|
||||||
|
|
||||||
|
def date_field(self, name: str):
|
||||||
|
self.field(name, DateFilter())
|
||||||
|
|||||||
16
src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py
Normal file
16
src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from cpl.graphql.schema.input import Input
|
||||||
|
|
||||||
|
|
||||||
|
class IntFilter(Input[int]):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.field("equal", int, optional=True)
|
||||||
|
self.field("notEqual", int, optional=True)
|
||||||
|
self.field("greater", int, optional=True)
|
||||||
|
self.field("greaterOrEqual", int, optional=True)
|
||||||
|
self.field("less", int, optional=True)
|
||||||
|
self.field("lessOrEqual", int, optional=True)
|
||||||
|
self.field("isNull", int, optional=True)
|
||||||
|
self.field("isNotNull", int, optional=True)
|
||||||
|
self.field("in", list[int], optional=True)
|
||||||
|
self.field("notIn", list[int], optional=True)
|
||||||
16
src/cpl-graphql/cpl/graphql/schema/filter/string_filter.py
Normal file
16
src/cpl-graphql/cpl/graphql/schema/filter/string_filter.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from cpl.graphql.schema.input import Input
|
||||||
|
|
||||||
|
|
||||||
|
class StringFilter(Input[str]):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.field("equal", str, optional=True)
|
||||||
|
self.field("notEqual", str, optional=True)
|
||||||
|
self.field("contains", str, optional=True)
|
||||||
|
self.field("notContains", str, optional=True)
|
||||||
|
self.field("startsWith", str, optional=True)
|
||||||
|
self.field("endsWith", str, optional=True)
|
||||||
|
self.field("isNull", str, optional=True)
|
||||||
|
self.field("isNotNull", str, optional=True)
|
||||||
|
self.field("in", list[str], optional=True)
|
||||||
|
self.field("notIn", list[str], optional=True)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Generic, Dict, Type, Any, Optional
|
from typing import Generic, Dict, Type, Optional, Self, Union
|
||||||
|
|
||||||
import strawberry
|
import strawberry
|
||||||
|
|
||||||
@@ -6,6 +6,7 @@ from cpl.core.typing import T
|
|||||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||||
from cpl.graphql.schema.field import Field
|
from cpl.graphql.schema.field import Field
|
||||||
|
|
||||||
|
_PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
|
||||||
|
|
||||||
class Input(StrawberryProtocol, Generic[T]):
|
class Input(StrawberryProtocol, Generic[T]):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -14,21 +15,40 @@ class Input(StrawberryProtocol, Generic[T]):
|
|||||||
def get_fields(self) -> dict[str, Field]:
|
def get_fields(self) -> dict[str, Field]:
|
||||||
return self._fields
|
return self._fields
|
||||||
|
|
||||||
def field(self, name: str, typ: type, optional: bool = True):
|
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True):
|
||||||
self._fields[name] = Field(name, typ, optional=optional)
|
self._fields[name] = Field(name, typ, optional=optional)
|
||||||
|
|
||||||
|
_registry: dict[type, Type] = {}
|
||||||
|
|
||||||
def to_strawberry(self) -> Type:
|
def to_strawberry(self) -> Type:
|
||||||
|
cls = self.__class__
|
||||||
|
if cls in self._registry:
|
||||||
|
return self._registry[cls]
|
||||||
|
|
||||||
annotations = {}
|
annotations = {}
|
||||||
namespace = {}
|
namespace = {}
|
||||||
|
|
||||||
for name, f in self._fields.items():
|
for name, f in self._fields.items():
|
||||||
ann = f.type if not f.optional else Optional[f.type]
|
typ = f.type
|
||||||
annotations[name] = ann
|
if isinstance(typ, type) and issubclass(typ, Input):
|
||||||
|
typ = typ().to_strawberry()
|
||||||
|
elif isinstance(typ, Input):
|
||||||
|
typ = typ.to_strawberry()
|
||||||
|
|
||||||
if f.optional:
|
ann = typ if not f.optional else Optional[typ]
|
||||||
namespace[name] = None
|
|
||||||
elif f.default is not None:
|
py_name = name + "_" if name in _PYTHON_KEYWORDS else name
|
||||||
namespace[name] = f.default
|
annotations[py_name] = ann
|
||||||
|
|
||||||
|
field_args = {}
|
||||||
|
if py_name != name:
|
||||||
|
field_args["name"] = name
|
||||||
|
|
||||||
|
default = None if f.optional else f.default
|
||||||
|
namespace[py_name] = strawberry.field(default=default, **field_args)
|
||||||
|
|
||||||
namespace["__annotations__"] = annotations
|
namespace["__annotations__"] = annotations
|
||||||
return strawberry.input(type(f"{self.__class__.__name__}Input", (), namespace))
|
|
||||||
|
gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace))
|
||||||
|
Input._registry[cls] = gql_type
|
||||||
|
return gql_type
|
||||||
|
|||||||
@@ -122,9 +122,29 @@ class Query(StrawberryProtocol):
|
|||||||
if not sort:
|
if not sort:
|
||||||
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
|
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
|
||||||
|
|
||||||
|
def input_to_dict(obj) -> dict | None:
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for k, v in obj.__dict__.items():
|
||||||
|
if v is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# verschachtelte Inputs rekursiv
|
||||||
|
if hasattr(v, "__dict__"):
|
||||||
|
result[k] = input_to_dict(v)
|
||||||
|
else:
|
||||||
|
result[k] = v
|
||||||
|
return result
|
||||||
|
|
||||||
async def _resolver(filter=None, sort=None, take=10, skip=0):
|
async def _resolver(filter=None, sort=None, take=10, skip=0):
|
||||||
|
filter_dict = input_to_dict(filter) if filter is not None else None
|
||||||
sort_dict = None
|
sort_dict = None
|
||||||
|
|
||||||
|
if filter is not None:
|
||||||
|
pass
|
||||||
|
|
||||||
if sort is not None:
|
if sort is not None:
|
||||||
sort_dict = {}
|
sort_dict = {}
|
||||||
for k, v in sort.__dict__.items():
|
for k, v in sort.__dict__.items():
|
||||||
@@ -137,8 +157,8 @@ class Query(StrawberryProtocol):
|
|||||||
|
|
||||||
sort_dict[k] = str(v).lower()
|
sort_dict[k] = str(v).lower()
|
||||||
|
|
||||||
total_count = await dao.count(filter)
|
total_count = await dao.count(filter_dict)
|
||||||
data = await dao.find_by(filter, sort_dict, take, skip)
|
data = await dao.find_by(filter_dict, sort_dict, take, skip)
|
||||||
return Collection(nodes=data, total_count=total_count, count=len(data))
|
return Collection(nodes=data, total_count=total_count, count=len(data))
|
||||||
|
|
||||||
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver)
|
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver)
|
||||||
|
|||||||
Reference in New Issue
Block a user