Changed to strawberry #181
This commit is contained in:
0
src/cpl-graphql/cpl/graphql/abc/__init__.py
Normal file
0
src/cpl-graphql/cpl/graphql/abc/__init__.py
Normal file
9
src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py
Normal file
9
src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from typing import Protocol, Type, runtime_checkable
|
||||
|
||||
from cpl.graphql.schema.field import Field
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class StrawberryProtocol(Protocol):
|
||||
def to_strawberry(self) -> Type: ...
|
||||
def get_fields(self) -> dict[str, Field]: ...
|
||||
@@ -1,17 +1,15 @@
|
||||
from cpl.api.api_module import ApiModule
|
||||
from cpl.dependency.module.module import Module
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.schema.collection import CollectionGraphType
|
||||
from cpl.graphql.schema.root_query import RootQuery
|
||||
from cpl.graphql.service.schema import Schema
|
||||
from cpl.graphql.service.service import GraphQLService
|
||||
from cpl.graphql.service.type_converter import TypeConverter
|
||||
|
||||
|
||||
class GraphQLModule(Module):
|
||||
dependencies = [ApiModule]
|
||||
singleton = [TypeConverter, Schema]
|
||||
scoped = [GraphQLService, RootQuery, CollectionGraphType]
|
||||
singleton = [Schema, RootQuery]
|
||||
scoped = [GraphQLService]
|
||||
|
||||
@staticmethod
|
||||
def configure(services: ServiceProvider) -> None:
|
||||
|
||||
@@ -1,9 +1,21 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Argument:
|
||||
def __init__(self, t: type, name: str, description: str = None, default_value=None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
t: type,
|
||||
name: str,
|
||||
description: str = None,
|
||||
default_value: Any = None,
|
||||
optional: bool = None,
|
||||
):
|
||||
self._type = t
|
||||
self._name = name
|
||||
self._description = description
|
||||
self._default_value = default_value
|
||||
self._optional = optional
|
||||
|
||||
@property
|
||||
def type(self) -> type:
|
||||
@@ -18,5 +30,9 @@ class Argument:
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def default_value(self):
|
||||
def default_value(self) -> Any | None:
|
||||
return self._default_value
|
||||
|
||||
@property
|
||||
def optional(self) -> bool | None:
|
||||
return self._optional
|
||||
|
||||
@@ -1,18 +1,53 @@
|
||||
from typing import Generic, Type
|
||||
from typing import Type, Dict, List
|
||||
|
||||
import strawberry
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.graphql.schema.graph_type import GraphType
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
|
||||
|
||||
class Collection(Generic[T]):
|
||||
class CollectionGraphTypeFactory:
|
||||
_cache: Dict[Type, Type] = {}
|
||||
|
||||
@classmethod
|
||||
def get(cls, node_type: Type[StrawberryProtocol]) -> Type:
|
||||
if node_type in cls._cache:
|
||||
return cls._cache[node_type]
|
||||
|
||||
gql_node = node_type().to_strawberry() if hasattr(node_type, "to_strawberry") else node_type
|
||||
|
||||
gql_type = strawberry.type(
|
||||
type(
|
||||
f"{node_type.__name__}Collection",
|
||||
(),
|
||||
{
|
||||
"__annotations__": {
|
||||
"nodes": List[gql_node],
|
||||
"total_count": int,
|
||||
"count": int,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
cls._cache[node_type] = gql_type
|
||||
return gql_type
|
||||
|
||||
|
||||
class Collection:
|
||||
def __init__(self, nodes: list[T], total_count: int, count: int):
|
||||
self.nodes = nodes
|
||||
self.totalCount = total_count
|
||||
self.count = count
|
||||
self._nodes = nodes
|
||||
self._total_count = total_count
|
||||
self._count = count
|
||||
|
||||
class CollectionGraphType(GraphType[T]):
|
||||
def __init__(self, t: Type[GraphType[T]]):
|
||||
GraphType.__init__(self)
|
||||
self.string_field("totalCount", resolver=lambda obj, *_: obj.totalCount)
|
||||
self.string_field("count", resolver=lambda obj, *_: obj.count)
|
||||
self.list_field("nodes", t, resolver=lambda obj, *_: obj.nodes)
|
||||
@property
|
||||
def nodes(self) -> list[T]:
|
||||
return self._nodes
|
||||
|
||||
@property
|
||||
def total_count(self) -> int:
|
||||
return self._total_count
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return self._count
|
||||
|
||||
@@ -6,11 +6,24 @@ from cpl.graphql.typing import TQuery, Resolver
|
||||
|
||||
class Field:
|
||||
|
||||
def __init__(self, name: str, gql_type: type, resolver: Resolver = None, subquery: TQuery = None):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
gql_type: type = None,
|
||||
resolver: Resolver = None,
|
||||
optional=None,
|
||||
default=None,
|
||||
subquery: TQuery = None,
|
||||
parent_type=None,
|
||||
):
|
||||
self._name = name
|
||||
self._gql_type = gql_type
|
||||
self._resolver = resolver
|
||||
self._optional = optional or True
|
||||
self._default = default
|
||||
|
||||
self._subquery = subquery
|
||||
self._parent_type = parent_type
|
||||
|
||||
self._args: dict[str, Argument] = {}
|
||||
|
||||
@@ -26,6 +39,14 @@ class Field:
|
||||
def resolver(self) -> callable:
|
||||
return self._resolver
|
||||
|
||||
@property
|
||||
def optional(self) -> bool | None:
|
||||
return self._optional
|
||||
|
||||
@property
|
||||
def default(self):
|
||||
return self._default
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
return self._args
|
||||
@@ -34,10 +55,18 @@ class Field:
|
||||
def subquery(self) -> TQuery | None:
|
||||
return self._subquery
|
||||
|
||||
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None) -> Self:
|
||||
@property
|
||||
def parent_type(self):
|
||||
return self._parent_type
|
||||
|
||||
@property
|
||||
def arguments(self) -> dict[str, Argument]:
|
||||
return self._args
|
||||
|
||||
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self:
|
||||
if name in self._args:
|
||||
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
|
||||
self._args[name] = Argument(arg_type, name, description, default_value)
|
||||
self._args[name] = Argument(arg_type, name, description, default_value, optional)
|
||||
return self
|
||||
|
||||
def with_arguments(self, args: list[Argument]) -> Self:
|
||||
@@ -45,5 +74,5 @@ class Field:
|
||||
if not isinstance(arg, Argument):
|
||||
raise ValueError(f"Expected Argument instance, got {type(arg)}")
|
||||
|
||||
self.with_argument(arg.type, arg.name, arg.description, arg.default_value)
|
||||
self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional)
|
||||
return self
|
||||
|
||||
@@ -3,7 +3,5 @@ from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class Filter(Input[T]):
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
def __init__(self):
|
||||
Input.__init__(self)
|
||||
|
||||
@@ -4,7 +4,7 @@ from cpl.core.typing import T
|
||||
from cpl.graphql.schema.query import Query
|
||||
|
||||
|
||||
class GraphType(Generic[T], Query):
|
||||
class GraphType(Query, Generic[T]):
|
||||
|
||||
def __init__(self):
|
||||
Query.__init__(self)
|
||||
@@ -1,26 +1,34 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Type, Generic
|
||||
from typing import Generic, Dict, Type, Any, Optional
|
||||
|
||||
import graphene
|
||||
import strawberry
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.schema.field import Field
|
||||
|
||||
|
||||
class Input(Generic[T], graphene.InputObjectType):
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
graphene.InputObjectType.__init__(self)
|
||||
self._fields: dict[str, Field] = {}
|
||||
class Input(StrawberryProtocol, Generic[T]):
|
||||
def __init__(self):
|
||||
self._fields: Dict[str, Field] = {}
|
||||
|
||||
def get_fields(self) -> dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
def field(
|
||||
self,
|
||||
field: str,
|
||||
t: Type["Input"] | Type[int | str | bool | datetime | list | Enum],
|
||||
):
|
||||
self._fields[field] = Field(field, t)
|
||||
def field(self, name: str, typ: type, optional: bool = True):
|
||||
self._fields[name] = Field(name, typ, optional=optional)
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
annotations = {}
|
||||
namespace = {}
|
||||
|
||||
for name, f in self._fields.items():
|
||||
ann = f.type if not f.optional else Optional[f.type]
|
||||
annotations[name] = ann
|
||||
|
||||
if f.optional:
|
||||
namespace[name] = None
|
||||
elif f.default is not None:
|
||||
namespace[name] = f.default
|
||||
|
||||
namespace["__annotations__"] = annotations
|
||||
return strawberry.input(type(f"{self.__class__.__name__}Input", (), namespace))
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from cpl.core.typing import T
|
||||
from cpl.graphql.schema.graph_type import GraphType
|
||||
from cpl.graphql.schema.query import Query
|
||||
|
||||
|
||||
class ObjectGraphType(GraphType[T], Query):
|
||||
|
||||
def __init__(self):
|
||||
Query.__init__(self)
|
||||
@@ -1,21 +1,27 @@
|
||||
from typing import Callable, Type
|
||||
import inspect
|
||||
from typing import Callable, Type, Any, Optional
|
||||
|
||||
from graphene import ObjectType
|
||||
import strawberry
|
||||
from strawberry.exceptions import StrawberryException
|
||||
|
||||
from cpl.graphql.schema.argument import Argument
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
from cpl.graphql.schema.sort.sort import Sort
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
from cpl.graphql.typing import Resolver
|
||||
|
||||
|
||||
class Query(ObjectType):
|
||||
class Query(StrawberryProtocol):
|
||||
|
||||
def __init__(self):
|
||||
from cpl.graphql.schema.field import Field
|
||||
@inject
|
||||
def __init__(self, provider: ServiceProvider):
|
||||
self._provider = provider
|
||||
|
||||
ObjectType.__init__(self)
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
self._schema = provider.get_service(Schema)
|
||||
self._fields: dict[str, Field] = {}
|
||||
|
||||
def get_fields(self) -> dict[str, Field]:
|
||||
@@ -25,69 +31,137 @@ class Query(ObjectType):
|
||||
self,
|
||||
name: str,
|
||||
t: type,
|
||||
resolver: Callable | None = None,
|
||||
) -> "Field":
|
||||
resolver: Resolver = None,
|
||||
) -> Field:
|
||||
from cpl.graphql.schema.field import Field
|
||||
|
||||
self._fields[name] = Field(name, t, resolver)
|
||||
return self._fields[name]
|
||||
|
||||
def with_query(self, name: str, subquery: Type["Query"]):
|
||||
from cpl.graphql.schema.field import Field
|
||||
|
||||
f = Field(name=name, gql_type=subquery, resolver=lambda root, info, **kwargs: {}, subquery=subquery)
|
||||
self._fields[name] = f
|
||||
return self._fields[name]
|
||||
|
||||
def string_field(self, name: str, resolver: Resolver = None) -> "Field":
|
||||
def string_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, str, resolver)
|
||||
|
||||
def int_field(self, name: str, resolver: Resolver = None) -> "Field":
|
||||
def int_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, int, resolver)
|
||||
|
||||
def float_field(self, name: str, resolver: Resolver = None) -> "Field":
|
||||
def float_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, float, resolver)
|
||||
|
||||
def bool_field(self, name: str, resolver: Resolver = None) -> "Field":
|
||||
def bool_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, bool, resolver)
|
||||
|
||||
def list_field(self, name: str, t: type, resolver: Resolver = None) -> "Field":
|
||||
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, list[t], resolver)
|
||||
|
||||
def with_query(self, name: str, subquery_cls: Type["Query"]):
|
||||
sub = self._provider.get_service(subquery_cls)
|
||||
if not sub:
|
||||
raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider")
|
||||
|
||||
self.field(name, sub.to_strawberry(), lambda: sub)
|
||||
return self
|
||||
|
||||
def collection_field(
|
||||
self, t: type, name: str, filter_type: type, sort_type: type, resolver: Resolver = None
|
||||
) -> "Field":
|
||||
from cpl.graphql.schema.collection import Collection, CollectionGraphType
|
||||
self,
|
||||
t: type,
|
||||
name: str,
|
||||
filter_type: Type[StrawberryProtocol],
|
||||
sort_type: Type[StrawberryProtocol],
|
||||
resolver: Callable,
|
||||
) -> Field:
|
||||
# self._schema.with_type(filter_type)
|
||||
# self._schema.with_type(sort_type)
|
||||
|
||||
def _resolve_collection(*_, filter: Filter, sort: Sort, skip: int, take: int):
|
||||
def _resolve_collection(filter=None, sort=None, skip=0, take=10):
|
||||
items = resolver()
|
||||
if filter:
|
||||
for field, value in filter.__dict__.items():
|
||||
if value is None:
|
||||
continue
|
||||
items = [i for i in items if getattr(i, field) == value]
|
||||
|
||||
for field in filter or []:
|
||||
if filter[field] is None:
|
||||
continue
|
||||
|
||||
items = [item for item in items if getattr(item, field) == filter[field]]
|
||||
|
||||
for field in sort or []:
|
||||
if sort[field] is None:
|
||||
continue
|
||||
|
||||
reverse = sort[field] == SortOrder.DESC
|
||||
items = sorted(items, key=lambda item: getattr(item, field), reverse=reverse)
|
||||
|
||||
if sort:
|
||||
for field, direction in sort.__dict__.items():
|
||||
reverse = direction == SortOrder.DESC
|
||||
items = sorted(items, key=lambda i: getattr(i, field), reverse=reverse)
|
||||
total_count = len(items)
|
||||
paged = items[skip : skip + take]
|
||||
return Collection(nodes=paged, total_count=total_count, count=len(paged))
|
||||
|
||||
# base = getattr(t, "__gqlname__", t.__class__.__name__)
|
||||
wrapper = CollectionGraphType(t)
|
||||
# wrapper.set_graphql_name(f"{base}Collection")
|
||||
f = self.field(name, wrapper, resolver=_resolve_collection)
|
||||
return f.with_arguments(
|
||||
[
|
||||
Argument(filter_type, "filter"),
|
||||
Argument(sort_type, "sort"),
|
||||
Argument(int, "skip", default_value=0),
|
||||
Argument(int, "take", default_value=10),
|
||||
]
|
||||
)
|
||||
filter = self._provider.get_service(filter_type)
|
||||
if not filter:
|
||||
raise ValueError(f"Filter '{filter_type.__name__}' not registered in service provider")
|
||||
|
||||
sort = self._provider.get_service(sort_type)
|
||||
if not sort:
|
||||
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
|
||||
|
||||
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection)
|
||||
f.with_argument(filter.to_strawberry(), "filter")
|
||||
f.with_argument(sort.to_strawberry(), "sort")
|
||||
f.with_argument(int, "skip", default_value=0)
|
||||
f.with_argument(int, "take", default_value=10)
|
||||
return f
|
||||
|
||||
@staticmethod
|
||||
def _build_resolver(f: "Field"):
|
||||
params: list[inspect.Parameter] = []
|
||||
for arg in f.arguments.values():
|
||||
ann = Optional[arg.type] if arg.optional else arg.type
|
||||
|
||||
if arg.default_value is None:
|
||||
param = inspect.Parameter(
|
||||
arg.name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=ann,
|
||||
)
|
||||
else:
|
||||
param = inspect.Parameter(
|
||||
arg.name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=ann,
|
||||
default=arg.default_value,
|
||||
)
|
||||
|
||||
params.append(param)
|
||||
|
||||
sig = inspect.Signature(parameters=params, return_annotation=f.type)
|
||||
|
||||
def _resolver(*args, **kwargs):
|
||||
return f.resolver(*args, **kwargs) if f.resolver else None
|
||||
|
||||
_resolver.__signature__ = sig
|
||||
return _resolver
|
||||
|
||||
def _field_to_strawberry(self, f: Field) -> Any:
|
||||
try:
|
||||
if f.resolver:
|
||||
ann = getattr(f.resolver, "__annotations__", {})
|
||||
if "return" not in ann or ann["return"] is None:
|
||||
ann = dict(ann)
|
||||
ann["return"] = f.type
|
||||
f.resolver.__annotations__ = ann
|
||||
|
||||
if f.arguments:
|
||||
resolver = self._build_resolver(f)
|
||||
return strawberry.field(resolver=resolver)
|
||||
|
||||
if not f.resolver:
|
||||
return strawberry.field(resolver=lambda *_, **__: None)
|
||||
|
||||
return strawberry.field(resolver=f.resolver)
|
||||
except StrawberryException as e:
|
||||
raise Exception(
|
||||
f"Error converting field '{f.name}' to strawberry field: {e}"
|
||||
) from e
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
annotations: dict[str, Any] = {}
|
||||
namespace: dict[str, Any] = {}
|
||||
|
||||
for name, f in self._fields.items():
|
||||
annotations[name] = f.type
|
||||
namespace[name] = self._field_to_strawberry(f)
|
||||
|
||||
namespace["__annotations__"] = annotations
|
||||
return strawberry.type(type(f"{self.__class__.__name__}GraphType", (), namespace))
|
||||
|
||||
@@ -1,43 +1,54 @@
|
||||
import graphene
|
||||
from typing import Type, Self
|
||||
|
||||
import strawberry
|
||||
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.schema.collection import CollectionGraphType
|
||||
from cpl.graphql.schema.graph_type import GraphType
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.schema.root_query import RootQuery
|
||||
from cpl.graphql.service.type_converter import TypeConverter
|
||||
|
||||
|
||||
class Schema:
|
||||
|
||||
def __init__(self, logger: APILogger, converter: TypeConverter, query: RootQuery, provider: ServiceProvider):
|
||||
def __init__(self, logger: APILogger, provider: ServiceProvider):
|
||||
self._logger = logger
|
||||
self._provider = provider
|
||||
self._converter = converter
|
||||
|
||||
self._types = set(GraphType.__subclasses__())
|
||||
self._types.remove(CollectionGraphType)
|
||||
self._types: dict[str, Type[StrawberryProtocol]] = {}
|
||||
|
||||
self._query = query
|
||||
self._schema = None
|
||||
|
||||
@property
|
||||
def schema(self) -> graphene.Schema | None:
|
||||
def schema(self) -> strawberry.Schema | None:
|
||||
return self._schema
|
||||
|
||||
@property
|
||||
def query(self) -> RootQuery:
|
||||
return self._query
|
||||
return self._provider.get_service(RootQuery)
|
||||
|
||||
def with_type(self, t: type[GraphType]):
|
||||
self._types.add(t)
|
||||
def with_type(self, t: Type[StrawberryProtocol]) -> Self:
|
||||
self._types[t.__name__] = t
|
||||
return self
|
||||
|
||||
def build(self) -> graphene.Schema:
|
||||
self._schema = graphene.Schema(
|
||||
query=self._converter.to_graphene(self._query),
|
||||
def _get_types(self):
|
||||
types: list[Type] = []
|
||||
for t in self._types.values():
|
||||
t_obj = self._provider.get_service(t)
|
||||
if not t_obj:
|
||||
raise ValueError(f"Type '{t.__name__}' not registered in service provider")
|
||||
types.append(t_obj.to_strawberry())
|
||||
|
||||
return types
|
||||
|
||||
def build(self) -> strawberry.Schema:
|
||||
query = self._provider.get_service(RootQuery)
|
||||
if not query:
|
||||
raise ValueError("RootQuery not registered in service provider")
|
||||
|
||||
self._schema = strawberry.Schema(
|
||||
query=query.to_strawberry(),
|
||||
mutation=None,
|
||||
subscription=None,
|
||||
# types=[self._converter.to_graphene(t) for t in self._types] if len(self._types) > 0 else None,
|
||||
types=self._get_types(),
|
||||
)
|
||||
return self._schema
|
||||
|
||||
@@ -16,7 +16,7 @@ class GraphQLService:
|
||||
variables: Optional[Dict[str, Any]],
|
||||
request: TRequest,
|
||||
) -> Dict[str, Any]:
|
||||
result = await self._schema.execute_async(
|
||||
result = await self._schema.execute(
|
||||
query,
|
||||
variable_values=variables,
|
||||
context_value={"request": request},
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
import typing
|
||||
from enum import Enum
|
||||
from inspect import isclass
|
||||
|
||||
import graphene
|
||||
from typing import Any, get_origin, get_args
|
||||
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.graphql.schema.argument import Argument
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
from cpl.graphql.schema.graph_type import GraphType
|
||||
from cpl.graphql.schema.object_graph_type import ObjectGraphType
|
||||
from cpl.graphql.schema.sort.sort import Sort
|
||||
from cpl.graphql.typing import Resolver
|
||||
from cpl.graphql.utils.name_pipe import NamePipe
|
||||
|
||||
|
||||
class TypeConverter:
|
||||
__scalar_map: dict[Any, type[graphene.Scalar]] = {
|
||||
str: graphene.String,
|
||||
int: graphene.Int,
|
||||
float: graphene.Float,
|
||||
bool: graphene.Boolean,
|
||||
}
|
||||
|
||||
def __init__(self, provider: ServiceProvider):
|
||||
self._provider = provider
|
||||
|
||||
def _field_to_graphene(self, t: typing.Type[graphene.Scalar] | type, args: dict[str, Argument] = None, resolver: Resolver = None) -> graphene.Field:
|
||||
arguments = {}
|
||||
if args is not None:
|
||||
arguments = {
|
||||
arg.name: graphene.Argument(self.to_graphene(arg.type), name=arg.name, description=arg.description, default_value=arg.default_value)
|
||||
for arg in args.values()
|
||||
}
|
||||
|
||||
return graphene.Field(t, args=arguments, resolver=resolver)
|
||||
|
||||
def to_graphene(self, t: Any, name: str | None = None) -> Any:
|
||||
try:
|
||||
origin = get_origin(t)
|
||||
args = get_args(t)
|
||||
|
||||
if t in self.__scalar_map:
|
||||
return self.__scalar_map[t]
|
||||
|
||||
if origin in (list, typing.List):
|
||||
if not args:
|
||||
raise ValueError("List must specify element type, e.g. list[str]")
|
||||
inner = self.to_graphene(args[0])
|
||||
return graphene.List(inner)
|
||||
|
||||
if t is list or t is typing.List:
|
||||
raise ValueError("List must be parametrized: list[str], list[int], list[UserQuery]")
|
||||
|
||||
if isclass(t) and issubclass(t, Enum):
|
||||
return graphene.Enum.from_enum(t)
|
||||
|
||||
from cpl.graphql.schema.query import Query
|
||||
if isinstance(t, type) and issubclass(t, (Query)):
|
||||
query = self._provider.get_service(t)
|
||||
if query is None:
|
||||
raise ValueError(f"Could not resolve query of type {t}")
|
||||
|
||||
t = query
|
||||
|
||||
if isinstance(t, type) and issubclass(t, (ObjectGraphType, GraphType, Filter, Sort)):
|
||||
t = t()
|
||||
|
||||
if isinstance(t, (Query, Filter, Sort)):
|
||||
attrs = {}
|
||||
for field in t.get_fields().values():
|
||||
if isclass(field.type) and issubclass(field.type, Query) and field.subquery is not None:
|
||||
subquery = self._provider.get_service(field.subquery)
|
||||
sub = self.to_graphene(subquery, name=field.name.capitalize())
|
||||
attrs[field.name] = self._field_to_graphene(sub, field.args, field.resolver)
|
||||
continue
|
||||
|
||||
attrs[field.name] = self._field_to_graphene(self.to_graphene(field.type), field.args, field.resolver)
|
||||
|
||||
class_name = NamePipe.to_str(name or t.__class__)
|
||||
if isinstance(t, (Filter, Sort)):
|
||||
return type(class_name, (graphene.InputObjectType,), attrs)
|
||||
|
||||
return type(class_name, (graphene.ObjectType,), attrs)
|
||||
|
||||
raise ValueError(f"Unsupported field type: {t}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to convert type {t} to graphene type: {e}") from e
|
||||
@@ -1,2 +1,2 @@
|
||||
cpl-api
|
||||
graphene==3.4.3
|
||||
strawberry-graphql==0.282.0
|
||||
Reference in New Issue
Block a user