Added mutations #181

This commit is contained in:
2025-09-28 18:51:28 +02:00
parent 3286a95cbf
commit 39d06dfe48
16 changed files with 424 additions and 210 deletions

View File

@@ -5,7 +5,6 @@ from api.src.queries.hello import UserGraphType, AuthUserFilter, AuthUserSort, A
from api.src.queries.user import UserFilter, UserSort from api.src.queries.user import UserFilter, UserSort
from cpl.api.api_module import ApiModule from cpl.api.api_module import ApiModule
from cpl.application.application_builder import ApplicationBuilder from cpl.application.application_builder import ApplicationBuilder
from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema import AuthUser, Role from cpl.auth.schema import AuthUser, Role
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.core.console import Console from cpl.core.console import Console
@@ -17,7 +16,7 @@ from cpl.graphql.graphql_module import GraphQLModule
from model.author_dao import AuthorDao from model.author_dao import AuthorDao
from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort
from model.post_dao import PostDao from model.post_dao import PostDao
from model.post_query import PostFilter, PostSort, PostGraphType from model.post_query import PostFilter, PostSort, PostGraphType, PostMutation
from permissions import PostPermissions from permissions import PostPermissions
from queries.hello import HelloQuery from queries.hello import HelloQuery
from scoped_service import ScopedService from scoped_service import ScopedService
@@ -64,6 +63,7 @@ def main():
.add_transient(PostGraphType) .add_transient(PostGraphType)
.add_transient(PostFilter) .add_transient(PostFilter)
.add_transient(PostSort) .add_transient(PostSort)
.add_transient(PostMutation)
) )
app = builder.build() app = builder.build()
@@ -77,8 +77,8 @@ def main():
path="/route1", path="/route1",
fn=lambda r: JSONResponse("route1"), fn=lambda r: JSONResponse("route1"),
method="GET", method="GET",
authentication=True, # authentication=True,
permissions=[Permissions.administrator], # permissions=[Permissions.administrator],
) )
app.with_routes_directory("routes") app.with_routes_directory("routes")
@@ -88,9 +88,12 @@ def main():
schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort) schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort)
( (
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort) schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort)
.with_require_any_permission(PostPermissions.read) # .with_require_any_permission(PostPermissions.read)
.with_public()
) )
schema.mutation.with_mutation("post", PostMutation).with_public()
app.with_playground() app.with_playground()
app.with_graphiql() app.with_graphiql()

View File

@@ -7,5 +7,5 @@ class AuthorDao(DbModelDaoABC):
def __init__(self): def __init__(self):
DbModelDaoABC.__init__(self, Author, "authors") DbModelDaoABC.__init__(self, Author, "authors")
self.attribute(Author.first_name, str) self.attribute(Author.first_name, str, db_name="firstname")
self.attribute(Author.last_name, str) self.attribute(Author.last_name, str, db_name="lastname")

View File

@@ -31,6 +31,14 @@ class Post(DbModelABC[Self]):
def title(self) -> str: def title(self) -> str:
return self._title return self._title
@title.setter
def title(self, value: str):
self._title = value
@property @property
def content(self) -> str: def content(self) -> str:
return self._content return self._content
@content.setter
def content(self, value: str):
self._content = value

View File

@@ -3,7 +3,7 @@ from model.author_dao import AuthorDao
from model.post import Post from model.post import Post
class PostDao(DbModelDaoABC): class PostDao(DbModelDaoABC[Post]):
def __init__(self, authors: AuthorDao): def __init__(self, authors: AuthorDao):
DbModelDaoABC.__init__(self, Post, "posts") DbModelDaoABC.__init__(self, Post, "posts")

View File

@@ -1,11 +1,15 @@
from cpl.graphql.query_context import QueryContext from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.graph_type import GraphType from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.input import Input
from cpl.graphql.schema.mutation import Mutation
from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder from cpl.graphql.schema.sort.sort_order import SortOrder
from model.author_dao import AuthorDao from model.author_dao import AuthorDao
from model.author_query import AuthorGraphType, AuthorFilter from model.author_query import AuthorGraphType, AuthorFilter
from model.post import Post from model.post import Post
from model.post_dao import PostDao
class PostFilter(Filter[Post]): class PostFilter(Filter[Post]):
def __init__(self): def __init__(self):
@@ -38,9 +42,7 @@ class PostGraphType(GraphType[Post]):
def r_name(ctx: QueryContext): def r_name(ctx: QueryContext):
return ctx.user.username == "admin" return ctx.user.username == "admin"
self.object_field("author", AuthorGraphType, resolver=_a).with_require_any( self.object_field("author", AuthorGraphType, resolver=_a).with_public(True)# .with_require_any([], [r_name]))
[], [r_name]
)
self.string_field( self.string_field(
"title", "title",
resolver=lambda root: root.title, resolver=lambda root: root.title,
@@ -49,3 +51,80 @@ class PostGraphType(GraphType[Post]):
"content", "content",
resolver=lambda root: root.content, resolver=lambda root: root.content,
).with_public(True) ).with_public(True)
class PostCreateInput(Input[Post]):
title: str
content: str
author_id: int
def __init__(self):
Input.__init__(self)
self.string_field("title").with_required()
self.string_field("content").with_required()
self.int_field("author_id").with_required()
class PostUpdateInput(Input[Post]):
title: str
content: str
author_id: int
def __init__(self):
Input.__init__(self)
self.int_field("id").with_required()
self.string_field("title").with_required(False)
self.string_field("content").with_required(False)
class PostMutation(Mutation):
def __init__(self, posts: PostDao, authors: AuthorDao):
Mutation.__init__(self)
self._posts = posts
self._authors = authors
self.field("create", int, resolver=self.create_post).with_public().with_required().with_argument(
"input",
PostCreateInput,
).with_required()
self.field("update", bool, resolver=self.update_post).with_public().with_required().with_argument(
"input",
PostUpdateInput,
).with_required()
self.field("delete", bool, resolver=self.delete_post).with_public().with_required().with_argument(
"id",
int,
).with_required()
self.field("restore", bool, resolver=self.restore_post).with_public().with_required().with_argument(
"id",
int,
).with_required()
async def create_post(self, input: PostCreateInput) -> int:
return await self._posts.create(Post(0, input.author_id, input.title, input.content))
async def update_post(self, input: PostUpdateInput) -> bool:
post = await self._posts.get_by_id(input.id)
if post is None:
return False
post.title = input.title if input.title is not None else post.title
post.content = input.content if input.content is not None else post.content
await self._posts.update(post)
return True
async def delete_post(self, id: int) -> bool:
post = await self._posts.get_by_id(id)
if post is None:
return False
await self._posts.delete(post)
return True
async def restore_post(self, id: int) -> bool:
post = await self._posts.get_by_id(id)
if post is None:
return False
await self._posts.restore(post)
return True

View File

@@ -44,7 +44,7 @@ class HelloQuery(Query):
self.string_field( self.string_field(
"message", "message",
resolver=lambda name: f"Hello {name} {get_request().state.request_id}", resolver=lambda name: f"Hello {name} {get_request().state.request_id}",
).with_argument(str, "name", "Name to greet", "world") ).with_argument("name", str, "Name to greet", "world")
self.collection_field( self.collection_field(
UserGraphType, UserGraphType,

View File

@@ -0,0 +1,178 @@
import functools
import inspect
from abc import ABC
from asyncio import iscoroutinefunction
from typing import Callable, Type, Any, Optional
import strawberry
from strawberry.exceptions import StrawberryException
from cpl.api import Unauthorized, Forbidden
from cpl.core.ctx.user_context import get_user
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.error import graphql_error
from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.field import Field
from cpl.graphql.typing import Resolver
from cpl.graphql.utils.type_collector import TypeCollector
class QueryABC(StrawberryProtocol, ABC):
def __init__(self):
ABC.__init__(self)
self._fields: dict[str, Field] = {}
@property
def fields(self) -> dict[str, Field]:
return self._fields
@property
def fields_count(self) -> int:
return len(self._fields)
def get_fields(self) -> dict[str, Field]:
return self._fields
def field(
self,
name: str,
t: type,
resolver: Resolver = None,
) -> Field:
from cpl.graphql.schema.field import Field
self._fields[name] = Field(name, t, resolver)
return self._fields[name]
def string_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, str, resolver)
def int_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, int, resolver)
def float_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, float, resolver)
def bool_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, bool, resolver)
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
return self.field(name, list[t], resolver)
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
return self.field(name, t().to_strawberry(), resolver)
@staticmethod
def _build_resolver(f: "Field"):
params: list[inspect.Parameter] = []
for arg in f.arguments.values():
_type = arg.type
if isinstance(_type, type) and issubclass(_type, StrawberryProtocol):
_type = _type().to_strawberry()
ann = Optional[_type] if arg.optional else _type
if arg.default is None:
param = inspect.Parameter(
arg.name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=ann,
)
else:
param = inspect.Parameter(
arg.name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=ann,
default=arg.default,
)
params.append(param)
sig = inspect.Signature(parameters=params, return_annotation=f.type)
async def _resolver(*args, **kwargs):
if f.resolver is None:
return None
if iscoroutinefunction(f.resolver):
return await f.resolver(*args, **kwargs)
return f.resolver(*args, **kwargs)
_resolver.__signature__ = sig
return _resolver
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
sig = getattr(resolver, "__signature__", None)
@functools.wraps(resolver)
async def _auth_resolver(*args, **kwargs):
if f.public:
return await self._run_resolver(resolver, *args, **kwargs)
user = get_user()
if user is None:
raise graphql_error(Unauthorized(f"{f.name}: Authentication required"))
if f.require_any_permission:
if not any([await user.has_permission(p) for p in f.require_any_permission]):
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
if f.require_any:
perms, resolvers = f.require_any
if not any([await user.has_permission(p) for p in perms]):
ctx = QueryContext([x.name for x in await user.permissions])
resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers]
if not any(resolved):
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
return await self._run_resolver(resolver, *args, **kwargs)
if sig:
_auth_resolver.__signature__ = sig
return _auth_resolver
@staticmethod
async def _run_resolver(r: Callable, *args, **kwargs):
if iscoroutinefunction(r):
return await r(*args, **kwargs)
return r(*args, **kwargs)
def _field_to_strawberry(self, f: Field) -> Any:
resolver = None
try:
if f.arguments:
resolver = self._build_resolver(f)
elif not f.resolver:
resolver = lambda *_, **__: None
else:
ann = getattr(f.resolver, "__annotations__", {})
if "return" not in ann or ann["return"] is None:
ann = dict(ann)
ann["return"] = f.type
f.resolver.__annotations__ = ann
resolver = f.resolver
return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
except StrawberryException as e:
raise Exception(f"Error converting field '{f.name}' to strawberry field: {e}") from e
def to_strawberry(self) -> Type:
cls = self.__class__
if TypeCollector.has(cls):
return TypeCollector.get(cls)
annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {}
for name, f in self._fields.items():
annotations[name] = f.type
namespace[name] = self._field_to_strawberry(f)
namespace["__annotations__"] = annotations
gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
TypeCollector.set(cls, gql_type)
return gql_type

View File

@@ -6,14 +6,15 @@ from cpl.graphql.schema.filter.date_filter import DateFilter
from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.filter.int_filter import IntFilter from cpl.graphql.schema.filter.int_filter import IntFilter
from cpl.graphql.schema.filter.string_filter import StringFilter from cpl.graphql.schema.filter.string_filter import StringFilter
from cpl.graphql.schema.root_mutation import RootMutation
from cpl.graphql.schema.root_query import RootQuery from cpl.graphql.schema.root_query import RootQuery
from cpl.graphql.service.schema import Schema
from cpl.graphql.service.graphql import GraphQLService from cpl.graphql.service.graphql import GraphQLService
from cpl.graphql.service.schema import Schema
class GraphQLModule(Module): class GraphQLModule(Module):
dependencies = [ApiModule] dependencies = [ApiModule]
singleton = [Schema, RootQuery] singleton = [Schema, RootQuery, RootMutation]
scoped = [GraphQLService] scoped = [GraphQLService]
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter] transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]

View File

@@ -1,11 +1,10 @@
from enum import Enum from enum import Enum
from typing import Optional, Any from typing import Optional
from graphql import GraphQLResolveInfo from graphql import GraphQLResolveInfo
from cpl.auth.schema import AuthUser, Permission from cpl.auth.schema import AuthUser, Permission
from cpl.core.ctx import get_user from cpl.core.ctx import get_user
from cpl.core.utils import get_value
class QueryContext: class QueryContext:

View File

@@ -1,38 +1,54 @@
from typing import Any from typing import Any, Self
class Argument: class Argument:
def __init__( def __init__(
self, self,
t: type,
name: str, name: str,
t: type,
description: str = None, description: str = None,
default_value: Any = None, default: Any = None,
optional: bool = None, optional: bool = None,
): ):
self._type = t
self._name = name self._name = name
self._type = t
self._description = description self._description = description
self._default_value = default_value self._default = default
self._optional = optional self._optional = optional
@property
def type(self) -> type:
return self._type
@property @property
def name(self) -> str: def name(self) -> str:
return self._name return self._name
@property
def type(self) -> type:
return self._type
@property @property
def description(self) -> str | None: def description(self) -> str | None:
return self._description return self._description
@property @property
def default_value(self) -> Any | None: def default(self) -> Any | None:
return self._default_value return self._default
@property @property
def optional(self) -> bool | None: def optional(self) -> bool | None:
return self._optional return self._optional
def with_description(self, description: str) -> Self:
self._description = description
return self
def with_default(self, default: Any) -> Self:
self._default = default
return self
def with_optional(self, optional: bool) -> Self:
self._optional = optional
return self
def with_required(self, required: bool = True) -> Self:
self._optional = not required
return self

View File

@@ -91,22 +91,26 @@ class Field:
self._optional = optional self._optional = optional
return self return self
def with_required(self, required: bool = True) -> Self:
self._optional = not required
return self
def with_default(self, default) -> Self: def with_default(self, default) -> Self:
self._default = default self._default = default
return self return self
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self: def with_argument(self, name: str, arg_type: type, description: str = None, default_value=None, optional=True) -> Argument:
if name in self._args: if name in self._args:
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'") raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
self._args[name] = Argument(arg_type, name, description, default_value, optional) self._args[name] = Argument(name, arg_type, description, default_value, optional)
return self return self._args[name]
def with_arguments(self, args: list[Argument]) -> Self: def with_arguments(self, args: list[Argument]) -> Self:
for arg in args: for arg in args:
if not isinstance(arg, Argument): if not isinstance(arg, Argument):
raise ValueError(f"Expected Argument instance, got {type(arg)}") raise ValueError(f"Expected Argument instance, got {type(arg)}")
self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional) self.with_argument(arg.type, arg.name, arg.description, arg.default, arg.optional)
return self return self
def with_require_any_permission(self, *permissions: TRequireAnyPermissions) -> Self: def with_require_any_permission(self, *permissions: TRequireAnyPermissions) -> Self:
@@ -126,7 +130,7 @@ class Field:
self._require_any = (permissions, resolvers) self._require_any = (permissions, resolvers)
return self return self
def with_public(self, public: bool = False) -> Self: def with_public(self, public: bool = True) -> Self:
assert self._require_any is None, "Field cannot be public and have require_any set" assert self._require_any is None, "Field cannot be public and have require_any set"
assert self._require_any_permission is None, "Field cannot be public and have require_any_permission set" assert self._require_any_permission is None, "Field cannot be public and have require_any_permission set"
self._public = public self._public = public

View File

@@ -1,4 +1,4 @@
from typing import Generic, Dict, Type, Optional, Self, Union from typing import Generic, Dict, Type, Optional, Union, Any
import strawberry import strawberry
@@ -12,12 +12,52 @@ _PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
class Input(StrawberryProtocol, Generic[T]): class Input(StrawberryProtocol, Generic[T]):
def __init__(self): def __init__(self):
self._fields: Dict[str, Field] = {} self._fields: Dict[str, Field] = {}
self._values: Dict[str, Any] = {}
@property
def fields(self) -> Dict[str, Field]:
return self._fields
def __getattr__(self, item):
if item in self._values:
return self._values[item]
raise AttributeError(f"{self.__class__.__name__} has no attribute {item}")
def __setattr__(self, key, value):
if key in {"_fields", "_values"}:
super().__setattr__(key, value)
elif key in self._fields:
self._values[key] = value
else:
super().__setattr__(key, value)
def get(self, key: str, default=None):
return self._values.get(key, default)
def get_fields(self) -> dict[str, Field]: def get_fields(self) -> dict[str, Field]:
return self._fields return self._fields
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True): def field(self, name: str, typ: Union[type, "Input"], optional: bool = True) -> Field:
self._fields[name] = Field(name, typ, optional=optional) self._fields[name] = Field(name, typ, optional=optional)
return self._fields[name]
def string_field(self, name: str, optional: bool = True) -> Field:
return self.field(name, str)
def int_field(self, name: str, optional: bool = True) -> Field:
return self.field(name, int, optional)
def float_field(self, name: str, optional: bool = True) -> Field:
return self.field(name, float, optional)
def bool_field(self, name: str, optional: bool = True) -> Field:
return self.field(name, bool, optional)
def list_field(self, name: str, t: type, optional: bool = True) -> Field:
return self.field(name, list[t], optional)
def object_field(self, name: str, t: Type[StrawberryProtocol], optional: bool = True) -> Field:
return self.field(name, t().to_strawberry(), optional)
def to_strawberry(self) -> Type: def to_strawberry(self) -> Type:
cls = self.__class__ cls = self.__class__

View File

@@ -0,0 +1,25 @@
from typing import Type
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.query_abc import QueryABC
from cpl.graphql.schema.field import Field
class Mutation(QueryABC):
@inject
def __init__(self, provider: ServiceProvider):
QueryABC.__init__(self)
self._provider = provider
from cpl.graphql.service.schema import Schema
self._schema = provider.get_service(Schema)
def with_mutation(self, name: str, cls: Type["Mutation"]) -> Field:
sub = self._provider.get_service(cls)
if not sub:
raise ValueError(f"Mutation '{cls.__name__}' not registered in service provider")
return self.field(name, sub.to_strawberry(), lambda: sub)

View File

@@ -1,76 +1,32 @@
import functools from typing import Callable, Type
import inspect
from asyncio import iscoroutinefunction
from typing import Callable, Type, Any, Optional
import strawberry
from strawberry.exceptions import StrawberryException
from cpl.api import Unauthorized, Forbidden
from cpl.core.ctx import get_user
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.dependency.inject import inject from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.query_abc import QueryABC
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.error import graphql_error
from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
from cpl.graphql.schema.field import Field from cpl.graphql.schema.field import Field
from cpl.graphql.schema.sort.sort_order import SortOrder from cpl.graphql.schema.sort.sort_order import SortOrder
from cpl.graphql.typing import Resolver
from cpl.graphql.utils.type_collector import TypeCollector
class Query(StrawberryProtocol): class Query(QueryABC):
@inject @inject
def __init__(self, provider: ServiceProvider): def __init__(self, provider: ServiceProvider):
QueryABC.__init__(self)
self._provider = provider self._provider = provider
from cpl.graphql.service.schema import Schema from cpl.graphql.service.schema import Schema
self._schema = provider.get_service(Schema) self._schema = provider.get_service(Schema)
self._fields: dict[str, Field] = {}
def get_fields(self) -> dict[str, Field]: def with_query(self, name: str, subquery_cls: Type["Query"]) -> Field:
return self._fields
def field(
self,
name: str,
t: type,
resolver: Resolver = None,
) -> Field:
from cpl.graphql.schema.field import Field
self._fields[name] = Field(name, t, resolver)
return self._fields[name]
def string_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, str, resolver)
def int_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, int, resolver)
def float_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, float, resolver)
def bool_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, bool, resolver)
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
return self.field(name, list[t], resolver)
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
return self.field(name, t().to_strawberry(), resolver)
def with_query(self, name: str, subquery_cls: Type["Query"]):
sub = self._provider.get_service(subquery_cls) sub = self._provider.get_service(subquery_cls)
if not sub: if not sub:
raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider") raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider")
self.field(name, sub.to_strawberry(), lambda: sub) return self.field(name, sub.to_strawberry(), lambda: sub)
return self
def collection_field( def collection_field(
self, self,
@@ -105,10 +61,10 @@ class Query(StrawberryProtocol):
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider") raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection) f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection)
f.with_argument(filter.to_strawberry(), "filter") f.with_argument("filter", filter.to_strawberry())
f.with_argument(sort.to_strawberry(), "sort") f.with_argument("sort", sort.to_strawberry())
f.with_argument(int, "skip", default_value=0) f.with_argument("skip", int, default_value=0)
f.with_argument(int, "take", default_value=10) f.with_argument("take", int, default_value=10)
return f return f
def dao_collection_field( def dao_collection_field(
@@ -168,120 +124,8 @@ class Query(StrawberryProtocol):
return Collection(nodes=data, total_count=total_count, count=len(data)) return Collection(nodes=data, total_count=total_count, count=len(data))
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver) f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver)
f.with_argument(filter.to_strawberry(), "filter") f.with_argument("filter", filter.to_strawberry())
f.with_argument(sort.to_strawberry(), "sort") f.with_argument("sort", sort.to_strawberry())
f.with_argument(int, "skip", default_value=0) f.with_argument("skip", int, default_value=0)
f.with_argument(int, "take", default_value=10) f.with_argument("take", int, default_value=10)
return f return f
@staticmethod
def _build_resolver(f: "Field"):
params: list[inspect.Parameter] = []
for arg in f.arguments.values():
ann = Optional[arg.type] if arg.optional else arg.type
if arg.default_value is None:
param = inspect.Parameter(
arg.name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=ann,
)
else:
param = inspect.Parameter(
arg.name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=ann,
default=arg.default_value,
)
params.append(param)
sig = inspect.Signature(parameters=params, return_annotation=f.type)
async def _resolver(*args, **kwargs):
if f.resolver is None:
return None
if iscoroutinefunction(f.resolver):
return await f.resolver(*args, **kwargs)
return f.resolver(*args, **kwargs)
_resolver.__signature__ = sig
return _resolver
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
sig = getattr(resolver, "__signature__", None)
@functools.wraps(resolver)
async def _auth_resolver(*args, **kwargs):
if f.public:
return await self._run_resolver(resolver, *args, **kwargs)
user = get_user()
if user is None:
raise graphql_error(Unauthorized(f"{f.name}: Authentication required"))
if f.require_any_permission:
if not any([await user.has_permission(p) for p in f.require_any_permission]):
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
if f.require_any:
perms, resolvers = f.require_any
if not any([await user.has_permission(p) for p in perms]):
ctx = QueryContext([x.name for x in await user.permissions])
resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers]
if not any(resolved):
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
return await self._run_resolver(resolver, *args, **kwargs)
if sig:
_auth_resolver.__signature__ = sig
return _auth_resolver
@staticmethod
async def _run_resolver(r: Callable, *args, **kwargs):
if iscoroutinefunction(r):
return await r(*args, **kwargs)
return r(*args, **kwargs)
def _field_to_strawberry(self, f: Field) -> Any:
resolver = None
try:
if f.arguments:
resolver = self._build_resolver(f)
elif not f.resolver:
resolver = lambda *_, **__: None
else:
ann = getattr(f.resolver, "__annotations__", {})
if "return" not in ann or ann["return"] is None:
ann = dict(ann)
ann["return"] = f.type
f.resolver.__annotations__ = ann
resolver = f.resolver
return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
except StrawberryException as e:
raise Exception(
f"Error converting field '{f.name}' to strawberry field: {e}"
) from e
def to_strawberry(self) -> Type:
cls = self.__class__
if TypeCollector.has(cls):
return TypeCollector.get(cls)
annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {}
for name, f in self._fields.items():
annotations[name] = f.type
namespace[name] = self._field_to_strawberry(f)
namespace["__annotations__"] = annotations
gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
TypeCollector.set(cls, gql_type)
return gql_type

View File

@@ -0,0 +1,6 @@
from cpl.graphql.schema.mutation import Mutation
class RootMutation(Mutation):
def __init__(self):
Mutation.__init__(self)

View File

@@ -6,6 +6,7 @@ import strawberry
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.root_mutation import RootMutation
from cpl.graphql.schema.root_query import RootQuery from cpl.graphql.schema.root_query import RootQuery
@@ -25,7 +26,17 @@ class Schema:
@property @property
def query(self) -> RootQuery: def query(self) -> RootQuery:
return self._provider.get_service(RootQuery) query = self._provider.get_service(RootQuery)
if not query:
raise ValueError("RootQuery not registered in service provider")
return query
@property
def mutation(self) -> RootMutation:
mutation = self._provider.get_service(RootMutation)
if not mutation:
raise ValueError("RootMutation not registered in service provider")
return mutation
def with_type(self, t: Type[StrawberryProtocol]) -> Self: def with_type(self, t: Type[StrawberryProtocol]) -> Self:
self._types[t.__name__] = t self._types[t.__name__] = t
@@ -43,13 +54,13 @@ class Schema:
def build(self) -> strawberry.Schema: def build(self) -> strawberry.Schema:
logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL) logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL)
query = self._provider.get_service(RootQuery)
if not query: query = self.query
raise ValueError("RootQuery not registered in service provider") mutation = self.mutation
self._schema = strawberry.Schema( self._schema = strawberry.Schema(
query=query.to_strawberry(), query=query.to_strawberry() if query.fields_count > 0 else None,
mutation=None, mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None,
subscription=None, subscription=None,
types=self._get_types(), types=self._get_types(),
) )