Added mutations #181
This commit is contained in:
178
src/cpl-graphql/cpl/graphql/abc/query_abc.py
Normal file
178
src/cpl-graphql/cpl/graphql/abc/query_abc.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import functools
|
||||
import inspect
|
||||
from abc import ABC
|
||||
from asyncio import iscoroutinefunction
|
||||
from typing import Callable, Type, Any, Optional
|
||||
|
||||
import strawberry
|
||||
from strawberry.exceptions import StrawberryException
|
||||
|
||||
from cpl.api import Unauthorized, Forbidden
|
||||
from cpl.core.ctx.user_context import get_user
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.error import graphql_error
|
||||
from cpl.graphql.query_context import QueryContext
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.typing import Resolver
|
||||
from cpl.graphql.utils.type_collector import TypeCollector
|
||||
|
||||
|
||||
class QueryABC(StrawberryProtocol, ABC):
|
||||
|
||||
def __init__(self):
|
||||
ABC.__init__(self)
|
||||
self._fields: dict[str, Field] = {}
|
||||
|
||||
@property
|
||||
def fields(self) -> dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
@property
|
||||
def fields_count(self) -> int:
|
||||
return len(self._fields)
|
||||
|
||||
def get_fields(self) -> dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
def field(
|
||||
self,
|
||||
name: str,
|
||||
t: type,
|
||||
resolver: Resolver = None,
|
||||
) -> Field:
|
||||
from cpl.graphql.schema.field import Field
|
||||
|
||||
self._fields[name] = Field(name, t, resolver)
|
||||
return self._fields[name]
|
||||
|
||||
def string_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, str, resolver)
|
||||
|
||||
def int_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, int, resolver)
|
||||
|
||||
def float_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, float, resolver)
|
||||
|
||||
def bool_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, bool, resolver)
|
||||
|
||||
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, list[t], resolver)
|
||||
|
||||
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
|
||||
return self.field(name, t().to_strawberry(), resolver)
|
||||
|
||||
@staticmethod
|
||||
def _build_resolver(f: "Field"):
|
||||
params: list[inspect.Parameter] = []
|
||||
for arg in f.arguments.values():
|
||||
_type = arg.type
|
||||
if isinstance(_type, type) and issubclass(_type, StrawberryProtocol):
|
||||
_type = _type().to_strawberry()
|
||||
|
||||
ann = Optional[_type] if arg.optional else _type
|
||||
|
||||
if arg.default is None:
|
||||
param = inspect.Parameter(
|
||||
arg.name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=ann,
|
||||
)
|
||||
else:
|
||||
param = inspect.Parameter(
|
||||
arg.name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=ann,
|
||||
default=arg.default,
|
||||
)
|
||||
|
||||
params.append(param)
|
||||
|
||||
sig = inspect.Signature(parameters=params, return_annotation=f.type)
|
||||
|
||||
async def _resolver(*args, **kwargs):
|
||||
if f.resolver is None:
|
||||
return None
|
||||
|
||||
if iscoroutinefunction(f.resolver):
|
||||
return await f.resolver(*args, **kwargs)
|
||||
return f.resolver(*args, **kwargs)
|
||||
|
||||
_resolver.__signature__ = sig
|
||||
return _resolver
|
||||
|
||||
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
|
||||
sig = getattr(resolver, "__signature__", None)
|
||||
|
||||
@functools.wraps(resolver)
|
||||
async def _auth_resolver(*args, **kwargs):
|
||||
if f.public:
|
||||
return await self._run_resolver(resolver, *args, **kwargs)
|
||||
|
||||
user = get_user()
|
||||
|
||||
if user is None:
|
||||
raise graphql_error(Unauthorized(f"{f.name}: Authentication required"))
|
||||
|
||||
if f.require_any_permission:
|
||||
if not any([await user.has_permission(p) for p in f.require_any_permission]):
|
||||
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
|
||||
|
||||
if f.require_any:
|
||||
perms, resolvers = f.require_any
|
||||
if not any([await user.has_permission(p) for p in perms]):
|
||||
ctx = QueryContext([x.name for x in await user.permissions])
|
||||
resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers]
|
||||
|
||||
if not any(resolved):
|
||||
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
|
||||
|
||||
return await self._run_resolver(resolver, *args, **kwargs)
|
||||
|
||||
if sig:
|
||||
_auth_resolver.__signature__ = sig
|
||||
|
||||
return _auth_resolver
|
||||
|
||||
@staticmethod
|
||||
async def _run_resolver(r: Callable, *args, **kwargs):
|
||||
if iscoroutinefunction(r):
|
||||
return await r(*args, **kwargs)
|
||||
return r(*args, **kwargs)
|
||||
|
||||
def _field_to_strawberry(self, f: Field) -> Any:
|
||||
resolver = None
|
||||
try:
|
||||
if f.arguments:
|
||||
resolver = self._build_resolver(f)
|
||||
elif not f.resolver:
|
||||
resolver = lambda *_, **__: None
|
||||
else:
|
||||
ann = getattr(f.resolver, "__annotations__", {})
|
||||
if "return" not in ann or ann["return"] is None:
|
||||
ann = dict(ann)
|
||||
ann["return"] = f.type
|
||||
f.resolver.__annotations__ = ann
|
||||
resolver = f.resolver
|
||||
|
||||
return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
|
||||
except StrawberryException as e:
|
||||
raise Exception(f"Error converting field '{f.name}' to strawberry field: {e}") from e
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
cls = self.__class__
|
||||
if TypeCollector.has(cls):
|
||||
return TypeCollector.get(cls)
|
||||
|
||||
annotations: dict[str, Any] = {}
|
||||
namespace: dict[str, Any] = {}
|
||||
|
||||
for name, f in self._fields.items():
|
||||
annotations[name] = f.type
|
||||
namespace[name] = self._field_to_strawberry(f)
|
||||
|
||||
namespace["__annotations__"] = annotations
|
||||
gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
|
||||
TypeCollector.set(cls, gql_type)
|
||||
return gql_type
|
||||
@@ -6,14 +6,15 @@ from cpl.graphql.schema.filter.date_filter import DateFilter
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
from cpl.graphql.schema.filter.int_filter import IntFilter
|
||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||
from cpl.graphql.schema.root_mutation import RootMutation
|
||||
from cpl.graphql.schema.root_query import RootQuery
|
||||
from cpl.graphql.service.schema import Schema
|
||||
from cpl.graphql.service.graphql import GraphQLService
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
|
||||
class GraphQLModule(Module):
|
||||
dependencies = [ApiModule]
|
||||
singleton = [Schema, RootQuery]
|
||||
singleton = [Schema, RootQuery, RootMutation]
|
||||
scoped = [GraphQLService]
|
||||
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Any
|
||||
from typing import Optional
|
||||
|
||||
from graphql import GraphQLResolveInfo
|
||||
|
||||
from cpl.auth.schema import AuthUser, Permission
|
||||
from cpl.core.ctx import get_user
|
||||
from cpl.core.utils import get_value
|
||||
|
||||
|
||||
class QueryContext:
|
||||
|
||||
@@ -1,38 +1,54 @@
|
||||
from typing import Any
|
||||
from typing import Any, Self
|
||||
|
||||
|
||||
class Argument:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
t: type,
|
||||
name: str,
|
||||
t: type,
|
||||
description: str = None,
|
||||
default_value: Any = None,
|
||||
default: Any = None,
|
||||
optional: bool = None,
|
||||
):
|
||||
self._type = t
|
||||
self._name = name
|
||||
self._type = t
|
||||
self._description = description
|
||||
self._default_value = default_value
|
||||
self._default = default
|
||||
self._optional = optional
|
||||
|
||||
@property
|
||||
def type(self) -> type:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def type(self) -> type:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def default_value(self) -> Any | None:
|
||||
return self._default_value
|
||||
def default(self) -> Any | None:
|
||||
return self._default
|
||||
|
||||
@property
|
||||
def optional(self) -> bool | None:
|
||||
return self._optional
|
||||
|
||||
def with_description(self, description: str) -> Self:
|
||||
self._description = description
|
||||
return self
|
||||
|
||||
def with_default(self, default: Any) -> Self:
|
||||
self._default = default
|
||||
return self
|
||||
|
||||
def with_optional(self, optional: bool) -> Self:
|
||||
self._optional = optional
|
||||
return self
|
||||
|
||||
def with_required(self, required: bool = True) -> Self:
|
||||
self._optional = not required
|
||||
return self
|
||||
|
||||
@@ -91,22 +91,26 @@ class Field:
|
||||
self._optional = optional
|
||||
return self
|
||||
|
||||
def with_required(self, required: bool = True) -> Self:
|
||||
self._optional = not required
|
||||
return self
|
||||
|
||||
def with_default(self, default) -> Self:
|
||||
self._default = default
|
||||
return self
|
||||
|
||||
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self:
|
||||
def with_argument(self, name: str, arg_type: type, description: str = None, default_value=None, optional=True) -> Argument:
|
||||
if name in self._args:
|
||||
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
|
||||
self._args[name] = Argument(arg_type, name, description, default_value, optional)
|
||||
return self
|
||||
self._args[name] = Argument(name, arg_type, description, default_value, optional)
|
||||
return self._args[name]
|
||||
|
||||
def with_arguments(self, args: list[Argument]) -> Self:
|
||||
for arg in args:
|
||||
if not isinstance(arg, Argument):
|
||||
raise ValueError(f"Expected Argument instance, got {type(arg)}")
|
||||
|
||||
self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional)
|
||||
self.with_argument(arg.type, arg.name, arg.description, arg.default, arg.optional)
|
||||
return self
|
||||
|
||||
def with_require_any_permission(self, *permissions: TRequireAnyPermissions) -> Self:
|
||||
@@ -126,7 +130,7 @@ class Field:
|
||||
self._require_any = (permissions, resolvers)
|
||||
return self
|
||||
|
||||
def with_public(self, public: bool = False) -> Self:
|
||||
def with_public(self, public: bool = True) -> Self:
|
||||
assert self._require_any is None, "Field cannot be public and have require_any set"
|
||||
assert self._require_any_permission is None, "Field cannot be public and have require_any_permission set"
|
||||
self._public = public
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Generic, Dict, Type, Optional, Self, Union
|
||||
from typing import Generic, Dict, Type, Optional, Union, Any
|
||||
|
||||
import strawberry
|
||||
|
||||
@@ -12,12 +12,52 @@ _PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
|
||||
class Input(StrawberryProtocol, Generic[T]):
|
||||
def __init__(self):
|
||||
self._fields: Dict[str, Field] = {}
|
||||
self._values: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def fields(self) -> Dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item in self._values:
|
||||
return self._values[item]
|
||||
raise AttributeError(f"{self.__class__.__name__} has no attribute {item}")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in {"_fields", "_values"}:
|
||||
super().__setattr__(key, value)
|
||||
elif key in self._fields:
|
||||
self._values[key] = value
|
||||
else:
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
return self._values.get(key, default)
|
||||
|
||||
def get_fields(self) -> dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True):
|
||||
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True) -> Field:
|
||||
self._fields[name] = Field(name, typ, optional=optional)
|
||||
return self._fields[name]
|
||||
|
||||
def string_field(self, name: str, optional: bool = True) -> Field:
|
||||
return self.field(name, str)
|
||||
|
||||
def int_field(self, name: str, optional: bool = True) -> Field:
|
||||
return self.field(name, int, optional)
|
||||
|
||||
def float_field(self, name: str, optional: bool = True) -> Field:
|
||||
return self.field(name, float, optional)
|
||||
|
||||
def bool_field(self, name: str, optional: bool = True) -> Field:
|
||||
return self.field(name, bool, optional)
|
||||
|
||||
def list_field(self, name: str, t: type, optional: bool = True) -> Field:
|
||||
return self.field(name, list[t], optional)
|
||||
|
||||
def object_field(self, name: str, t: Type[StrawberryProtocol], optional: bool = True) -> Field:
|
||||
return self.field(name, t().to_strawberry(), optional)
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
cls = self.__class__
|
||||
|
||||
25
src/cpl-graphql/cpl/graphql/schema/mutation.py
Normal file
25
src/cpl-graphql/cpl/graphql/schema/mutation.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Type
|
||||
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.abc.query_abc import QueryABC
|
||||
from cpl.graphql.schema.field import Field
|
||||
|
||||
|
||||
class Mutation(QueryABC):
|
||||
|
||||
@inject
|
||||
def __init__(self, provider: ServiceProvider):
|
||||
QueryABC.__init__(self)
|
||||
self._provider = provider
|
||||
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
self._schema = provider.get_service(Schema)
|
||||
|
||||
def with_mutation(self, name: str, cls: Type["Mutation"]) -> Field:
|
||||
sub = self._provider.get_service(cls)
|
||||
if not sub:
|
||||
raise ValueError(f"Mutation '{cls.__name__}' not registered in service provider")
|
||||
|
||||
return self.field(name, sub.to_strawberry(), lambda: sub)
|
||||
@@ -1,76 +1,32 @@
|
||||
import functools
|
||||
import inspect
|
||||
from asyncio import iscoroutinefunction
|
||||
from typing import Callable, Type, Any, Optional
|
||||
from typing import Callable, Type
|
||||
|
||||
import strawberry
|
||||
from strawberry.exceptions import StrawberryException
|
||||
|
||||
from cpl.api import Unauthorized, Forbidden
|
||||
from cpl.core.ctx import get_user
|
||||
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.abc.query_abc import QueryABC
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.error import graphql_error
|
||||
from cpl.graphql.query_context import QueryContext
|
||||
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
from cpl.graphql.typing import Resolver
|
||||
from cpl.graphql.utils.type_collector import TypeCollector
|
||||
|
||||
|
||||
class Query(StrawberryProtocol):
|
||||
class Query(QueryABC):
|
||||
|
||||
@inject
|
||||
def __init__(self, provider: ServiceProvider):
|
||||
QueryABC.__init__(self)
|
||||
self._provider = provider
|
||||
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
self._schema = provider.get_service(Schema)
|
||||
self._fields: dict[str, Field] = {}
|
||||
|
||||
def get_fields(self) -> dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
def field(
|
||||
self,
|
||||
name: str,
|
||||
t: type,
|
||||
resolver: Resolver = None,
|
||||
) -> Field:
|
||||
from cpl.graphql.schema.field import Field
|
||||
|
||||
self._fields[name] = Field(name, t, resolver)
|
||||
return self._fields[name]
|
||||
|
||||
def string_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, str, resolver)
|
||||
|
||||
def int_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, int, resolver)
|
||||
|
||||
def float_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, float, resolver)
|
||||
|
||||
def bool_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, bool, resolver)
|
||||
|
||||
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, list[t], resolver)
|
||||
|
||||
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
|
||||
return self.field(name, t().to_strawberry(), resolver)
|
||||
|
||||
def with_query(self, name: str, subquery_cls: Type["Query"]):
|
||||
def with_query(self, name: str, subquery_cls: Type["Query"]) -> Field:
|
||||
sub = self._provider.get_service(subquery_cls)
|
||||
if not sub:
|
||||
raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider")
|
||||
|
||||
self.field(name, sub.to_strawberry(), lambda: sub)
|
||||
return self
|
||||
return self.field(name, sub.to_strawberry(), lambda: sub)
|
||||
|
||||
def collection_field(
|
||||
self,
|
||||
@@ -105,10 +61,10 @@ class Query(StrawberryProtocol):
|
||||
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
|
||||
|
||||
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection)
|
||||
f.with_argument(filter.to_strawberry(), "filter")
|
||||
f.with_argument(sort.to_strawberry(), "sort")
|
||||
f.with_argument(int, "skip", default_value=0)
|
||||
f.with_argument(int, "take", default_value=10)
|
||||
f.with_argument("filter", filter.to_strawberry())
|
||||
f.with_argument("sort", sort.to_strawberry())
|
||||
f.with_argument("skip", int, default_value=0)
|
||||
f.with_argument("take", int, default_value=10)
|
||||
return f
|
||||
|
||||
def dao_collection_field(
|
||||
@@ -168,120 +124,8 @@ class Query(StrawberryProtocol):
|
||||
return Collection(nodes=data, total_count=total_count, count=len(data))
|
||||
|
||||
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver)
|
||||
f.with_argument(filter.to_strawberry(), "filter")
|
||||
f.with_argument(sort.to_strawberry(), "sort")
|
||||
f.with_argument(int, "skip", default_value=0)
|
||||
f.with_argument(int, "take", default_value=10)
|
||||
f.with_argument("filter", filter.to_strawberry())
|
||||
f.with_argument("sort", sort.to_strawberry())
|
||||
f.with_argument("skip", int, default_value=0)
|
||||
f.with_argument("take", int, default_value=10)
|
||||
return f
|
||||
|
||||
@staticmethod
|
||||
def _build_resolver(f: "Field"):
|
||||
params: list[inspect.Parameter] = []
|
||||
for arg in f.arguments.values():
|
||||
ann = Optional[arg.type] if arg.optional else arg.type
|
||||
|
||||
if arg.default_value is None:
|
||||
param = inspect.Parameter(
|
||||
arg.name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=ann,
|
||||
)
|
||||
else:
|
||||
param = inspect.Parameter(
|
||||
arg.name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=ann,
|
||||
default=arg.default_value,
|
||||
)
|
||||
|
||||
params.append(param)
|
||||
|
||||
sig = inspect.Signature(parameters=params, return_annotation=f.type)
|
||||
|
||||
async def _resolver(*args, **kwargs):
|
||||
if f.resolver is None:
|
||||
return None
|
||||
|
||||
if iscoroutinefunction(f.resolver):
|
||||
return await f.resolver(*args, **kwargs)
|
||||
return f.resolver(*args, **kwargs)
|
||||
|
||||
_resolver.__signature__ = sig
|
||||
return _resolver
|
||||
|
||||
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
|
||||
sig = getattr(resolver, "__signature__", None)
|
||||
|
||||
@functools.wraps(resolver)
|
||||
async def _auth_resolver(*args, **kwargs):
|
||||
if f.public:
|
||||
return await self._run_resolver(resolver, *args, **kwargs)
|
||||
|
||||
user = get_user()
|
||||
|
||||
if user is None:
|
||||
raise graphql_error(Unauthorized(f"{f.name}: Authentication required"))
|
||||
|
||||
if f.require_any_permission:
|
||||
if not any([await user.has_permission(p) for p in f.require_any_permission]):
|
||||
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
|
||||
|
||||
if f.require_any:
|
||||
perms, resolvers = f.require_any
|
||||
if not any([await user.has_permission(p) for p in perms]):
|
||||
ctx = QueryContext([x.name for x in await user.permissions])
|
||||
resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers]
|
||||
|
||||
if not any(resolved):
|
||||
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
|
||||
|
||||
return await self._run_resolver(resolver, *args, **kwargs)
|
||||
|
||||
if sig:
|
||||
_auth_resolver.__signature__ = sig
|
||||
|
||||
return _auth_resolver
|
||||
|
||||
@staticmethod
|
||||
async def _run_resolver(r: Callable, *args, **kwargs):
|
||||
if iscoroutinefunction(r):
|
||||
return await r(*args, **kwargs)
|
||||
return r(*args, **kwargs)
|
||||
|
||||
def _field_to_strawberry(self, f: Field) -> Any:
|
||||
resolver = None
|
||||
try:
|
||||
if f.arguments:
|
||||
resolver = self._build_resolver(f)
|
||||
elif not f.resolver:
|
||||
resolver = lambda *_, **__: None
|
||||
else:
|
||||
ann = getattr(f.resolver, "__annotations__", {})
|
||||
if "return" not in ann or ann["return"] is None:
|
||||
ann = dict(ann)
|
||||
ann["return"] = f.type
|
||||
f.resolver.__annotations__ = ann
|
||||
resolver = f.resolver
|
||||
|
||||
return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
|
||||
except StrawberryException as e:
|
||||
raise Exception(
|
||||
f"Error converting field '{f.name}' to strawberry field: {e}"
|
||||
) from e
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
cls = self.__class__
|
||||
if TypeCollector.has(cls):
|
||||
return TypeCollector.get(cls)
|
||||
|
||||
annotations: dict[str, Any] = {}
|
||||
namespace: dict[str, Any] = {}
|
||||
|
||||
for name, f in self._fields.items():
|
||||
annotations[name] = f.type
|
||||
namespace[name] = self._field_to_strawberry(f)
|
||||
|
||||
namespace["__annotations__"] = annotations
|
||||
gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
|
||||
TypeCollector.set(cls, gql_type)
|
||||
return gql_type
|
||||
|
||||
6
src/cpl-graphql/cpl/graphql/schema/root_mutation.py
Normal file
6
src/cpl-graphql/cpl/graphql/schema/root_mutation.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from cpl.graphql.schema.mutation import Mutation
|
||||
|
||||
|
||||
class RootMutation(Mutation):
|
||||
def __init__(self):
|
||||
Mutation.__init__(self)
|
||||
@@ -6,6 +6,7 @@ import strawberry
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.schema.root_mutation import RootMutation
|
||||
from cpl.graphql.schema.root_query import RootQuery
|
||||
|
||||
|
||||
@@ -25,7 +26,17 @@ class Schema:
|
||||
|
||||
@property
|
||||
def query(self) -> RootQuery:
|
||||
return self._provider.get_service(RootQuery)
|
||||
query = self._provider.get_service(RootQuery)
|
||||
if not query:
|
||||
raise ValueError("RootQuery not registered in service provider")
|
||||
return query
|
||||
|
||||
@property
|
||||
def mutation(self) -> RootMutation:
|
||||
mutation = self._provider.get_service(RootMutation)
|
||||
if not mutation:
|
||||
raise ValueError("RootMutation not registered in service provider")
|
||||
return mutation
|
||||
|
||||
def with_type(self, t: Type[StrawberryProtocol]) -> Self:
|
||||
self._types[t.__name__] = t
|
||||
@@ -43,13 +54,13 @@ class Schema:
|
||||
|
||||
def build(self) -> strawberry.Schema:
|
||||
logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL)
|
||||
query = self._provider.get_service(RootQuery)
|
||||
if not query:
|
||||
raise ValueError("RootQuery not registered in service provider")
|
||||
|
||||
query = self.query
|
||||
mutation = self.mutation
|
||||
|
||||
self._schema = strawberry.Schema(
|
||||
query=query.to_strawberry(),
|
||||
mutation=None,
|
||||
query=query.to_strawberry() if query.fields_count > 0 else None,
|
||||
mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None,
|
||||
subscription=None,
|
||||
types=self._get_types(),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user