Recursive filter #181
All checks were successful
Test before pr merge / test-lint (pull_request) Successful in 5s
All checks were successful
Test before pr merge / test-lint (pull_request) Successful in 5s
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
from cpl.auth.schema import AuthUser
|
||||
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
|
||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||
|
||||
|
||||
class AuthUserFilter(DbModelFilter[AuthUser]):
|
||||
def __init__(self, public: bool = False):
|
||||
DbModelFilter.__init__(self, public)
|
||||
|
||||
self.field("username", StringFilter).with_public(public)
|
||||
self.field("email", StringFilter).with_public(public)
|
||||
@@ -1,6 +1,7 @@
|
||||
from cpl.dependency.module.module import Module
|
||||
from cpl.graphql.auth.administration.auth_user_filter import AuthUserFilter
|
||||
from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType
|
||||
|
||||
|
||||
class GraphQLAuthModule(Module):
|
||||
transient = [AuthUserGraphType]
|
||||
transient = [AuthUserGraphType, AuthUserFilter]
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
from typing import Type, Generic
|
||||
from typing import Generic
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.graphql.schema.filter.bool_filter import BoolFilter
|
||||
from cpl.graphql.schema.filter.date_filter import DateFilter
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
from cpl.graphql.schema.filter.int_filter import IntFilter
|
||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||
from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class DbModelFilter(Filter[T], Generic[T]):
|
||||
@@ -15,6 +13,8 @@ class DbModelFilter(Filter[T], Generic[T]):
|
||||
|
||||
self.field("id", IntFilter).with_public(public)
|
||||
self.field("deleted", BoolFilter).with_public(public)
|
||||
# self.field("editor", AuthUserFilter)
|
||||
from cpl.graphql.auth.administration.auth_user_filter import AuthUserFilter
|
||||
|
||||
self.field("editor", lambda: AuthUserFilter).with_public(public)
|
||||
self.field("created", DateFilter).with_public(public)
|
||||
self.field("updated", DateFilter).with_public(public)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import types
|
||||
from typing import Generic, Dict, Type, Optional, Union, Any
|
||||
|
||||
import strawberry
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.dependency import get_provider
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.typing import AttributeName
|
||||
@@ -39,7 +41,7 @@ class Input(StrawberryProtocol, Generic[T]):
|
||||
def get_fields(self) -> dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
def field(self, name: AttributeName, typ: Union[type, "Input"], optional: bool = True) -> Field:
|
||||
def field(self, name: AttributeName, typ: type, optional: bool = True) -> Field:
|
||||
if isinstance(name, property):
|
||||
name = name.fget.__name__
|
||||
|
||||
@@ -62,6 +64,9 @@ class Input(StrawberryProtocol, Generic[T]):
|
||||
return self.field(name, list[t], optional)
|
||||
|
||||
def object_field(self, name: AttributeName, t: Type[StrawberryProtocol], optional: bool = True) -> Field:
|
||||
if not isinstance(t, type) and callable(t):
|
||||
return self.field(name, t, optional)
|
||||
|
||||
return self.field(name, t().to_strawberry(), optional)
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
@@ -69,20 +74,28 @@ class Input(StrawberryProtocol, Generic[T]):
|
||||
if TypeCollector.has(cls):
|
||||
return TypeCollector.get(cls)
|
||||
|
||||
annotations = {}
|
||||
namespace = {}
|
||||
gql_cls = type(f"{cls.__name__.replace('GraphType', '')}", (), {})
|
||||
# register early to handle recursive types
|
||||
TypeCollector.set(cls, gql_cls)
|
||||
|
||||
annotations: dict[str, Any] = {}
|
||||
namespace: dict[str, Any] = {}
|
||||
|
||||
for name, f in self._fields.items():
|
||||
typ = f.type
|
||||
if isinstance(typ, type) and issubclass(typ, Input):
|
||||
typ = typ().to_strawberry()
|
||||
elif isinstance(typ, Input):
|
||||
typ = typ.to_strawberry()
|
||||
t = f.type
|
||||
|
||||
ann = typ if not f.optional else Optional[typ]
|
||||
if isinstance(t, types.FunctionType):
|
||||
_t = get_provider().get_service(t())
|
||||
if _t is None:
|
||||
raise ValueError(f"'{t()}' could not be resolved from the provider")
|
||||
t = _t.to_strawberry()
|
||||
elif isinstance(t, type) and issubclass(t, Input):
|
||||
t = t().to_strawberry()
|
||||
elif isinstance(t, Input):
|
||||
t = t.to_strawberry()
|
||||
|
||||
py_name = name + "_" if name in _PYTHON_KEYWORDS else name
|
||||
annotations[py_name] = ann
|
||||
annotations[py_name] = t if not f.optional else Optional[t]
|
||||
|
||||
field_args = {}
|
||||
if py_name != name:
|
||||
@@ -93,6 +106,10 @@ class Input(StrawberryProtocol, Generic[T]):
|
||||
|
||||
namespace["__annotations__"] = annotations
|
||||
|
||||
gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace))
|
||||
for k, v in namespace.items():
|
||||
setattr(gql_cls, k, v)
|
||||
|
||||
gql_cls.__annotations__ = annotations
|
||||
gql_type = strawberry.input(gql_cls)
|
||||
TypeCollector.set(cls, gql_type)
|
||||
return gql_type
|
||||
|
||||
Reference in New Issue
Block a user