Added mutations #181

This commit is contained in:
2025-09-28 18:51:28 +02:00
parent 3286a95cbf
commit 39d06dfe48
16 changed files with 424 additions and 210 deletions

View File

@@ -0,0 +1,178 @@
import functools
import inspect
from abc import ABC
from asyncio import iscoroutinefunction
from typing import Callable, Type, Any, Optional
import strawberry
from strawberry.exceptions import StrawberryException
from cpl.api import Unauthorized, Forbidden
from cpl.core.ctx.user_context import get_user
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.error import graphql_error
from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.field import Field
from cpl.graphql.typing import Resolver
from cpl.graphql.utils.type_collector import TypeCollector
class QueryABC(StrawberryProtocol, ABC):
def __init__(self):
ABC.__init__(self)
self._fields: dict[str, Field] = {}
@property
def fields(self) -> dict[str, Field]:
return self._fields
@property
def fields_count(self) -> int:
return len(self._fields)
def get_fields(self) -> dict[str, Field]:
return self._fields
def field(
self,
name: str,
t: type,
resolver: Resolver = None,
) -> Field:
from cpl.graphql.schema.field import Field
self._fields[name] = Field(name, t, resolver)
return self._fields[name]
def string_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, str, resolver)
def int_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, int, resolver)
def float_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, float, resolver)
def bool_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, bool, resolver)
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
return self.field(name, list[t], resolver)
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
return self.field(name, t().to_strawberry(), resolver)
@staticmethod
def _build_resolver(f: "Field"):
params: list[inspect.Parameter] = []
for arg in f.arguments.values():
_type = arg.type
if isinstance(_type, type) and issubclass(_type, StrawberryProtocol):
_type = _type().to_strawberry()
ann = Optional[_type] if arg.optional else _type
if arg.default is None:
param = inspect.Parameter(
arg.name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=ann,
)
else:
param = inspect.Parameter(
arg.name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=ann,
default=arg.default,
)
params.append(param)
sig = inspect.Signature(parameters=params, return_annotation=f.type)
async def _resolver(*args, **kwargs):
if f.resolver is None:
return None
if iscoroutinefunction(f.resolver):
return await f.resolver(*args, **kwargs)
return f.resolver(*args, **kwargs)
_resolver.__signature__ = sig
return _resolver
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
sig = getattr(resolver, "__signature__", None)
@functools.wraps(resolver)
async def _auth_resolver(*args, **kwargs):
if f.public:
return await self._run_resolver(resolver, *args, **kwargs)
user = get_user()
if user is None:
raise graphql_error(Unauthorized(f"{f.name}: Authentication required"))
if f.require_any_permission:
if not any([await user.has_permission(p) for p in f.require_any_permission]):
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
if f.require_any:
perms, resolvers = f.require_any
if not any([await user.has_permission(p) for p in perms]):
ctx = QueryContext([x.name for x in await user.permissions])
resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers]
if not any(resolved):
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
return await self._run_resolver(resolver, *args, **kwargs)
if sig:
_auth_resolver.__signature__ = sig
return _auth_resolver
@staticmethod
async def _run_resolver(r: Callable, *args, **kwargs):
if iscoroutinefunction(r):
return await r(*args, **kwargs)
return r(*args, **kwargs)
def _field_to_strawberry(self, f: Field) -> Any:
resolver = None
try:
if f.arguments:
resolver = self._build_resolver(f)
elif not f.resolver:
resolver = lambda *_, **__: None
else:
ann = getattr(f.resolver, "__annotations__", {})
if "return" not in ann or ann["return"] is None:
ann = dict(ann)
ann["return"] = f.type
f.resolver.__annotations__ = ann
resolver = f.resolver
return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
except StrawberryException as e:
raise Exception(f"Error converting field '{f.name}' to strawberry field: {e}") from e
def to_strawberry(self) -> Type:
cls = self.__class__
if TypeCollector.has(cls):
return TypeCollector.get(cls)
annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {}
for name, f in self._fields.items():
annotations[name] = f.type
namespace[name] = self._field_to_strawberry(f)
namespace["__annotations__"] = annotations
gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
TypeCollector.set(cls, gql_type)
return gql_type

View File

@@ -6,14 +6,15 @@ from cpl.graphql.schema.filter.date_filter import DateFilter
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.filter.int_filter import IntFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
from cpl.graphql.schema.root_mutation import RootMutation
from cpl.graphql.schema.root_query import RootQuery
from cpl.graphql.service.schema import Schema
from cpl.graphql.service.graphql import GraphQLService
from cpl.graphql.service.schema import Schema
class GraphQLModule(Module):
dependencies = [ApiModule]
singleton = [Schema, RootQuery]
singleton = [Schema, RootQuery, RootMutation]
scoped = [GraphQLService]
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]

View File

@@ -1,11 +1,10 @@
from enum import Enum
from typing import Optional, Any
from typing import Optional
from graphql import GraphQLResolveInfo
from cpl.auth.schema import AuthUser, Permission
from cpl.core.ctx import get_user
from cpl.core.utils import get_value
class QueryContext:

View File

@@ -1,38 +1,54 @@
from typing import Any
from typing import Any, Self
class Argument:
def __init__(
self,
t: type,
name: str,
t: type,
description: str = None,
default_value: Any = None,
default: Any = None,
optional: bool = None,
):
self._type = t
self._name = name
self._type = t
self._description = description
self._default_value = default_value
self._default = default
self._optional = optional
@property
def type(self) -> type:
return self._type
@property
def name(self) -> str:
return self._name
@property
def type(self) -> type:
return self._type
@property
def description(self) -> str | None:
return self._description
@property
def default_value(self) -> Any | None:
return self._default_value
def default(self) -> Any | None:
return self._default
@property
def optional(self) -> bool | None:
return self._optional
def with_description(self, description: str) -> Self:
self._description = description
return self
def with_default(self, default: Any) -> Self:
self._default = default
return self
def with_optional(self, optional: bool) -> Self:
self._optional = optional
return self
def with_required(self, required: bool = True) -> Self:
self._optional = not required
return self

View File

@@ -91,22 +91,26 @@ class Field:
self._optional = optional
return self
def with_required(self, required: bool = True) -> Self:
self._optional = not required
return self
def with_default(self, default) -> Self:
self._default = default
return self
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self:
def with_argument(self, name: str, arg_type: type, description: str = None, default_value=None, optional=True) -> Argument:
if name in self._args:
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
self._args[name] = Argument(arg_type, name, description, default_value, optional)
return self
self._args[name] = Argument(name, arg_type, description, default_value, optional)
return self._args[name]
def with_arguments(self, args: list[Argument]) -> Self:
for arg in args:
if not isinstance(arg, Argument):
raise ValueError(f"Expected Argument instance, got {type(arg)}")
self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional)
self.with_argument(arg.type, arg.name, arg.description, arg.default, arg.optional)
return self
def with_require_any_permission(self, *permissions: TRequireAnyPermissions) -> Self:
@@ -126,7 +130,7 @@ class Field:
self._require_any = (permissions, resolvers)
return self
def with_public(self, public: bool = False) -> Self:
def with_public(self, public: bool = True) -> Self:
assert self._require_any is None, "Field cannot be public and have require_any set"
assert self._require_any_permission is None, "Field cannot be public and have require_any_permission set"
self._public = public

View File

@@ -1,4 +1,4 @@
from typing import Generic, Dict, Type, Optional, Self, Union
from typing import Generic, Dict, Type, Optional, Union, Any
import strawberry
@@ -12,12 +12,52 @@ _PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
class Input(StrawberryProtocol, Generic[T]):
def __init__(self):
self._fields: Dict[str, Field] = {}
self._values: Dict[str, Any] = {}
@property
def fields(self) -> Dict[str, Field]:
return self._fields
def __getattr__(self, item):
if item in self._values:
return self._values[item]
raise AttributeError(f"{self.__class__.__name__} has no attribute {item}")
def __setattr__(self, key, value):
if key in {"_fields", "_values"}:
super().__setattr__(key, value)
elif key in self._fields:
self._values[key] = value
else:
super().__setattr__(key, value)
def get(self, key: str, default=None):
return self._values.get(key, default)
def get_fields(self) -> dict[str, Field]:
return self._fields
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True):
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True) -> Field:
self._fields[name] = Field(name, typ, optional=optional)
return self._fields[name]
def string_field(self, name: str, optional: bool = True) -> Field:
return self.field(name, str)
def int_field(self, name: str, optional: bool = True) -> Field:
return self.field(name, int, optional)
def float_field(self, name: str, optional: bool = True) -> Field:
return self.field(name, float, optional)
def bool_field(self, name: str, optional: bool = True) -> Field:
return self.field(name, bool, optional)
def list_field(self, name: str, t: type, optional: bool = True) -> Field:
return self.field(name, list[t], optional)
def object_field(self, name: str, t: Type[StrawberryProtocol], optional: bool = True) -> Field:
return self.field(name, t().to_strawberry(), optional)
def to_strawberry(self) -> Type:
cls = self.__class__

View File

@@ -0,0 +1,25 @@
from typing import Type
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.query_abc import QueryABC
from cpl.graphql.schema.field import Field
class Mutation(QueryABC):
@inject
def __init__(self, provider: ServiceProvider):
QueryABC.__init__(self)
self._provider = provider
from cpl.graphql.service.schema import Schema
self._schema = provider.get_service(Schema)
def with_mutation(self, name: str, cls: Type["Mutation"]) -> Field:
sub = self._provider.get_service(cls)
if not sub:
raise ValueError(f"Mutation '{cls.__name__}' not registered in service provider")
return self.field(name, sub.to_strawberry(), lambda: sub)

View File

@@ -1,76 +1,32 @@
import functools
import inspect
from asyncio import iscoroutinefunction
from typing import Callable, Type, Any, Optional
from typing import Callable, Type
import strawberry
from strawberry.exceptions import StrawberryException
from cpl.api import Unauthorized, Forbidden
from cpl.core.ctx import get_user
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.query_abc import QueryABC
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.error import graphql_error
from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
from cpl.graphql.schema.field import Field
from cpl.graphql.schema.sort.sort_order import SortOrder
from cpl.graphql.typing import Resolver
from cpl.graphql.utils.type_collector import TypeCollector
class Query(StrawberryProtocol):
class Query(QueryABC):
@inject
def __init__(self, provider: ServiceProvider):
QueryABC.__init__(self)
self._provider = provider
from cpl.graphql.service.schema import Schema
self._schema = provider.get_service(Schema)
self._fields: dict[str, Field] = {}
def get_fields(self) -> dict[str, Field]:
return self._fields
def field(
self,
name: str,
t: type,
resolver: Resolver = None,
) -> Field:
from cpl.graphql.schema.field import Field
self._fields[name] = Field(name, t, resolver)
return self._fields[name]
def string_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, str, resolver)
def int_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, int, resolver)
def float_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, float, resolver)
def bool_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, bool, resolver)
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
return self.field(name, list[t], resolver)
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
return self.field(name, t().to_strawberry(), resolver)
def with_query(self, name: str, subquery_cls: Type["Query"]):
def with_query(self, name: str, subquery_cls: Type["Query"]) -> Field:
sub = self._provider.get_service(subquery_cls)
if not sub:
raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider")
self.field(name, sub.to_strawberry(), lambda: sub)
return self
return self.field(name, sub.to_strawberry(), lambda: sub)
def collection_field(
self,
@@ -105,10 +61,10 @@ class Query(StrawberryProtocol):
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection)
f.with_argument(filter.to_strawberry(), "filter")
f.with_argument(sort.to_strawberry(), "sort")
f.with_argument(int, "skip", default_value=0)
f.with_argument(int, "take", default_value=10)
f.with_argument("filter", filter.to_strawberry())
f.with_argument("sort", sort.to_strawberry())
f.with_argument("skip", int, default_value=0)
f.with_argument("take", int, default_value=10)
return f
def dao_collection_field(
@@ -168,120 +124,8 @@ class Query(StrawberryProtocol):
return Collection(nodes=data, total_count=total_count, count=len(data))
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver)
f.with_argument(filter.to_strawberry(), "filter")
f.with_argument(sort.to_strawberry(), "sort")
f.with_argument(int, "skip", default_value=0)
f.with_argument(int, "take", default_value=10)
f.with_argument("filter", filter.to_strawberry())
f.with_argument("sort", sort.to_strawberry())
f.with_argument("skip", int, default_value=0)
f.with_argument("take", int, default_value=10)
return f
@staticmethod
def _build_resolver(f: "Field"):
params: list[inspect.Parameter] = []
for arg in f.arguments.values():
ann = Optional[arg.type] if arg.optional else arg.type
if arg.default_value is None:
param = inspect.Parameter(
arg.name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=ann,
)
else:
param = inspect.Parameter(
arg.name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=ann,
default=arg.default_value,
)
params.append(param)
sig = inspect.Signature(parameters=params, return_annotation=f.type)
async def _resolver(*args, **kwargs):
if f.resolver is None:
return None
if iscoroutinefunction(f.resolver):
return await f.resolver(*args, **kwargs)
return f.resolver(*args, **kwargs)
_resolver.__signature__ = sig
return _resolver
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
sig = getattr(resolver, "__signature__", None)
@functools.wraps(resolver)
async def _auth_resolver(*args, **kwargs):
if f.public:
return await self._run_resolver(resolver, *args, **kwargs)
user = get_user()
if user is None:
raise graphql_error(Unauthorized(f"{f.name}: Authentication required"))
if f.require_any_permission:
if not any([await user.has_permission(p) for p in f.require_any_permission]):
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
if f.require_any:
perms, resolvers = f.require_any
if not any([await user.has_permission(p) for p in perms]):
ctx = QueryContext([x.name for x in await user.permissions])
resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers]
if not any(resolved):
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
return await self._run_resolver(resolver, *args, **kwargs)
if sig:
_auth_resolver.__signature__ = sig
return _auth_resolver
@staticmethod
async def _run_resolver(r: Callable, *args, **kwargs):
if iscoroutinefunction(r):
return await r(*args, **kwargs)
return r(*args, **kwargs)
def _field_to_strawberry(self, f: Field) -> Any:
resolver = None
try:
if f.arguments:
resolver = self._build_resolver(f)
elif not f.resolver:
resolver = lambda *_, **__: None
else:
ann = getattr(f.resolver, "__annotations__", {})
if "return" not in ann or ann["return"] is None:
ann = dict(ann)
ann["return"] = f.type
f.resolver.__annotations__ = ann
resolver = f.resolver
return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
except StrawberryException as e:
raise Exception(
f"Error converting field '{f.name}' to strawberry field: {e}"
) from e
def to_strawberry(self) -> Type:
cls = self.__class__
if TypeCollector.has(cls):
return TypeCollector.get(cls)
annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {}
for name, f in self._fields.items():
annotations[name] = f.type
namespace[name] = self._field_to_strawberry(f)
namespace["__annotations__"] = annotations
gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
TypeCollector.set(cls, gql_type)
return gql_type

View File

@@ -0,0 +1,6 @@
from cpl.graphql.schema.mutation import Mutation
class RootMutation(Mutation):
def __init__(self):
Mutation.__init__(self)

View File

@@ -6,6 +6,7 @@ import strawberry
from cpl.api.logger import APILogger
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.root_mutation import RootMutation
from cpl.graphql.schema.root_query import RootQuery
@@ -25,7 +26,17 @@ class Schema:
@property
def query(self) -> RootQuery:
return self._provider.get_service(RootQuery)
query = self._provider.get_service(RootQuery)
if not query:
raise ValueError("RootQuery not registered in service provider")
return query
@property
def mutation(self) -> RootMutation:
mutation = self._provider.get_service(RootMutation)
if not mutation:
raise ValueError("RootMutation not registered in service provider")
return mutation
def with_type(self, t: Type[StrawberryProtocol]) -> Self:
self._types[t.__name__] = t
@@ -43,13 +54,13 @@ class Schema:
def build(self) -> strawberry.Schema:
logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL)
query = self._provider.get_service(RootQuery)
if not query:
raise ValueError("RootQuery not registered in service provider")
query = self.query
mutation = self.mutation
self._schema = strawberry.Schema(
query=query.to_strawberry(),
mutation=None,
query=query.to_strawberry() if query.fields_count > 0 else None,
mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None,
subscription=None,
types=self._get_types(),
)