Dao complex filtering #181
This commit is contained in:
@@ -1,6 +1,11 @@
|
||||
from cpl.api.api_module import ApiModule
|
||||
from cpl.dependency.module.module import Module
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.schema.filter.bool_filter import BoolFilter
|
||||
from cpl.graphql.schema.filter.date_filter import DateFilter
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
from cpl.graphql.schema.filter.int_filter import IntFilter
|
||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||
from cpl.graphql.schema.root_query import RootQuery
|
||||
from cpl.graphql.service.schema import Schema
|
||||
from cpl.graphql.service.service import GraphQLService
|
||||
@@ -10,6 +15,7 @@ class GraphQLModule(Module):
|
||||
dependencies = [ApiModule]
|
||||
singleton = [Schema, RootQuery]
|
||||
scoped = [GraphQLService]
|
||||
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]
|
||||
|
||||
@staticmethod
|
||||
def configure(services: ServiceProvider) -> None:
|
||||
|
||||
10
src/cpl-graphql/cpl/graphql/schema/filter/bool_filter.py
Normal file
10
src/cpl-graphql/cpl/graphql/schema/filter/bool_filter.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class BoolFilter(Input[bool]):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.field("equal", bool, optional=True)
|
||||
self.field("notEqual", bool, optional=True)
|
||||
self.field("isNull", bool, optional=True)
|
||||
self.field("isNotNull", bool, optional=True)
|
||||
18
src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py
Normal file
18
src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from datetime import datetime
|
||||
|
||||
from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class DateFilter(Input[datetime]):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.field("equal", datetime, optional=True)
|
||||
self.field("notEqual", datetime, optional=True)
|
||||
self.field("greater", datetime, optional=True)
|
||||
self.field("greaterOrEqual", datetime, optional=True)
|
||||
self.field("less", datetime, optional=True)
|
||||
self.field("lessOrEqual", datetime, optional=True)
|
||||
self.field("isNull", datetime, optional=True)
|
||||
self.field("isNotNull", datetime, optional=True)
|
||||
self.field("in", list[datetime], optional=True)
|
||||
self.field("notIn", list[datetime], optional=True)
|
||||
@@ -1,7 +1,23 @@
|
||||
from cpl.core.typing import T
|
||||
from cpl.graphql.schema.filter.bool_filter import BoolFilter
|
||||
from cpl.graphql.schema.filter.date_filter import DateFilter
|
||||
from cpl.graphql.schema.filter.int_filter import IntFilter
|
||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||
from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class Filter(Input[T]):
|
||||
def __init__(self):
|
||||
Input.__init__(self)
|
||||
|
||||
def string_field(self, name: str):
|
||||
self.field(name, StringFilter())
|
||||
|
||||
def int_field(self, name: str):
|
||||
self.field(name, IntFilter())
|
||||
|
||||
def bool_field(self, name: str):
|
||||
self.field(name, BoolFilter())
|
||||
|
||||
def date_field(self, name: str):
|
||||
self.field(name, DateFilter())
|
||||
|
||||
16
src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py
Normal file
16
src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class IntFilter(Input[int]):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.field("equal", int, optional=True)
|
||||
self.field("notEqual", int, optional=True)
|
||||
self.field("greater", int, optional=True)
|
||||
self.field("greaterOrEqual", int, optional=True)
|
||||
self.field("less", int, optional=True)
|
||||
self.field("lessOrEqual", int, optional=True)
|
||||
self.field("isNull", int, optional=True)
|
||||
self.field("isNotNull", int, optional=True)
|
||||
self.field("in", list[int], optional=True)
|
||||
self.field("notIn", list[int], optional=True)
|
||||
16
src/cpl-graphql/cpl/graphql/schema/filter/string_filter.py
Normal file
16
src/cpl-graphql/cpl/graphql/schema/filter/string_filter.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class StringFilter(Input[str]):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.field("equal", str, optional=True)
|
||||
self.field("notEqual", str, optional=True)
|
||||
self.field("contains", str, optional=True)
|
||||
self.field("notContains", str, optional=True)
|
||||
self.field("startsWith", str, optional=True)
|
||||
self.field("endsWith", str, optional=True)
|
||||
self.field("isNull", str, optional=True)
|
||||
self.field("isNotNull", str, optional=True)
|
||||
self.field("in", list[str], optional=True)
|
||||
self.field("notIn", list[str], optional=True)
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Generic, Dict, Type, Any, Optional
|
||||
from typing import Generic, Dict, Type, Optional, Self, Union
|
||||
|
||||
import strawberry
|
||||
|
||||
@@ -6,6 +6,7 @@ from cpl.core.typing import T
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.schema.field import Field
|
||||
|
||||
_PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
|
||||
|
||||
class Input(StrawberryProtocol, Generic[T]):
|
||||
def __init__(self):
|
||||
@@ -14,21 +15,40 @@ class Input(StrawberryProtocol, Generic[T]):
|
||||
def get_fields(self) -> dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
def field(self, name: str, typ: type, optional: bool = True):
|
||||
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True):
|
||||
self._fields[name] = Field(name, typ, optional=optional)
|
||||
|
||||
_registry: dict[type, Type] = {}
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
cls = self.__class__
|
||||
if cls in self._registry:
|
||||
return self._registry[cls]
|
||||
|
||||
annotations = {}
|
||||
namespace = {}
|
||||
|
||||
for name, f in self._fields.items():
|
||||
ann = f.type if not f.optional else Optional[f.type]
|
||||
annotations[name] = ann
|
||||
typ = f.type
|
||||
if isinstance(typ, type) and issubclass(typ, Input):
|
||||
typ = typ().to_strawberry()
|
||||
elif isinstance(typ, Input):
|
||||
typ = typ.to_strawberry()
|
||||
|
||||
if f.optional:
|
||||
namespace[name] = None
|
||||
elif f.default is not None:
|
||||
namespace[name] = f.default
|
||||
ann = typ if not f.optional else Optional[typ]
|
||||
|
||||
py_name = name + "_" if name in _PYTHON_KEYWORDS else name
|
||||
annotations[py_name] = ann
|
||||
|
||||
field_args = {}
|
||||
if py_name != name:
|
||||
field_args["name"] = name
|
||||
|
||||
default = None if f.optional else f.default
|
||||
namespace[py_name] = strawberry.field(default=default, **field_args)
|
||||
|
||||
namespace["__annotations__"] = annotations
|
||||
return strawberry.input(type(f"{self.__class__.__name__}Input", (), namespace))
|
||||
|
||||
gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace))
|
||||
Input._registry[cls] = gql_type
|
||||
return gql_type
|
||||
|
||||
@@ -122,9 +122,29 @@ class Query(StrawberryProtocol):
|
||||
if not sort:
|
||||
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
|
||||
|
||||
def input_to_dict(obj) -> dict | None:
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
result = {}
|
||||
for k, v in obj.__dict__.items():
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
# verschachtelte Inputs rekursiv
|
||||
if hasattr(v, "__dict__"):
|
||||
result[k] = input_to_dict(v)
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
async def _resolver(filter=None, sort=None, take=10, skip=0):
|
||||
filter_dict = input_to_dict(filter) if filter is not None else None
|
||||
sort_dict = None
|
||||
|
||||
if filter is not None:
|
||||
pass
|
||||
|
||||
if sort is not None:
|
||||
sort_dict = {}
|
||||
for k, v in sort.__dict__.items():
|
||||
@@ -137,8 +157,8 @@ class Query(StrawberryProtocol):
|
||||
|
||||
sort_dict[k] = str(v).lower()
|
||||
|
||||
total_count = await dao.count(filter)
|
||||
data = await dao.find_by(filter, sort_dict, take, skip)
|
||||
total_count = await dao.count(filter_dict)
|
||||
data = await dao.find_by(filter_dict, sort_dict, take, skip)
|
||||
return Collection(nodes=data, total_count=total_count, count=len(data))
|
||||
|
||||
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver)
|
||||
|
||||
Reference in New Issue
Block a user