Recursive filter #181
All checks were successful
Test before pr merge / test-lint (pull_request) Successful in 5s

This commit is contained in:
2025-09-28 22:06:50 +02:00
parent aeeb5d06c8
commit bd13bbca5d
4 changed files with 45 additions and 16 deletions

View File

@@ -0,0 +1,11 @@
from cpl.auth.schema import AuthUser
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
class AuthUserFilter(DbModelFilter[AuthUser]):
def __init__(self, public: bool = False):
DbModelFilter.__init__(self, public)
self.field("username", StringFilter).with_public(public)
self.field("email", StringFilter).with_public(public)

View File

@@ -1,6 +1,7 @@
from cpl.dependency.module.module import Module from cpl.dependency.module.module import Module
from cpl.graphql.auth.administration.auth_user_filter import AuthUserFilter
from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType
class GraphQLAuthModule(Module): class GraphQLAuthModule(Module):
transient = [AuthUserGraphType] transient = [AuthUserGraphType, AuthUserFilter]

View File

@@ -1,12 +1,10 @@
from typing import Type, Generic from typing import Generic
from cpl.core.typing import T from cpl.core.typing import T
from cpl.graphql.schema.filter.bool_filter import BoolFilter from cpl.graphql.schema.filter.bool_filter import BoolFilter
from cpl.graphql.schema.filter.date_filter import DateFilter from cpl.graphql.schema.filter.date_filter import DateFilter
from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.filter.int_filter import IntFilter from cpl.graphql.schema.filter.int_filter import IntFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
from cpl.graphql.schema.input import Input
class DbModelFilter(Filter[T], Generic[T]): class DbModelFilter(Filter[T], Generic[T]):
@@ -15,6 +13,8 @@ class DbModelFilter(Filter[T], Generic[T]):
self.field("id", IntFilter).with_public(public) self.field("id", IntFilter).with_public(public)
self.field("deleted", BoolFilter).with_public(public) self.field("deleted", BoolFilter).with_public(public)
# self.field("editor", AuthUserFilter) from cpl.graphql.auth.administration.auth_user_filter import AuthUserFilter
self.field("editor", lambda: AuthUserFilter).with_public(public)
self.field("created", DateFilter).with_public(public) self.field("created", DateFilter).with_public(public)
self.field("updated", DateFilter).with_public(public) self.field("updated", DateFilter).with_public(public)

View File

@@ -1,8 +1,10 @@
import types
from typing import Generic, Dict, Type, Optional, Union, Any from typing import Generic, Dict, Type, Optional, Union, Any
import strawberry import strawberry
from cpl.core.typing import T from cpl.core.typing import T
from cpl.dependency import get_provider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.field import Field from cpl.graphql.schema.field import Field
from cpl.graphql.typing import AttributeName from cpl.graphql.typing import AttributeName
@@ -39,7 +41,7 @@ class Input(StrawberryProtocol, Generic[T]):
def get_fields(self) -> dict[str, Field]: def get_fields(self) -> dict[str, Field]:
return self._fields return self._fields
def field(self, name: AttributeName, typ: Union[type, "Input"], optional: bool = True) -> Field: def field(self, name: AttributeName, typ: type, optional: bool = True) -> Field:
if isinstance(name, property): if isinstance(name, property):
name = name.fget.__name__ name = name.fget.__name__
@@ -62,6 +64,9 @@ class Input(StrawberryProtocol, Generic[T]):
return self.field(name, list[t], optional) return self.field(name, list[t], optional)
def object_field(self, name: AttributeName, t: Type[StrawberryProtocol], optional: bool = True) -> Field: def object_field(self, name: AttributeName, t: Type[StrawberryProtocol], optional: bool = True) -> Field:
if not isinstance(t, type) and callable(t):
return self.field(name, t, optional)
return self.field(name, t().to_strawberry(), optional) return self.field(name, t().to_strawberry(), optional)
def to_strawberry(self) -> Type: def to_strawberry(self) -> Type:
@@ -69,20 +74,28 @@ class Input(StrawberryProtocol, Generic[T]):
if TypeCollector.has(cls): if TypeCollector.has(cls):
return TypeCollector.get(cls) return TypeCollector.get(cls)
annotations = {} gql_cls = type(f"{cls.__name__.replace('GraphType', '')}", (), {})
namespace = {} # register early to handle recursive types
TypeCollector.set(cls, gql_cls)
annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {}
for name, f in self._fields.items(): for name, f in self._fields.items():
typ = f.type t = f.type
if isinstance(typ, type) and issubclass(typ, Input):
typ = typ().to_strawberry()
elif isinstance(typ, Input):
typ = typ.to_strawberry()
ann = typ if not f.optional else Optional[typ] if isinstance(t, types.FunctionType):
_t = get_provider().get_service(t())
if _t is None:
raise ValueError(f"'{t()}' could not be resolved from the provider")
t = _t.to_strawberry()
elif isinstance(t, type) and issubclass(t, Input):
t = t().to_strawberry()
elif isinstance(t, Input):
t = t.to_strawberry()
py_name = name + "_" if name in _PYTHON_KEYWORDS else name py_name = name + "_" if name in _PYTHON_KEYWORDS else name
annotations[py_name] = ann annotations[py_name] = t if not f.optional else Optional[t]
field_args = {} field_args = {}
if py_name != name: if py_name != name:
@@ -93,6 +106,10 @@ class Input(StrawberryProtocol, Generic[T]):
namespace["__annotations__"] = annotations namespace["__annotations__"] = annotations
gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace)) for k, v in namespace.items():
setattr(gql_cls, k, v)
gql_cls.__annotations__ = annotations
gql_type = strawberry.input(gql_cls)
TypeCollector.set(cls, gql_type) TypeCollector.set(cls, gql_type)
return gql_type return gql_type