From df69f1c7256572e1cdb62227ccd0d0157b90c431 Mon Sep 17 00:00:00 2001 From: edraft Date: Sun, 28 Sep 2025 22:06:50 +0200 Subject: [PATCH] Recursive filter #181 --- .../auth/administration/auth_user_filter.py | 11 ++++++ .../cpl/graphql/auth/graphql_auth_module.py | 3 +- .../graphql/schema/filter/db_model_filter.py | 8 ++-- src/cpl-graphql/cpl/graphql/schema/input.py | 39 +++++++++++++------ 4 files changed, 45 insertions(+), 16 deletions(-) create mode 100644 src/cpl-graphql/cpl/graphql/auth/administration/auth_user_filter.py diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/auth_user_filter.py b/src/cpl-graphql/cpl/graphql/auth/administration/auth_user_filter.py new file mode 100644 index 00000000..19264a46 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/administration/auth_user_filter.py @@ -0,0 +1,11 @@ +from cpl.auth.schema import AuthUser +from cpl.graphql.schema.filter.db_model_filter import DbModelFilter +from cpl.graphql.schema.filter.string_filter import StringFilter + + +class AuthUserFilter(DbModelFilter[AuthUser]): + def __init__(self, public: bool = False): + DbModelFilter.__init__(self, public) + + self.field("username", StringFilter).with_public(public) + self.field("email", StringFilter).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py index dc53f754..a0724910 100644 --- a/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py +++ b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py @@ -1,6 +1,7 @@ from cpl.dependency.module.module import Module +from cpl.graphql.auth.administration.auth_user_filter import AuthUserFilter from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType class GraphQLAuthModule(Module): - transient = [AuthUserGraphType] + transient = [AuthUserGraphType, AuthUserFilter] diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py index 860712fe..aa4fb4d8 100644 --- a/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py +++ b/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py @@ -1,12 +1,10 @@ -from typing import Type, Generic +from typing import Generic from cpl.core.typing import T from cpl.graphql.schema.filter.bool_filter import BoolFilter from cpl.graphql.schema.filter.date_filter import DateFilter from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.filter.int_filter import IntFilter -from cpl.graphql.schema.filter.string_filter import StringFilter -from cpl.graphql.schema.input import Input class DbModelFilter(Filter[T], Generic[T]): @@ -15,6 +13,8 @@ class DbModelFilter(Filter[T], Generic[T]): self.field("id", IntFilter).with_public(public) self.field("deleted", BoolFilter).with_public(public) - # self.field("editor", AuthUserFilter) + from cpl.graphql.auth.administration.auth_user_filter import AuthUserFilter + + self.field("editor", lambda: AuthUserFilter).with_public(public) self.field("created", DateFilter).with_public(public) self.field("updated", DateFilter).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/schema/input.py b/src/cpl-graphql/cpl/graphql/schema/input.py index bcba7ae0..ce7817ab 100644 --- a/src/cpl-graphql/cpl/graphql/schema/input.py +++ b/src/cpl-graphql/cpl/graphql/schema/input.py @@ -1,8 +1,10 @@ +import types from typing import Generic, Dict, Type, Optional, Union, Any import strawberry from cpl.core.typing import T +from cpl.dependency import get_provider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.schema.field import Field from cpl.graphql.typing import AttributeName @@ -39,7 +41,7 @@ class Input(StrawberryProtocol, Generic[T]): def get_fields(self) -> dict[str, Field]: return self._fields - def field(self, name: AttributeName, typ: Union[type, "Input"], optional: bool = True) -> Field: + def field(self, name: AttributeName, typ: type, optional: bool = True) -> Field: if isinstance(name, property): name = name.fget.__name__ @@ -62,6 +64,9 @@ class Input(StrawberryProtocol, Generic[T]): return self.field(name, list[t], optional) def object_field(self, name: AttributeName, t: Type[StrawberryProtocol], optional: bool = True) -> Field: + if not isinstance(t, type) and callable(t): + return self.field(name, t, optional) + return self.field(name, t().to_strawberry(), optional) def to_strawberry(self) -> Type: @@ -69,20 +74,28 @@ class Input(StrawberryProtocol, Generic[T]): if TypeCollector.has(cls): return TypeCollector.get(cls) - annotations = {} - namespace = {} + gql_cls = type(f"{cls.__name__.replace('GraphType', '')}", (), {}) + # register early to handle recursive types + TypeCollector.set(cls, gql_cls) + + annotations: dict[str, Any] = {} + namespace: dict[str, Any] = {} for name, f in self._fields.items(): - typ = f.type - if isinstance(typ, type) and issubclass(typ, Input): - typ = typ().to_strawberry() - elif isinstance(typ, Input): - typ = typ.to_strawberry() + t = f.type - ann = typ if not f.optional else Optional[typ] + if isinstance(t, types.FunctionType): + _t = get_provider().get_service(t()) + if _t is None: + raise ValueError(f"'{t()}' could not be resolved from the provider") + t = _t.to_strawberry() + elif isinstance(t, type) and issubclass(t, Input): + t = t().to_strawberry() + elif isinstance(t, Input): + t = t.to_strawberry() py_name = name + "_" if name in _PYTHON_KEYWORDS else name - annotations[py_name] = ann + annotations[py_name] = t if not f.optional else Optional[t] field_args = {} if py_name != name: @@ -93,6 +106,10 @@ class Input(StrawberryProtocol, Generic[T]): namespace["__annotations__"] = annotations - gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace)) + for k, v in namespace.items(): + setattr(gql_cls, k, v) + + gql_cls.__annotations__ = annotations + gql_type = strawberry.input(gql_cls) TypeCollector.set(cls, gql_type) return gql_type