Recursive filter #181
This commit is contained in:
@@ -0,0 +1,11 @@
|
|||||||
|
from cpl.auth.schema import AuthUser
|
||||||
|
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
|
||||||
|
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||||
|
|
||||||
|
|
||||||
|
class AuthUserFilter(DbModelFilter[AuthUser]):
|
||||||
|
def __init__(self, public: bool = False):
|
||||||
|
DbModelFilter.__init__(self, public)
|
||||||
|
|
||||||
|
self.field("username", StringFilter).with_public(public)
|
||||||
|
self.field("email", StringFilter).with_public(public)
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
from cpl.dependency.module.module import Module
|
from cpl.dependency.module.module import Module
|
||||||
|
from cpl.graphql.auth.administration.auth_user_filter import AuthUserFilter
|
||||||
from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType
|
from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType
|
||||||
|
|
||||||
|
|
||||||
class GraphQLAuthModule(Module):
|
class GraphQLAuthModule(Module):
|
||||||
transient = [AuthUserGraphType]
|
transient = [AuthUserGraphType, AuthUserFilter]
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
from typing import Type, Generic
|
from typing import Generic
|
||||||
|
|
||||||
from cpl.core.typing import T
|
from cpl.core.typing import T
|
||||||
from cpl.graphql.schema.filter.bool_filter import BoolFilter
|
from cpl.graphql.schema.filter.bool_filter import BoolFilter
|
||||||
from cpl.graphql.schema.filter.date_filter import DateFilter
|
from cpl.graphql.schema.filter.date_filter import DateFilter
|
||||||
from cpl.graphql.schema.filter.filter import Filter
|
from cpl.graphql.schema.filter.filter import Filter
|
||||||
from cpl.graphql.schema.filter.int_filter import IntFilter
|
from cpl.graphql.schema.filter.int_filter import IntFilter
|
||||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
|
||||||
from cpl.graphql.schema.input import Input
|
|
||||||
|
|
||||||
|
|
||||||
class DbModelFilter(Filter[T], Generic[T]):
|
class DbModelFilter(Filter[T], Generic[T]):
|
||||||
@@ -15,6 +13,8 @@ class DbModelFilter(Filter[T], Generic[T]):
|
|||||||
|
|
||||||
self.field("id", IntFilter).with_public(public)
|
self.field("id", IntFilter).with_public(public)
|
||||||
self.field("deleted", BoolFilter).with_public(public)
|
self.field("deleted", BoolFilter).with_public(public)
|
||||||
# self.field("editor", AuthUserFilter)
|
from cpl.graphql.auth.administration.auth_user_filter import AuthUserFilter
|
||||||
|
|
||||||
|
self.field("editor", lambda: AuthUserFilter).with_public(public)
|
||||||
self.field("created", DateFilter).with_public(public)
|
self.field("created", DateFilter).with_public(public)
|
||||||
self.field("updated", DateFilter).with_public(public)
|
self.field("updated", DateFilter).with_public(public)
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
|
import types
|
||||||
from typing import Generic, Dict, Type, Optional, Union, Any
|
from typing import Generic, Dict, Type, Optional, Union, Any
|
||||||
|
|
||||||
import strawberry
|
import strawberry
|
||||||
|
|
||||||
from cpl.core.typing import T
|
from cpl.core.typing import T
|
||||||
|
from cpl.dependency import get_provider
|
||||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||||
from cpl.graphql.schema.field import Field
|
from cpl.graphql.schema.field import Field
|
||||||
from cpl.graphql.typing import AttributeName
|
from cpl.graphql.typing import AttributeName
|
||||||
@@ -39,7 +41,7 @@ class Input(StrawberryProtocol, Generic[T]):
|
|||||||
def get_fields(self) -> dict[str, Field]:
|
def get_fields(self) -> dict[str, Field]:
|
||||||
return self._fields
|
return self._fields
|
||||||
|
|
||||||
def field(self, name: AttributeName, typ: Union[type, "Input"], optional: bool = True) -> Field:
|
def field(self, name: AttributeName, typ: type, optional: bool = True) -> Field:
|
||||||
if isinstance(name, property):
|
if isinstance(name, property):
|
||||||
name = name.fget.__name__
|
name = name.fget.__name__
|
||||||
|
|
||||||
@@ -62,6 +64,9 @@ class Input(StrawberryProtocol, Generic[T]):
|
|||||||
return self.field(name, list[t], optional)
|
return self.field(name, list[t], optional)
|
||||||
|
|
||||||
def object_field(self, name: AttributeName, t: Type[StrawberryProtocol], optional: bool = True) -> Field:
|
def object_field(self, name: AttributeName, t: Type[StrawberryProtocol], optional: bool = True) -> Field:
|
||||||
|
if not isinstance(t, type) and callable(t):
|
||||||
|
return self.field(name, t, optional)
|
||||||
|
|
||||||
return self.field(name, t().to_strawberry(), optional)
|
return self.field(name, t().to_strawberry(), optional)
|
||||||
|
|
||||||
def to_strawberry(self) -> Type:
|
def to_strawberry(self) -> Type:
|
||||||
@@ -69,20 +74,28 @@ class Input(StrawberryProtocol, Generic[T]):
|
|||||||
if TypeCollector.has(cls):
|
if TypeCollector.has(cls):
|
||||||
return TypeCollector.get(cls)
|
return TypeCollector.get(cls)
|
||||||
|
|
||||||
annotations = {}
|
gql_cls = type(f"{cls.__name__.replace('GraphType', '')}", (), {})
|
||||||
namespace = {}
|
# register early to handle recursive types
|
||||||
|
TypeCollector.set(cls, gql_cls)
|
||||||
|
|
||||||
|
annotations: dict[str, Any] = {}
|
||||||
|
namespace: dict[str, Any] = {}
|
||||||
|
|
||||||
for name, f in self._fields.items():
|
for name, f in self._fields.items():
|
||||||
typ = f.type
|
t = f.type
|
||||||
if isinstance(typ, type) and issubclass(typ, Input):
|
|
||||||
typ = typ().to_strawberry()
|
|
||||||
elif isinstance(typ, Input):
|
|
||||||
typ = typ.to_strawberry()
|
|
||||||
|
|
||||||
ann = typ if not f.optional else Optional[typ]
|
if isinstance(t, types.FunctionType):
|
||||||
|
_t = get_provider().get_service(t())
|
||||||
|
if _t is None:
|
||||||
|
raise ValueError(f"'{t()}' could not be resolved from the provider")
|
||||||
|
t = _t.to_strawberry()
|
||||||
|
elif isinstance(t, type) and issubclass(t, Input):
|
||||||
|
t = t().to_strawberry()
|
||||||
|
elif isinstance(t, Input):
|
||||||
|
t = t.to_strawberry()
|
||||||
|
|
||||||
py_name = name + "_" if name in _PYTHON_KEYWORDS else name
|
py_name = name + "_" if name in _PYTHON_KEYWORDS else name
|
||||||
annotations[py_name] = ann
|
annotations[py_name] = t if not f.optional else Optional[t]
|
||||||
|
|
||||||
field_args = {}
|
field_args = {}
|
||||||
if py_name != name:
|
if py_name != name:
|
||||||
@@ -93,6 +106,10 @@ class Input(StrawberryProtocol, Generic[T]):
|
|||||||
|
|
||||||
namespace["__annotations__"] = annotations
|
namespace["__annotations__"] = annotations
|
||||||
|
|
||||||
gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace))
|
for k, v in namespace.items():
|
||||||
|
setattr(gql_cls, k, v)
|
||||||
|
|
||||||
|
gql_cls.__annotations__ = annotations
|
||||||
|
gql_type = strawberry.input(gql_cls)
|
||||||
TypeCollector.set(cls, gql_type)
|
TypeCollector.set(cls, gql_type)
|
||||||
return gql_type
|
return gql_type
|
||||||
|
|||||||
Reference in New Issue
Block a user