From 71199f9b9a709a9d071ee3ff47491893988e2e23 Mon Sep 17 00:00:00 2001 From: edraft Date: Sun, 28 Sep 2025 18:51:28 +0200 Subject: [PATCH] Added mutations #181 --- example/api/src/main.py | 13 +- example/api/src/model/author_dao.py | 4 +- example/api/src/model/post.py | 8 + example/api/src/model/post_dao.py | 2 +- example/api/src/model/post_query.py | 85 +++++++- example/api/src/queries/hello.py | 2 +- src/cpl-graphql/cpl/graphql/abc/query_abc.py | 178 +++++++++++++++++ src/cpl-graphql/cpl/graphql/graphql_module.py | 5 +- src/cpl-graphql/cpl/graphql/query_context.py | 3 +- .../cpl/graphql/schema/argument.py | 38 ++-- src/cpl-graphql/cpl/graphql/schema/field.py | 14 +- src/cpl-graphql/cpl/graphql/schema/input.py | 44 ++++- .../cpl/graphql/schema/mutation.py | 25 +++ src/cpl-graphql/cpl/graphql/schema/query.py | 184 ++---------------- .../cpl/graphql/schema/root_mutation.py | 6 + src/cpl-graphql/cpl/graphql/service/schema.py | 23 ++- 16 files changed, 424 insertions(+), 210 deletions(-) create mode 100644 src/cpl-graphql/cpl/graphql/abc/query_abc.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/mutation.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/root_mutation.py diff --git a/example/api/src/main.py b/example/api/src/main.py index 06e39aa4..fdd7dff0 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -5,7 +5,6 @@ from api.src.queries.hello import UserGraphType, AuthUserFilter, AuthUserSort, A from api.src.queries.user import UserFilter, UserSort from cpl.api.api_module import ApiModule from cpl.application.application_builder import ApplicationBuilder -from cpl.auth.permission.permissions import Permissions from cpl.auth.schema import AuthUser, Role from cpl.core.configuration import Configuration from cpl.core.console import Console @@ -17,7 +16,7 @@ from cpl.graphql.graphql_module import GraphQLModule from model.author_dao import AuthorDao from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort from model.post_dao import PostDao -from model.post_query import PostFilter, PostSort, PostGraphType +from model.post_query import PostFilter, PostSort, PostGraphType, PostMutation from permissions import PostPermissions from queries.hello import HelloQuery from scoped_service import ScopedService @@ -64,6 +63,7 @@ def main(): .add_transient(PostGraphType) .add_transient(PostFilter) .add_transient(PostSort) + .add_transient(PostMutation) ) app = builder.build() @@ -77,8 +77,8 @@ def main(): path="/route1", fn=lambda r: JSONResponse("route1"), method="GET", - authentication=True, - permissions=[Permissions.administrator], + # authentication=True, + # permissions=[Permissions.administrator], ) app.with_routes_directory("routes") @@ -88,9 +88,12 @@ def main(): schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort) ( schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort) - .with_require_any_permission(PostPermissions.read) + # .with_require_any_permission(PostPermissions.read) + .with_public() ) + schema.mutation.with_mutation("post", PostMutation).with_public() + app.with_playground() app.with_graphiql() diff --git a/example/api/src/model/author_dao.py b/example/api/src/model/author_dao.py index 98b997a6..d1b1afc0 100644 --- a/example/api/src/model/author_dao.py +++ b/example/api/src/model/author_dao.py @@ -7,5 +7,5 @@ class AuthorDao(DbModelDaoABC): def __init__(self): DbModelDaoABC.__init__(self, Author, "authors") - self.attribute(Author.first_name, str) - self.attribute(Author.last_name, str) \ No newline at end of file + self.attribute(Author.first_name, str, db_name="firstname") + self.attribute(Author.last_name, str, db_name="lastname") \ No newline at end of file diff --git a/example/api/src/model/post.py b/example/api/src/model/post.py index d5801cd0..15b670b8 100644 --- a/example/api/src/model/post.py +++ b/example/api/src/model/post.py @@ -31,6 +31,14 @@ class Post(DbModelABC[Self]): def title(self) -> str: return self._title + @title.setter + def title(self, value: str): + self._title = value + @property def content(self) -> str: return self._content + + @content.setter + def content(self, value: str): + self._content = value diff --git a/example/api/src/model/post_dao.py b/example/api/src/model/post_dao.py index be8e5668..3205f8de 100644 --- a/example/api/src/model/post_dao.py +++ b/example/api/src/model/post_dao.py @@ -3,7 +3,7 @@ from model.author_dao import AuthorDao from model.post import Post -class PostDao(DbModelDaoABC): +class PostDao(DbModelDaoABC[Post]): def __init__(self, authors: AuthorDao): DbModelDaoABC.__init__(self, Post, "posts") diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py index 381c94ca..6334c51e 100644 --- a/example/api/src/model/post_query.py +++ b/example/api/src/model/post_query.py @@ -1,11 +1,15 @@ from cpl.graphql.query_context import QueryContext from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.schema.input import Input +from cpl.graphql.schema.mutation import Mutation from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort_order import SortOrder from model.author_dao import AuthorDao from model.author_query import AuthorGraphType, AuthorFilter from model.post import Post +from model.post_dao import PostDao + class PostFilter(Filter[Post]): def __init__(self): @@ -38,9 +42,7 @@ class PostGraphType(GraphType[Post]): def r_name(ctx: QueryContext): return ctx.user.username == "admin" - self.object_field("author", AuthorGraphType, resolver=_a).with_require_any( - [], [r_name] - ) + self.object_field("author", AuthorGraphType, resolver=_a).with_public(True)# .with_require_any([], [r_name])) self.string_field( "title", resolver=lambda root: root.title, @@ -49,3 +51,80 @@ class PostGraphType(GraphType[Post]): "content", resolver=lambda root: root.content, ).with_public(True) + + +class PostCreateInput(Input[Post]): + title: str + content: str + author_id: int + + def __init__(self): + Input.__init__(self) + self.string_field("title").with_required() + self.string_field("content").with_required() + self.int_field("author_id").with_required() + +class PostUpdateInput(Input[Post]): + title: str + content: str + author_id: int + + def __init__(self): + Input.__init__(self) + self.int_field("id").with_required() + self.string_field("title").with_required(False) + self.string_field("content").with_required(False) + +class PostMutation(Mutation): + + def __init__(self, posts: PostDao, authors: AuthorDao): + Mutation.__init__(self) + + self._posts = posts + self._authors = authors + + self.field("create", int, resolver=self.create_post).with_public().with_required().with_argument( + "input", + PostCreateInput, + ).with_required() + self.field("update", bool, resolver=self.update_post).with_public().with_required().with_argument( + "input", + PostUpdateInput, + ).with_required() + self.field("delete", bool, resolver=self.delete_post).with_public().with_required().with_argument( + "id", + int, + ).with_required() + self.field("restore", bool, resolver=self.restore_post).with_public().with_required().with_argument( + "id", + int, + ).with_required() + + async def create_post(self, input: PostCreateInput) -> int: + return await self._posts.create(Post(0, input.author_id, input.title, input.content)) + + async def update_post(self, input: PostUpdateInput) -> bool: + post = await self._posts.get_by_id(input.id) + if post is None: + return False + + post.title = input.title if input.title is not None else post.title + post.content = input.content if input.content is not None else post.content + + await self._posts.update(post) + return True + + async def delete_post(self, id: int) -> bool: + post = await self._posts.get_by_id(id) + if post is None: + return False + await self._posts.delete(post) + return True + + async def restore_post(self, id: int) -> bool: + post = await self._posts.get_by_id(id) + if post is None: + return False + await self._posts.restore(post) + return True + diff --git a/example/api/src/queries/hello.py b/example/api/src/queries/hello.py index addd9173..864e39ab 100644 --- a/example/api/src/queries/hello.py +++ b/example/api/src/queries/hello.py @@ -44,7 +44,7 @@ class HelloQuery(Query): self.string_field( "message", resolver=lambda name: f"Hello {name} {get_request().state.request_id}", - ).with_argument(str, "name", "Name to greet", "world") + ).with_argument("name", str, "Name to greet", "world") self.collection_field( UserGraphType, diff --git a/src/cpl-graphql/cpl/graphql/abc/query_abc.py b/src/cpl-graphql/cpl/graphql/abc/query_abc.py new file mode 100644 index 00000000..5023ebea --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/abc/query_abc.py @@ -0,0 +1,178 @@ +import functools +import inspect +from abc import ABC +from asyncio import iscoroutinefunction +from typing import Callable, Type, Any, Optional + +import strawberry +from strawberry.exceptions import StrawberryException + +from cpl.api import Unauthorized, Forbidden +from cpl.core.ctx.user_context import get_user +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol +from cpl.graphql.error import graphql_error +from cpl.graphql.query_context import QueryContext +from cpl.graphql.schema.field import Field +from cpl.graphql.typing import Resolver +from cpl.graphql.utils.type_collector import TypeCollector + + +class QueryABC(StrawberryProtocol, ABC): + + def __init__(self): + ABC.__init__(self) + self._fields: dict[str, Field] = {} + + @property + def fields(self) -> dict[str, Field]: + return self._fields + + @property + def fields_count(self) -> int: + return len(self._fields) + + def get_fields(self) -> dict[str, Field]: + return self._fields + + def field( + self, + name: str, + t: type, + resolver: Resolver = None, + ) -> Field: + from cpl.graphql.schema.field import Field + + self._fields[name] = Field(name, t, resolver) + return self._fields[name] + + def string_field(self, name: str, resolver: Resolver = None) -> Field: + return self.field(name, str, resolver) + + def int_field(self, name: str, resolver: Resolver = None) -> Field: + return self.field(name, int, resolver) + + def float_field(self, name: str, resolver: Resolver = None) -> Field: + return self.field(name, float, resolver) + + def bool_field(self, name: str, resolver: Resolver = None) -> Field: + return self.field(name, bool, resolver) + + def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field: + return self.field(name, list[t], resolver) + + def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field: + return self.field(name, t().to_strawberry(), resolver) + + @staticmethod + def _build_resolver(f: "Field"): + params: list[inspect.Parameter] = [] + for arg in f.arguments.values(): + _type = arg.type + if isinstance(_type, type) and issubclass(_type, StrawberryProtocol): + _type = _type().to_strawberry() + + ann = Optional[_type] if arg.optional else _type + + if arg.default is None: + param = inspect.Parameter( + arg.name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=ann, + ) + else: + param = inspect.Parameter( + arg.name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=ann, + default=arg.default, + ) + + params.append(param) + + sig = inspect.Signature(parameters=params, return_annotation=f.type) + + async def _resolver(*args, **kwargs): + if f.resolver is None: + return None + + if iscoroutinefunction(f.resolver): + return await f.resolver(*args, **kwargs) + return f.resolver(*args, **kwargs) + + _resolver.__signature__ = sig + return _resolver + + def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable: + sig = getattr(resolver, "__signature__", None) + + @functools.wraps(resolver) + async def _auth_resolver(*args, **kwargs): + if f.public: + return await self._run_resolver(resolver, *args, **kwargs) + + user = get_user() + + if user is None: + raise graphql_error(Unauthorized(f"{f.name}: Authentication required")) + + if f.require_any_permission: + if not any([await user.has_permission(p) for p in f.require_any_permission]): + raise graphql_error(Forbidden(f"{f.name}: Permission denied")) + + if f.require_any: + perms, resolvers = f.require_any + if not any([await user.has_permission(p) for p in perms]): + ctx = QueryContext([x.name for x in await user.permissions]) + resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers] + + if not any(resolved): + raise graphql_error(Forbidden(f"{f.name}: Permission denied")) + + return await self._run_resolver(resolver, *args, **kwargs) + + if sig: + _auth_resolver.__signature__ = sig + + return _auth_resolver + + @staticmethod + async def _run_resolver(r: Callable, *args, **kwargs): + if iscoroutinefunction(r): + return await r(*args, **kwargs) + return r(*args, **kwargs) + + def _field_to_strawberry(self, f: Field) -> Any: + resolver = None + try: + if f.arguments: + resolver = self._build_resolver(f) + elif not f.resolver: + resolver = lambda *_, **__: None + else: + ann = getattr(f.resolver, "__annotations__", {}) + if "return" not in ann or ann["return"] is None: + ann = dict(ann) + ann["return"] = f.type + f.resolver.__annotations__ = ann + resolver = f.resolver + + return strawberry.field(resolver=self._wrap_with_auth(f, resolver)) + except StrawberryException as e: + raise Exception(f"Error converting field '{f.name}' to strawberry field: {e}") from e + + def to_strawberry(self) -> Type: + cls = self.__class__ + if TypeCollector.has(cls): + return TypeCollector.get(cls) + + annotations: dict[str, Any] = {} + namespace: dict[str, Any] = {} + + for name, f in self._fields.items(): + annotations[name] = f.type + namespace[name] = self._field_to_strawberry(f) + + namespace["__annotations__"] = annotations + gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace)) + TypeCollector.set(cls, gql_type) + return gql_type diff --git a/src/cpl-graphql/cpl/graphql/graphql_module.py b/src/cpl-graphql/cpl/graphql/graphql_module.py index d9d66aee..b749d16e 100644 --- a/src/cpl-graphql/cpl/graphql/graphql_module.py +++ b/src/cpl-graphql/cpl/graphql/graphql_module.py @@ -6,14 +6,15 @@ from cpl.graphql.schema.filter.date_filter import DateFilter from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.filter.int_filter import IntFilter from cpl.graphql.schema.filter.string_filter import StringFilter +from cpl.graphql.schema.root_mutation import RootMutation from cpl.graphql.schema.root_query import RootQuery -from cpl.graphql.service.schema import Schema from cpl.graphql.service.graphql import GraphQLService +from cpl.graphql.service.schema import Schema class GraphQLModule(Module): dependencies = [ApiModule] - singleton = [Schema, RootQuery] + singleton = [Schema, RootQuery, RootMutation] scoped = [GraphQLService] transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter] diff --git a/src/cpl-graphql/cpl/graphql/query_context.py b/src/cpl-graphql/cpl/graphql/query_context.py index 9b75d694..0c8f5781 100644 --- a/src/cpl-graphql/cpl/graphql/query_context.py +++ b/src/cpl-graphql/cpl/graphql/query_context.py @@ -1,11 +1,10 @@ from enum import Enum -from typing import Optional, Any +from typing import Optional from graphql import GraphQLResolveInfo from cpl.auth.schema import AuthUser, Permission from cpl.core.ctx import get_user -from cpl.core.utils import get_value class QueryContext: diff --git a/src/cpl-graphql/cpl/graphql/schema/argument.py b/src/cpl-graphql/cpl/graphql/schema/argument.py index cbf8b32f..3332ddd0 100644 --- a/src/cpl-graphql/cpl/graphql/schema/argument.py +++ b/src/cpl-graphql/cpl/graphql/schema/argument.py @@ -1,38 +1,54 @@ -from typing import Any +from typing import Any, Self class Argument: def __init__( self, - t: type, name: str, + t: type, description: str = None, - default_value: Any = None, + default: Any = None, optional: bool = None, ): - self._type = t self._name = name + self._type = t self._description = description - self._default_value = default_value + self._default = default self._optional = optional - @property - def type(self) -> type: - return self._type - @property def name(self) -> str: return self._name + @property + def type(self) -> type: + return self._type + @property def description(self) -> str | None: return self._description @property - def default_value(self) -> Any | None: - return self._default_value + def default(self) -> Any | None: + return self._default @property def optional(self) -> bool | None: return self._optional + + def with_description(self, description: str) -> Self: + self._description = description + return self + + def with_default(self, default: Any) -> Self: + self._default = default + return self + + def with_optional(self, optional: bool) -> Self: + self._optional = optional + return self + + def with_required(self, required: bool = True) -> Self: + self._optional = not required + return self diff --git a/src/cpl-graphql/cpl/graphql/schema/field.py b/src/cpl-graphql/cpl/graphql/schema/field.py index 421413a4..8eceba25 100644 --- a/src/cpl-graphql/cpl/graphql/schema/field.py +++ b/src/cpl-graphql/cpl/graphql/schema/field.py @@ -91,22 +91,26 @@ class Field: self._optional = optional return self + def with_required(self, required: bool = True) -> Self: + self._optional = not required + return self + def with_default(self, default) -> Self: self._default = default return self - def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self: + def with_argument(self, name: str, arg_type: type, description: str = None, default_value=None, optional=True) -> Argument: if name in self._args: raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'") - self._args[name] = Argument(arg_type, name, description, default_value, optional) - return self + self._args[name] = Argument(name, arg_type, description, default_value, optional) + return self._args[name] def with_arguments(self, args: list[Argument]) -> Self: for arg in args: if not isinstance(arg, Argument): raise ValueError(f"Expected Argument instance, got {type(arg)}") - self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional) + self.with_argument(arg.type, arg.name, arg.description, arg.default, arg.optional) return self def with_require_any_permission(self, *permissions: TRequireAnyPermissions) -> Self: @@ -126,7 +130,7 @@ class Field: self._require_any = (permissions, resolvers) return self - def with_public(self, public: bool = False) -> Self: + def with_public(self, public: bool = True) -> Self: assert self._require_any is None, "Field cannot be public and have require_any set" assert self._require_any_permission is None, "Field cannot be public and have require_any_permission set" self._public = public diff --git a/src/cpl-graphql/cpl/graphql/schema/input.py b/src/cpl-graphql/cpl/graphql/schema/input.py index a4dfebdf..6e639db3 100644 --- a/src/cpl-graphql/cpl/graphql/schema/input.py +++ b/src/cpl-graphql/cpl/graphql/schema/input.py @@ -1,4 +1,4 @@ -from typing import Generic, Dict, Type, Optional, Self, Union +from typing import Generic, Dict, Type, Optional, Union, Any import strawberry @@ -12,12 +12,52 @@ _PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"} class Input(StrawberryProtocol, Generic[T]): def __init__(self): self._fields: Dict[str, Field] = {} + self._values: Dict[str, Any] = {} + + @property + def fields(self) -> Dict[str, Field]: + return self._fields + + def __getattr__(self, item): + if item in self._values: + return self._values[item] + raise AttributeError(f"{self.__class__.__name__} has no attribute {item}") + + def __setattr__(self, key, value): + if key in {"_fields", "_values"}: + super().__setattr__(key, value) + elif key in self._fields: + self._values[key] = value + else: + super().__setattr__(key, value) + + def get(self, key: str, default=None): + return self._values.get(key, default) def get_fields(self) -> dict[str, Field]: return self._fields - def field(self, name: str, typ: Union[type, "Input"], optional: bool = True): + def field(self, name: str, typ: Union[type, "Input"], optional: bool = True) -> Field: self._fields[name] = Field(name, typ, optional=optional) + return self._fields[name] + + def string_field(self, name: str, optional: bool = True) -> Field: + return self.field(name, str) + + def int_field(self, name: str, optional: bool = True) -> Field: + return self.field(name, int, optional) + + def float_field(self, name: str, optional: bool = True) -> Field: + return self.field(name, float, optional) + + def bool_field(self, name: str, optional: bool = True) -> Field: + return self.field(name, bool, optional) + + def list_field(self, name: str, t: type, optional: bool = True) -> Field: + return self.field(name, list[t], optional) + + def object_field(self, name: str, t: Type[StrawberryProtocol], optional: bool = True) -> Field: + return self.field(name, t().to_strawberry(), optional) def to_strawberry(self) -> Type: cls = self.__class__ diff --git a/src/cpl-graphql/cpl/graphql/schema/mutation.py b/src/cpl-graphql/cpl/graphql/schema/mutation.py new file mode 100644 index 00000000..691cee10 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/mutation.py @@ -0,0 +1,25 @@ +from typing import Type + +from cpl.dependency.inject import inject +from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.abc.query_abc import QueryABC +from cpl.graphql.schema.field import Field + + +class Mutation(QueryABC): + + @inject + def __init__(self, provider: ServiceProvider): + QueryABC.__init__(self) + self._provider = provider + + from cpl.graphql.service.schema import Schema + + self._schema = provider.get_service(Schema) + + def with_mutation(self, name: str, cls: Type["Mutation"]) -> Field: + sub = self._provider.get_service(cls) + if not sub: + raise ValueError(f"Mutation '{cls.__name__}' not registered in service provider") + + return self.field(name, sub.to_strawberry(), lambda: sub) diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py index 0b8df16f..cbd05781 100644 --- a/src/cpl-graphql/cpl/graphql/schema/query.py +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -1,76 +1,32 @@ -import functools -import inspect -from asyncio import iscoroutinefunction -from typing import Callable, Type, Any, Optional +from typing import Callable, Type -import strawberry -from strawberry.exceptions import StrawberryException - -from cpl.api import Unauthorized, Forbidden -from cpl.core.ctx import get_user from cpl.database.abc.data_access_object_abc import DataAccessObjectABC from cpl.dependency.inject import inject from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.abc.query_abc import QueryABC from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol -from cpl.graphql.error import graphql_error -from cpl.graphql.query_context import QueryContext from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory from cpl.graphql.schema.field import Field from cpl.graphql.schema.sort.sort_order import SortOrder -from cpl.graphql.typing import Resolver -from cpl.graphql.utils.type_collector import TypeCollector -class Query(StrawberryProtocol): +class Query(QueryABC): @inject def __init__(self, provider: ServiceProvider): + QueryABC.__init__(self) self._provider = provider from cpl.graphql.service.schema import Schema self._schema = provider.get_service(Schema) - self._fields: dict[str, Field] = {} - def get_fields(self) -> dict[str, Field]: - return self._fields - - def field( - self, - name: str, - t: type, - resolver: Resolver = None, - ) -> Field: - from cpl.graphql.schema.field import Field - - self._fields[name] = Field(name, t, resolver) - return self._fields[name] - - def string_field(self, name: str, resolver: Resolver = None) -> Field: - return self.field(name, str, resolver) - - def int_field(self, name: str, resolver: Resolver = None) -> Field: - return self.field(name, int, resolver) - - def float_field(self, name: str, resolver: Resolver = None) -> Field: - return self.field(name, float, resolver) - - def bool_field(self, name: str, resolver: Resolver = None) -> Field: - return self.field(name, bool, resolver) - - def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field: - return self.field(name, list[t], resolver) - - def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field: - return self.field(name, t().to_strawberry(), resolver) - - def with_query(self, name: str, subquery_cls: Type["Query"]): + def with_query(self, name: str, subquery_cls: Type["Query"]) -> Field: sub = self._provider.get_service(subquery_cls) if not sub: raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider") - self.field(name, sub.to_strawberry(), lambda: sub) - return self + return self.field(name, sub.to_strawberry(), lambda: sub) def collection_field( self, @@ -105,10 +61,10 @@ class Query(StrawberryProtocol): raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider") f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection) - f.with_argument(filter.to_strawberry(), "filter") - f.with_argument(sort.to_strawberry(), "sort") - f.with_argument(int, "skip", default_value=0) - f.with_argument(int, "take", default_value=10) + f.with_argument("filter", filter.to_strawberry()) + f.with_argument("sort", sort.to_strawberry()) + f.with_argument("skip", int, default_value=0) + f.with_argument("take", int, default_value=10) return f def dao_collection_field( @@ -168,120 +124,8 @@ class Query(StrawberryProtocol): return Collection(nodes=data, total_count=total_count, count=len(data)) f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver) - f.with_argument(filter.to_strawberry(), "filter") - f.with_argument(sort.to_strawberry(), "sort") - f.with_argument(int, "skip", default_value=0) - f.with_argument(int, "take", default_value=10) + f.with_argument("filter", filter.to_strawberry()) + f.with_argument("sort", sort.to_strawberry()) + f.with_argument("skip", int, default_value=0) + f.with_argument("take", int, default_value=10) return f - - @staticmethod - def _build_resolver(f: "Field"): - params: list[inspect.Parameter] = [] - for arg in f.arguments.values(): - ann = Optional[arg.type] if arg.optional else arg.type - - if arg.default_value is None: - param = inspect.Parameter( - arg.name, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=ann, - ) - else: - param = inspect.Parameter( - arg.name, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=ann, - default=arg.default_value, - ) - - params.append(param) - - sig = inspect.Signature(parameters=params, return_annotation=f.type) - - async def _resolver(*args, **kwargs): - if f.resolver is None: - return None - - if iscoroutinefunction(f.resolver): - return await f.resolver(*args, **kwargs) - return f.resolver(*args, **kwargs) - - _resolver.__signature__ = sig - return _resolver - - def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable: - sig = getattr(resolver, "__signature__", None) - - @functools.wraps(resolver) - async def _auth_resolver(*args, **kwargs): - if f.public: - return await self._run_resolver(resolver, *args, **kwargs) - - user = get_user() - - if user is None: - raise graphql_error(Unauthorized(f"{f.name}: Authentication required")) - - if f.require_any_permission: - if not any([await user.has_permission(p) for p in f.require_any_permission]): - raise graphql_error(Forbidden(f"{f.name}: Permission denied")) - - if f.require_any: - perms, resolvers = f.require_any - if not any([await user.has_permission(p) for p in perms]): - ctx = QueryContext([x.name for x in await user.permissions]) - resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers] - - if not any(resolved): - raise graphql_error(Forbidden(f"{f.name}: Permission denied")) - - return await self._run_resolver(resolver, *args, **kwargs) - - if sig: - _auth_resolver.__signature__ = sig - - return _auth_resolver - - @staticmethod - async def _run_resolver(r: Callable, *args, **kwargs): - if iscoroutinefunction(r): - return await r(*args, **kwargs) - return r(*args, **kwargs) - - def _field_to_strawberry(self, f: Field) -> Any: - resolver = None - try: - if f.arguments: - resolver = self._build_resolver(f) - elif not f.resolver: - resolver = lambda *_, **__: None - else: - ann = getattr(f.resolver, "__annotations__", {}) - if "return" not in ann or ann["return"] is None: - ann = dict(ann) - ann["return"] = f.type - f.resolver.__annotations__ = ann - resolver = f.resolver - - return strawberry.field(resolver=self._wrap_with_auth(f, resolver)) - except StrawberryException as e: - raise Exception( - f"Error converting field '{f.name}' to strawberry field: {e}" - ) from e - - def to_strawberry(self) -> Type: - cls = self.__class__ - if TypeCollector.has(cls): - return TypeCollector.get(cls) - - annotations: dict[str, Any] = {} - namespace: dict[str, Any] = {} - - for name, f in self._fields.items(): - annotations[name] = f.type - namespace[name] = self._field_to_strawberry(f) - - namespace["__annotations__"] = annotations - gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace)) - TypeCollector.set(cls, gql_type) - return gql_type diff --git a/src/cpl-graphql/cpl/graphql/schema/root_mutation.py b/src/cpl-graphql/cpl/graphql/schema/root_mutation.py new file mode 100644 index 00000000..8855d8e7 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/root_mutation.py @@ -0,0 +1,6 @@ +from cpl.graphql.schema.mutation import Mutation + + +class RootMutation(Mutation): + def __init__(self): + Mutation.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py index c1c43cdc..9141f455 100644 --- a/src/cpl-graphql/cpl/graphql/service/schema.py +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -6,6 +6,7 @@ import strawberry from cpl.api.logger import APILogger from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol +from cpl.graphql.schema.root_mutation import RootMutation from cpl.graphql.schema.root_query import RootQuery @@ -25,7 +26,17 @@ class Schema: @property def query(self) -> RootQuery: - return self._provider.get_service(RootQuery) + query = self._provider.get_service(RootQuery) + if not query: + raise ValueError("RootQuery not registered in service provider") + return query + + @property + def mutation(self) -> RootMutation: + mutation = self._provider.get_service(RootMutation) + if not mutation: + raise ValueError("RootMutation not registered in service provider") + return mutation def with_type(self, t: Type[StrawberryProtocol]) -> Self: self._types[t.__name__] = t @@ -43,13 +54,13 @@ class Schema: def build(self) -> strawberry.Schema: logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL) - query = self._provider.get_service(RootQuery) - if not query: - raise ValueError("RootQuery not registered in service provider") + + query = self.query + mutation = self.mutation self._schema = strawberry.Schema( - query=query.to_strawberry(), - mutation=None, + query=query.to_strawberry() if query.fields_count > 0 else None, + mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None, subscription=None, types=self._get_types(), )