Added mutations #181
This commit is contained in:
@@ -5,7 +5,6 @@ from api.src.queries.hello import UserGraphType, AuthUserFilter, AuthUserSort, A
|
||||
from api.src.queries.user import UserFilter, UserSort
|
||||
from cpl.api.api_module import ApiModule
|
||||
from cpl.application.application_builder import ApplicationBuilder
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.auth.schema import AuthUser, Role
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.console import Console
|
||||
@@ -17,7 +16,7 @@ from cpl.graphql.graphql_module import GraphQLModule
|
||||
from model.author_dao import AuthorDao
|
||||
from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort
|
||||
from model.post_dao import PostDao
|
||||
from model.post_query import PostFilter, PostSort, PostGraphType
|
||||
from model.post_query import PostFilter, PostSort, PostGraphType, PostMutation
|
||||
from permissions import PostPermissions
|
||||
from queries.hello import HelloQuery
|
||||
from scoped_service import ScopedService
|
||||
@@ -64,6 +63,7 @@ def main():
|
||||
.add_transient(PostGraphType)
|
||||
.add_transient(PostFilter)
|
||||
.add_transient(PostSort)
|
||||
.add_transient(PostMutation)
|
||||
)
|
||||
|
||||
app = builder.build()
|
||||
@@ -77,8 +77,8 @@ def main():
|
||||
path="/route1",
|
||||
fn=lambda r: JSONResponse("route1"),
|
||||
method="GET",
|
||||
authentication=True,
|
||||
permissions=[Permissions.administrator],
|
||||
# authentication=True,
|
||||
# permissions=[Permissions.administrator],
|
||||
)
|
||||
app.with_routes_directory("routes")
|
||||
|
||||
@@ -88,9 +88,12 @@ def main():
|
||||
schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort)
|
||||
(
|
||||
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort)
|
||||
.with_require_any_permission(PostPermissions.read)
|
||||
# .with_require_any_permission(PostPermissions.read)
|
||||
.with_public()
|
||||
)
|
||||
|
||||
schema.mutation.with_mutation("post", PostMutation).with_public()
|
||||
|
||||
app.with_playground()
|
||||
app.with_graphiql()
|
||||
|
||||
|
||||
@@ -7,5 +7,5 @@ class AuthorDao(DbModelDaoABC):
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, Author, "authors")
|
||||
|
||||
self.attribute(Author.first_name, str)
|
||||
self.attribute(Author.last_name, str)
|
||||
self.attribute(Author.first_name, str, db_name="firstname")
|
||||
self.attribute(Author.last_name, str, db_name="lastname")
|
||||
@@ -31,6 +31,14 @@ class Post(DbModelABC[Self]):
|
||||
def title(self) -> str:
|
||||
return self._title
|
||||
|
||||
@title.setter
|
||||
def title(self, value: str):
|
||||
self._title = value
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
return self._content
|
||||
|
||||
@content.setter
|
||||
def content(self, value: str):
|
||||
self._content = value
|
||||
|
||||
@@ -3,7 +3,7 @@ from model.author_dao import AuthorDao
|
||||
from model.post import Post
|
||||
|
||||
|
||||
class PostDao(DbModelDaoABC):
|
||||
class PostDao(DbModelDaoABC[Post]):
|
||||
|
||||
def __init__(self, authors: AuthorDao):
|
||||
DbModelDaoABC.__init__(self, Post, "posts")
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
from cpl.graphql.query_context import QueryContext
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
from cpl.graphql.schema.graph_type import GraphType
|
||||
from cpl.graphql.schema.input import Input
|
||||
from cpl.graphql.schema.mutation import Mutation
|
||||
from cpl.graphql.schema.sort.sort import Sort
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
from model.author_dao import AuthorDao
|
||||
from model.author_query import AuthorGraphType, AuthorFilter
|
||||
from model.post import Post
|
||||
from model.post_dao import PostDao
|
||||
|
||||
|
||||
class PostFilter(Filter[Post]):
|
||||
def __init__(self):
|
||||
@@ -38,9 +42,7 @@ class PostGraphType(GraphType[Post]):
|
||||
def r_name(ctx: QueryContext):
|
||||
return ctx.user.username == "admin"
|
||||
|
||||
self.object_field("author", AuthorGraphType, resolver=_a).with_require_any(
|
||||
[], [r_name]
|
||||
)
|
||||
self.object_field("author", AuthorGraphType, resolver=_a).with_public(True)# .with_require_any([], [r_name]))
|
||||
self.string_field(
|
||||
"title",
|
||||
resolver=lambda root: root.title,
|
||||
@@ -49,3 +51,80 @@ class PostGraphType(GraphType[Post]):
|
||||
"content",
|
||||
resolver=lambda root: root.content,
|
||||
).with_public(True)
|
||||
|
||||
|
||||
class PostCreateInput(Input[Post]):
|
||||
title: str
|
||||
content: str
|
||||
author_id: int
|
||||
|
||||
def __init__(self):
|
||||
Input.__init__(self)
|
||||
self.string_field("title").with_required()
|
||||
self.string_field("content").with_required()
|
||||
self.int_field("author_id").with_required()
|
||||
|
||||
class PostUpdateInput(Input[Post]):
|
||||
title: str
|
||||
content: str
|
||||
author_id: int
|
||||
|
||||
def __init__(self):
|
||||
Input.__init__(self)
|
||||
self.int_field("id").with_required()
|
||||
self.string_field("title").with_required(False)
|
||||
self.string_field("content").with_required(False)
|
||||
|
||||
class PostMutation(Mutation):
|
||||
|
||||
def __init__(self, posts: PostDao, authors: AuthorDao):
|
||||
Mutation.__init__(self)
|
||||
|
||||
self._posts = posts
|
||||
self._authors = authors
|
||||
|
||||
self.field("create", int, resolver=self.create_post).with_public().with_required().with_argument(
|
||||
"input",
|
||||
PostCreateInput,
|
||||
).with_required()
|
||||
self.field("update", bool, resolver=self.update_post).with_public().with_required().with_argument(
|
||||
"input",
|
||||
PostUpdateInput,
|
||||
).with_required()
|
||||
self.field("delete", bool, resolver=self.delete_post).with_public().with_required().with_argument(
|
||||
"id",
|
||||
int,
|
||||
).with_required()
|
||||
self.field("restore", bool, resolver=self.restore_post).with_public().with_required().with_argument(
|
||||
"id",
|
||||
int,
|
||||
).with_required()
|
||||
|
||||
async def create_post(self, input: PostCreateInput) -> int:
|
||||
return await self._posts.create(Post(0, input.author_id, input.title, input.content))
|
||||
|
||||
async def update_post(self, input: PostUpdateInput) -> bool:
|
||||
post = await self._posts.get_by_id(input.id)
|
||||
if post is None:
|
||||
return False
|
||||
|
||||
post.title = input.title if input.title is not None else post.title
|
||||
post.content = input.content if input.content is not None else post.content
|
||||
|
||||
await self._posts.update(post)
|
||||
return True
|
||||
|
||||
async def delete_post(self, id: int) -> bool:
|
||||
post = await self._posts.get_by_id(id)
|
||||
if post is None:
|
||||
return False
|
||||
await self._posts.delete(post)
|
||||
return True
|
||||
|
||||
async def restore_post(self, id: int) -> bool:
|
||||
post = await self._posts.get_by_id(id)
|
||||
if post is None:
|
||||
return False
|
||||
await self._posts.restore(post)
|
||||
return True
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class HelloQuery(Query):
|
||||
self.string_field(
|
||||
"message",
|
||||
resolver=lambda name: f"Hello {name} {get_request().state.request_id}",
|
||||
).with_argument(str, "name", "Name to greet", "world")
|
||||
).with_argument("name", str, "Name to greet", "world")
|
||||
|
||||
self.collection_field(
|
||||
UserGraphType,
|
||||
|
||||
178
src/cpl-graphql/cpl/graphql/abc/query_abc.py
Normal file
178
src/cpl-graphql/cpl/graphql/abc/query_abc.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import functools
|
||||
import inspect
|
||||
from abc import ABC
|
||||
from asyncio import iscoroutinefunction
|
||||
from typing import Callable, Type, Any, Optional
|
||||
|
||||
import strawberry
|
||||
from strawberry.exceptions import StrawberryException
|
||||
|
||||
from cpl.api import Unauthorized, Forbidden
|
||||
from cpl.core.ctx.user_context import get_user
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.error import graphql_error
|
||||
from cpl.graphql.query_context import QueryContext
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.typing import Resolver
|
||||
from cpl.graphql.utils.type_collector import TypeCollector
|
||||
|
||||
|
||||
class QueryABC(StrawberryProtocol, ABC):
|
||||
|
||||
def __init__(self):
|
||||
ABC.__init__(self)
|
||||
self._fields: dict[str, Field] = {}
|
||||
|
||||
@property
|
||||
def fields(self) -> dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
@property
|
||||
def fields_count(self) -> int:
|
||||
return len(self._fields)
|
||||
|
||||
def get_fields(self) -> dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
def field(
|
||||
self,
|
||||
name: str,
|
||||
t: type,
|
||||
resolver: Resolver = None,
|
||||
) -> Field:
|
||||
from cpl.graphql.schema.field import Field
|
||||
|
||||
self._fields[name] = Field(name, t, resolver)
|
||||
return self._fields[name]
|
||||
|
||||
def string_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, str, resolver)
|
||||
|
||||
def int_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, int, resolver)
|
||||
|
||||
def float_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, float, resolver)
|
||||
|
||||
def bool_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, bool, resolver)
|
||||
|
||||
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, list[t], resolver)
|
||||
|
||||
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
|
||||
return self.field(name, t().to_strawberry(), resolver)
|
||||
|
||||
@staticmethod
|
||||
def _build_resolver(f: "Field"):
|
||||
params: list[inspect.Parameter] = []
|
||||
for arg in f.arguments.values():
|
||||
_type = arg.type
|
||||
if isinstance(_type, type) and issubclass(_type, StrawberryProtocol):
|
||||
_type = _type().to_strawberry()
|
||||
|
||||
ann = Optional[_type] if arg.optional else _type
|
||||
|
||||
if arg.default is None:
|
||||
param = inspect.Parameter(
|
||||
arg.name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=ann,
|
||||
)
|
||||
else:
|
||||
param = inspect.Parameter(
|
||||
arg.name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=ann,
|
||||
default=arg.default,
|
||||
)
|
||||
|
||||
params.append(param)
|
||||
|
||||
sig = inspect.Signature(parameters=params, return_annotation=f.type)
|
||||
|
||||
async def _resolver(*args, **kwargs):
|
||||
if f.resolver is None:
|
||||
return None
|
||||
|
||||
if iscoroutinefunction(f.resolver):
|
||||
return await f.resolver(*args, **kwargs)
|
||||
return f.resolver(*args, **kwargs)
|
||||
|
||||
_resolver.__signature__ = sig
|
||||
return _resolver
|
||||
|
||||
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
|
||||
sig = getattr(resolver, "__signature__", None)
|
||||
|
||||
@functools.wraps(resolver)
|
||||
async def _auth_resolver(*args, **kwargs):
|
||||
if f.public:
|
||||
return await self._run_resolver(resolver, *args, **kwargs)
|
||||
|
||||
user = get_user()
|
||||
|
||||
if user is None:
|
||||
raise graphql_error(Unauthorized(f"{f.name}: Authentication required"))
|
||||
|
||||
if f.require_any_permission:
|
||||
if not any([await user.has_permission(p) for p in f.require_any_permission]):
|
||||
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
|
||||
|
||||
if f.require_any:
|
||||
perms, resolvers = f.require_any
|
||||
if not any([await user.has_permission(p) for p in perms]):
|
||||
ctx = QueryContext([x.name for x in await user.permissions])
|
||||
resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers]
|
||||
|
||||
if not any(resolved):
|
||||
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
|
||||
|
||||
return await self._run_resolver(resolver, *args, **kwargs)
|
||||
|
||||
if sig:
|
||||
_auth_resolver.__signature__ = sig
|
||||
|
||||
return _auth_resolver
|
||||
|
||||
@staticmethod
|
||||
async def _run_resolver(r: Callable, *args, **kwargs):
|
||||
if iscoroutinefunction(r):
|
||||
return await r(*args, **kwargs)
|
||||
return r(*args, **kwargs)
|
||||
|
||||
def _field_to_strawberry(self, f: Field) -> Any:
|
||||
resolver = None
|
||||
try:
|
||||
if f.arguments:
|
||||
resolver = self._build_resolver(f)
|
||||
elif not f.resolver:
|
||||
resolver = lambda *_, **__: None
|
||||
else:
|
||||
ann = getattr(f.resolver, "__annotations__", {})
|
||||
if "return" not in ann or ann["return"] is None:
|
||||
ann = dict(ann)
|
||||
ann["return"] = f.type
|
||||
f.resolver.__annotations__ = ann
|
||||
resolver = f.resolver
|
||||
|
||||
return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
|
||||
except StrawberryException as e:
|
||||
raise Exception(f"Error converting field '{f.name}' to strawberry field: {e}") from e
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
cls = self.__class__
|
||||
if TypeCollector.has(cls):
|
||||
return TypeCollector.get(cls)
|
||||
|
||||
annotations: dict[str, Any] = {}
|
||||
namespace: dict[str, Any] = {}
|
||||
|
||||
for name, f in self._fields.items():
|
||||
annotations[name] = f.type
|
||||
namespace[name] = self._field_to_strawberry(f)
|
||||
|
||||
namespace["__annotations__"] = annotations
|
||||
gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
|
||||
TypeCollector.set(cls, gql_type)
|
||||
return gql_type
|
||||
@@ -6,14 +6,15 @@ from cpl.graphql.schema.filter.date_filter import DateFilter
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
from cpl.graphql.schema.filter.int_filter import IntFilter
|
||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||
from cpl.graphql.schema.root_mutation import RootMutation
|
||||
from cpl.graphql.schema.root_query import RootQuery
|
||||
from cpl.graphql.service.schema import Schema
|
||||
from cpl.graphql.service.graphql import GraphQLService
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
|
||||
class GraphQLModule(Module):
|
||||
dependencies = [ApiModule]
|
||||
singleton = [Schema, RootQuery]
|
||||
singleton = [Schema, RootQuery, RootMutation]
|
||||
scoped = [GraphQLService]
|
||||
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Any
|
||||
from typing import Optional
|
||||
|
||||
from graphql import GraphQLResolveInfo
|
||||
|
||||
from cpl.auth.schema import AuthUser, Permission
|
||||
from cpl.core.ctx import get_user
|
||||
from cpl.core.utils import get_value
|
||||
|
||||
|
||||
class QueryContext:
|
||||
|
||||
@@ -1,38 +1,54 @@
|
||||
from typing import Any
|
||||
from typing import Any, Self
|
||||
|
||||
|
||||
class Argument:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
t: type,
|
||||
name: str,
|
||||
t: type,
|
||||
description: str = None,
|
||||
default_value: Any = None,
|
||||
default: Any = None,
|
||||
optional: bool = None,
|
||||
):
|
||||
self._type = t
|
||||
self._name = name
|
||||
self._type = t
|
||||
self._description = description
|
||||
self._default_value = default_value
|
||||
self._default = default
|
||||
self._optional = optional
|
||||
|
||||
@property
|
||||
def type(self) -> type:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def type(self) -> type:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def default_value(self) -> Any | None:
|
||||
return self._default_value
|
||||
def default(self) -> Any | None:
|
||||
return self._default
|
||||
|
||||
@property
|
||||
def optional(self) -> bool | None:
|
||||
return self._optional
|
||||
|
||||
def with_description(self, description: str) -> Self:
|
||||
self._description = description
|
||||
return self
|
||||
|
||||
def with_default(self, default: Any) -> Self:
|
||||
self._default = default
|
||||
return self
|
||||
|
||||
def with_optional(self, optional: bool) -> Self:
|
||||
self._optional = optional
|
||||
return self
|
||||
|
||||
def with_required(self, required: bool = True) -> Self:
|
||||
self._optional = not required
|
||||
return self
|
||||
|
||||
@@ -91,22 +91,26 @@ class Field:
|
||||
self._optional = optional
|
||||
return self
|
||||
|
||||
def with_required(self, required: bool = True) -> Self:
|
||||
self._optional = not required
|
||||
return self
|
||||
|
||||
def with_default(self, default) -> Self:
|
||||
self._default = default
|
||||
return self
|
||||
|
||||
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self:
|
||||
def with_argument(self, name: str, arg_type: type, description: str = None, default_value=None, optional=True) -> Argument:
|
||||
if name in self._args:
|
||||
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
|
||||
self._args[name] = Argument(arg_type, name, description, default_value, optional)
|
||||
return self
|
||||
self._args[name] = Argument(name, arg_type, description, default_value, optional)
|
||||
return self._args[name]
|
||||
|
||||
def with_arguments(self, args: list[Argument]) -> Self:
|
||||
for arg in args:
|
||||
if not isinstance(arg, Argument):
|
||||
raise ValueError(f"Expected Argument instance, got {type(arg)}")
|
||||
|
||||
self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional)
|
||||
self.with_argument(arg.type, arg.name, arg.description, arg.default, arg.optional)
|
||||
return self
|
||||
|
||||
def with_require_any_permission(self, *permissions: TRequireAnyPermissions) -> Self:
|
||||
@@ -126,7 +130,7 @@ class Field:
|
||||
self._require_any = (permissions, resolvers)
|
||||
return self
|
||||
|
||||
def with_public(self, public: bool = False) -> Self:
|
||||
def with_public(self, public: bool = True) -> Self:
|
||||
assert self._require_any is None, "Field cannot be public and have require_any set"
|
||||
assert self._require_any_permission is None, "Field cannot be public and have require_any_permission set"
|
||||
self._public = public
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Generic, Dict, Type, Optional, Self, Union
|
||||
from typing import Generic, Dict, Type, Optional, Union, Any
|
||||
|
||||
import strawberry
|
||||
|
||||
@@ -12,12 +12,52 @@ _PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
|
||||
class Input(StrawberryProtocol, Generic[T]):
|
||||
def __init__(self):
|
||||
self._fields: Dict[str, Field] = {}
|
||||
self._values: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def fields(self) -> Dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item in self._values:
|
||||
return self._values[item]
|
||||
raise AttributeError(f"{self.__class__.__name__} has no attribute {item}")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in {"_fields", "_values"}:
|
||||
super().__setattr__(key, value)
|
||||
elif key in self._fields:
|
||||
self._values[key] = value
|
||||
else:
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
return self._values.get(key, default)
|
||||
|
||||
def get_fields(self) -> dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True):
|
||||
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True) -> Field:
|
||||
self._fields[name] = Field(name, typ, optional=optional)
|
||||
return self._fields[name]
|
||||
|
||||
def string_field(self, name: str, optional: bool = True) -> Field:
|
||||
return self.field(name, str)
|
||||
|
||||
def int_field(self, name: str, optional: bool = True) -> Field:
|
||||
return self.field(name, int, optional)
|
||||
|
||||
def float_field(self, name: str, optional: bool = True) -> Field:
|
||||
return self.field(name, float, optional)
|
||||
|
||||
def bool_field(self, name: str, optional: bool = True) -> Field:
|
||||
return self.field(name, bool, optional)
|
||||
|
||||
def list_field(self, name: str, t: type, optional: bool = True) -> Field:
|
||||
return self.field(name, list[t], optional)
|
||||
|
||||
def object_field(self, name: str, t: Type[StrawberryProtocol], optional: bool = True) -> Field:
|
||||
return self.field(name, t().to_strawberry(), optional)
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
cls = self.__class__
|
||||
|
||||
25
src/cpl-graphql/cpl/graphql/schema/mutation.py
Normal file
25
src/cpl-graphql/cpl/graphql/schema/mutation.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Type
|
||||
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.abc.query_abc import QueryABC
|
||||
from cpl.graphql.schema.field import Field
|
||||
|
||||
|
||||
class Mutation(QueryABC):
|
||||
|
||||
@inject
|
||||
def __init__(self, provider: ServiceProvider):
|
||||
QueryABC.__init__(self)
|
||||
self._provider = provider
|
||||
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
self._schema = provider.get_service(Schema)
|
||||
|
||||
def with_mutation(self, name: str, cls: Type["Mutation"]) -> Field:
|
||||
sub = self._provider.get_service(cls)
|
||||
if not sub:
|
||||
raise ValueError(f"Mutation '{cls.__name__}' not registered in service provider")
|
||||
|
||||
return self.field(name, sub.to_strawberry(), lambda: sub)
|
||||
@@ -1,76 +1,32 @@
|
||||
import functools
|
||||
import inspect
|
||||
from asyncio import iscoroutinefunction
|
||||
from typing import Callable, Type, Any, Optional
|
||||
from typing import Callable, Type
|
||||
|
||||
import strawberry
|
||||
from strawberry.exceptions import StrawberryException
|
||||
|
||||
from cpl.api import Unauthorized, Forbidden
|
||||
from cpl.core.ctx import get_user
|
||||
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.abc.query_abc import QueryABC
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.error import graphql_error
|
||||
from cpl.graphql.query_context import QueryContext
|
||||
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
from cpl.graphql.typing import Resolver
|
||||
from cpl.graphql.utils.type_collector import TypeCollector
|
||||
|
||||
|
||||
class Query(StrawberryProtocol):
|
||||
class Query(QueryABC):
|
||||
|
||||
@inject
|
||||
def __init__(self, provider: ServiceProvider):
|
||||
QueryABC.__init__(self)
|
||||
self._provider = provider
|
||||
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
self._schema = provider.get_service(Schema)
|
||||
self._fields: dict[str, Field] = {}
|
||||
|
||||
def get_fields(self) -> dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
def field(
|
||||
self,
|
||||
name: str,
|
||||
t: type,
|
||||
resolver: Resolver = None,
|
||||
) -> Field:
|
||||
from cpl.graphql.schema.field import Field
|
||||
|
||||
self._fields[name] = Field(name, t, resolver)
|
||||
return self._fields[name]
|
||||
|
||||
def string_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, str, resolver)
|
||||
|
||||
def int_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, int, resolver)
|
||||
|
||||
def float_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, float, resolver)
|
||||
|
||||
def bool_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, bool, resolver)
|
||||
|
||||
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, list[t], resolver)
|
||||
|
||||
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
|
||||
return self.field(name, t().to_strawberry(), resolver)
|
||||
|
||||
def with_query(self, name: str, subquery_cls: Type["Query"]):
|
||||
def with_query(self, name: str, subquery_cls: Type["Query"]) -> Field:
|
||||
sub = self._provider.get_service(subquery_cls)
|
||||
if not sub:
|
||||
raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider")
|
||||
|
||||
self.field(name, sub.to_strawberry(), lambda: sub)
|
||||
return self
|
||||
return self.field(name, sub.to_strawberry(), lambda: sub)
|
||||
|
||||
def collection_field(
|
||||
self,
|
||||
@@ -105,10 +61,10 @@ class Query(StrawberryProtocol):
|
||||
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
|
||||
|
||||
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection)
|
||||
f.with_argument(filter.to_strawberry(), "filter")
|
||||
f.with_argument(sort.to_strawberry(), "sort")
|
||||
f.with_argument(int, "skip", default_value=0)
|
||||
f.with_argument(int, "take", default_value=10)
|
||||
f.with_argument("filter", filter.to_strawberry())
|
||||
f.with_argument("sort", sort.to_strawberry())
|
||||
f.with_argument("skip", int, default_value=0)
|
||||
f.with_argument("take", int, default_value=10)
|
||||
return f
|
||||
|
||||
def dao_collection_field(
|
||||
@@ -168,120 +124,8 @@ class Query(StrawberryProtocol):
|
||||
return Collection(nodes=data, total_count=total_count, count=len(data))
|
||||
|
||||
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver)
|
||||
f.with_argument(filter.to_strawberry(), "filter")
|
||||
f.with_argument(sort.to_strawberry(), "sort")
|
||||
f.with_argument(int, "skip", default_value=0)
|
||||
f.with_argument(int, "take", default_value=10)
|
||||
f.with_argument("filter", filter.to_strawberry())
|
||||
f.with_argument("sort", sort.to_strawberry())
|
||||
f.with_argument("skip", int, default_value=0)
|
||||
f.with_argument("take", int, default_value=10)
|
||||
return f
|
||||
|
||||
@staticmethod
|
||||
def _build_resolver(f: "Field"):
|
||||
params: list[inspect.Parameter] = []
|
||||
for arg in f.arguments.values():
|
||||
ann = Optional[arg.type] if arg.optional else arg.type
|
||||
|
||||
if arg.default_value is None:
|
||||
param = inspect.Parameter(
|
||||
arg.name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=ann,
|
||||
)
|
||||
else:
|
||||
param = inspect.Parameter(
|
||||
arg.name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=ann,
|
||||
default=arg.default_value,
|
||||
)
|
||||
|
||||
params.append(param)
|
||||
|
||||
sig = inspect.Signature(parameters=params, return_annotation=f.type)
|
||||
|
||||
async def _resolver(*args, **kwargs):
|
||||
if f.resolver is None:
|
||||
return None
|
||||
|
||||
if iscoroutinefunction(f.resolver):
|
||||
return await f.resolver(*args, **kwargs)
|
||||
return f.resolver(*args, **kwargs)
|
||||
|
||||
_resolver.__signature__ = sig
|
||||
return _resolver
|
||||
|
||||
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
|
||||
sig = getattr(resolver, "__signature__", None)
|
||||
|
||||
@functools.wraps(resolver)
|
||||
async def _auth_resolver(*args, **kwargs):
|
||||
if f.public:
|
||||
return await self._run_resolver(resolver, *args, **kwargs)
|
||||
|
||||
user = get_user()
|
||||
|
||||
if user is None:
|
||||
raise graphql_error(Unauthorized(f"{f.name}: Authentication required"))
|
||||
|
||||
if f.require_any_permission:
|
||||
if not any([await user.has_permission(p) for p in f.require_any_permission]):
|
||||
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
|
||||
|
||||
if f.require_any:
|
||||
perms, resolvers = f.require_any
|
||||
if not any([await user.has_permission(p) for p in perms]):
|
||||
ctx = QueryContext([x.name for x in await user.permissions])
|
||||
resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers]
|
||||
|
||||
if not any(resolved):
|
||||
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
|
||||
|
||||
return await self._run_resolver(resolver, *args, **kwargs)
|
||||
|
||||
if sig:
|
||||
_auth_resolver.__signature__ = sig
|
||||
|
||||
return _auth_resolver
|
||||
|
||||
@staticmethod
|
||||
async def _run_resolver(r: Callable, *args, **kwargs):
|
||||
if iscoroutinefunction(r):
|
||||
return await r(*args, **kwargs)
|
||||
return r(*args, **kwargs)
|
||||
|
||||
def _field_to_strawberry(self, f: Field) -> Any:
|
||||
resolver = None
|
||||
try:
|
||||
if f.arguments:
|
||||
resolver = self._build_resolver(f)
|
||||
elif not f.resolver:
|
||||
resolver = lambda *_, **__: None
|
||||
else:
|
||||
ann = getattr(f.resolver, "__annotations__", {})
|
||||
if "return" not in ann or ann["return"] is None:
|
||||
ann = dict(ann)
|
||||
ann["return"] = f.type
|
||||
f.resolver.__annotations__ = ann
|
||||
resolver = f.resolver
|
||||
|
||||
return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
|
||||
except StrawberryException as e:
|
||||
raise Exception(
|
||||
f"Error converting field '{f.name}' to strawberry field: {e}"
|
||||
) from e
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
cls = self.__class__
|
||||
if TypeCollector.has(cls):
|
||||
return TypeCollector.get(cls)
|
||||
|
||||
annotations: dict[str, Any] = {}
|
||||
namespace: dict[str, Any] = {}
|
||||
|
||||
for name, f in self._fields.items():
|
||||
annotations[name] = f.type
|
||||
namespace[name] = self._field_to_strawberry(f)
|
||||
|
||||
namespace["__annotations__"] = annotations
|
||||
gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
|
||||
TypeCollector.set(cls, gql_type)
|
||||
return gql_type
|
||||
|
||||
6
src/cpl-graphql/cpl/graphql/schema/root_mutation.py
Normal file
6
src/cpl-graphql/cpl/graphql/schema/root_mutation.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from cpl.graphql.schema.mutation import Mutation
|
||||
|
||||
|
||||
class RootMutation(Mutation):
|
||||
def __init__(self):
|
||||
Mutation.__init__(self)
|
||||
@@ -6,6 +6,7 @@ import strawberry
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.schema.root_mutation import RootMutation
|
||||
from cpl.graphql.schema.root_query import RootQuery
|
||||
|
||||
|
||||
@@ -25,7 +26,17 @@ class Schema:
|
||||
|
||||
@property
|
||||
def query(self) -> RootQuery:
|
||||
return self._provider.get_service(RootQuery)
|
||||
query = self._provider.get_service(RootQuery)
|
||||
if not query:
|
||||
raise ValueError("RootQuery not registered in service provider")
|
||||
return query
|
||||
|
||||
@property
|
||||
def mutation(self) -> RootMutation:
|
||||
mutation = self._provider.get_service(RootMutation)
|
||||
if not mutation:
|
||||
raise ValueError("RootMutation not registered in service provider")
|
||||
return mutation
|
||||
|
||||
def with_type(self, t: Type[StrawberryProtocol]) -> Self:
|
||||
self._types[t.__name__] = t
|
||||
@@ -43,13 +54,13 @@ class Schema:
|
||||
|
||||
def build(self) -> strawberry.Schema:
|
||||
logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL)
|
||||
query = self._provider.get_service(RootQuery)
|
||||
if not query:
|
||||
raise ValueError("RootQuery not registered in service provider")
|
||||
|
||||
query = self.query
|
||||
mutation = self.mutation
|
||||
|
||||
self._schema = strawberry.Schema(
|
||||
query=query.to_strawberry(),
|
||||
mutation=None,
|
||||
query=query.to_strawberry() if query.fields_count > 0 else None,
|
||||
mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None,
|
||||
subscription=None,
|
||||
types=self._get_types(),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user