Added mutations #181
This commit is contained in:
@@ -5,7 +5,6 @@ from api.src.queries.hello import UserGraphType, AuthUserFilter, AuthUserSort, A
|
|||||||
from api.src.queries.user import UserFilter, UserSort
|
from api.src.queries.user import UserFilter, UserSort
|
||||||
from cpl.api.api_module import ApiModule
|
from cpl.api.api_module import ApiModule
|
||||||
from cpl.application.application_builder import ApplicationBuilder
|
from cpl.application.application_builder import ApplicationBuilder
|
||||||
from cpl.auth.permission.permissions import Permissions
|
|
||||||
from cpl.auth.schema import AuthUser, Role
|
from cpl.auth.schema import AuthUser, Role
|
||||||
from cpl.core.configuration import Configuration
|
from cpl.core.configuration import Configuration
|
||||||
from cpl.core.console import Console
|
from cpl.core.console import Console
|
||||||
@@ -17,7 +16,7 @@ from cpl.graphql.graphql_module import GraphQLModule
|
|||||||
from model.author_dao import AuthorDao
|
from model.author_dao import AuthorDao
|
||||||
from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort
|
from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort
|
||||||
from model.post_dao import PostDao
|
from model.post_dao import PostDao
|
||||||
from model.post_query import PostFilter, PostSort, PostGraphType
|
from model.post_query import PostFilter, PostSort, PostGraphType, PostMutation
|
||||||
from permissions import PostPermissions
|
from permissions import PostPermissions
|
||||||
from queries.hello import HelloQuery
|
from queries.hello import HelloQuery
|
||||||
from scoped_service import ScopedService
|
from scoped_service import ScopedService
|
||||||
@@ -64,6 +63,7 @@ def main():
|
|||||||
.add_transient(PostGraphType)
|
.add_transient(PostGraphType)
|
||||||
.add_transient(PostFilter)
|
.add_transient(PostFilter)
|
||||||
.add_transient(PostSort)
|
.add_transient(PostSort)
|
||||||
|
.add_transient(PostMutation)
|
||||||
)
|
)
|
||||||
|
|
||||||
app = builder.build()
|
app = builder.build()
|
||||||
@@ -77,8 +77,8 @@ def main():
|
|||||||
path="/route1",
|
path="/route1",
|
||||||
fn=lambda r: JSONResponse("route1"),
|
fn=lambda r: JSONResponse("route1"),
|
||||||
method="GET",
|
method="GET",
|
||||||
authentication=True,
|
# authentication=True,
|
||||||
permissions=[Permissions.administrator],
|
# permissions=[Permissions.administrator],
|
||||||
)
|
)
|
||||||
app.with_routes_directory("routes")
|
app.with_routes_directory("routes")
|
||||||
|
|
||||||
@@ -88,9 +88,12 @@ def main():
|
|||||||
schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort)
|
schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort)
|
||||||
(
|
(
|
||||||
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort)
|
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort)
|
||||||
.with_require_any_permission(PostPermissions.read)
|
# .with_require_any_permission(PostPermissions.read)
|
||||||
|
.with_public()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
schema.mutation.with_mutation("post", PostMutation).with_public()
|
||||||
|
|
||||||
app.with_playground()
|
app.with_playground()
|
||||||
app.with_graphiql()
|
app.with_graphiql()
|
||||||
|
|
||||||
|
|||||||
@@ -7,5 +7,5 @@ class AuthorDao(DbModelDaoABC):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
DbModelDaoABC.__init__(self, Author, "authors")
|
DbModelDaoABC.__init__(self, Author, "authors")
|
||||||
|
|
||||||
self.attribute(Author.first_name, str)
|
self.attribute(Author.first_name, str, db_name="firstname")
|
||||||
self.attribute(Author.last_name, str)
|
self.attribute(Author.last_name, str, db_name="lastname")
|
||||||
@@ -31,6 +31,14 @@ class Post(DbModelABC[Self]):
|
|||||||
def title(self) -> str:
|
def title(self) -> str:
|
||||||
return self._title
|
return self._title
|
||||||
|
|
||||||
|
@title.setter
|
||||||
|
def title(self, value: str):
|
||||||
|
self._title = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def content(self) -> str:
|
def content(self) -> str:
|
||||||
return self._content
|
return self._content
|
||||||
|
|
||||||
|
@content.setter
|
||||||
|
def content(self, value: str):
|
||||||
|
self._content = value
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from model.author_dao import AuthorDao
|
|||||||
from model.post import Post
|
from model.post import Post
|
||||||
|
|
||||||
|
|
||||||
class PostDao(DbModelDaoABC):
|
class PostDao(DbModelDaoABC[Post]):
|
||||||
|
|
||||||
def __init__(self, authors: AuthorDao):
|
def __init__(self, authors: AuthorDao):
|
||||||
DbModelDaoABC.__init__(self, Post, "posts")
|
DbModelDaoABC.__init__(self, Post, "posts")
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
from cpl.graphql.query_context import QueryContext
|
from cpl.graphql.query_context import QueryContext
|
||||||
from cpl.graphql.schema.filter.filter import Filter
|
from cpl.graphql.schema.filter.filter import Filter
|
||||||
from cpl.graphql.schema.graph_type import GraphType
|
from cpl.graphql.schema.graph_type import GraphType
|
||||||
|
from cpl.graphql.schema.input import Input
|
||||||
|
from cpl.graphql.schema.mutation import Mutation
|
||||||
from cpl.graphql.schema.sort.sort import Sort
|
from cpl.graphql.schema.sort.sort import Sort
|
||||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||||
from model.author_dao import AuthorDao
|
from model.author_dao import AuthorDao
|
||||||
from model.author_query import AuthorGraphType, AuthorFilter
|
from model.author_query import AuthorGraphType, AuthorFilter
|
||||||
from model.post import Post
|
from model.post import Post
|
||||||
|
from model.post_dao import PostDao
|
||||||
|
|
||||||
|
|
||||||
class PostFilter(Filter[Post]):
|
class PostFilter(Filter[Post]):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -38,9 +42,7 @@ class PostGraphType(GraphType[Post]):
|
|||||||
def r_name(ctx: QueryContext):
|
def r_name(ctx: QueryContext):
|
||||||
return ctx.user.username == "admin"
|
return ctx.user.username == "admin"
|
||||||
|
|
||||||
self.object_field("author", AuthorGraphType, resolver=_a).with_require_any(
|
self.object_field("author", AuthorGraphType, resolver=_a).with_public(True)# .with_require_any([], [r_name]))
|
||||||
[], [r_name]
|
|
||||||
)
|
|
||||||
self.string_field(
|
self.string_field(
|
||||||
"title",
|
"title",
|
||||||
resolver=lambda root: root.title,
|
resolver=lambda root: root.title,
|
||||||
@@ -49,3 +51,80 @@ class PostGraphType(GraphType[Post]):
|
|||||||
"content",
|
"content",
|
||||||
resolver=lambda root: root.content,
|
resolver=lambda root: root.content,
|
||||||
).with_public(True)
|
).with_public(True)
|
||||||
|
|
||||||
|
|
||||||
|
class PostCreateInput(Input[Post]):
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
author_id: int
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
Input.__init__(self)
|
||||||
|
self.string_field("title").with_required()
|
||||||
|
self.string_field("content").with_required()
|
||||||
|
self.int_field("author_id").with_required()
|
||||||
|
|
||||||
|
class PostUpdateInput(Input[Post]):
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
author_id: int
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
Input.__init__(self)
|
||||||
|
self.int_field("id").with_required()
|
||||||
|
self.string_field("title").with_required(False)
|
||||||
|
self.string_field("content").with_required(False)
|
||||||
|
|
||||||
|
class PostMutation(Mutation):
|
||||||
|
|
||||||
|
def __init__(self, posts: PostDao, authors: AuthorDao):
|
||||||
|
Mutation.__init__(self)
|
||||||
|
|
||||||
|
self._posts = posts
|
||||||
|
self._authors = authors
|
||||||
|
|
||||||
|
self.field("create", int, resolver=self.create_post).with_public().with_required().with_argument(
|
||||||
|
"input",
|
||||||
|
PostCreateInput,
|
||||||
|
).with_required()
|
||||||
|
self.field("update", bool, resolver=self.update_post).with_public().with_required().with_argument(
|
||||||
|
"input",
|
||||||
|
PostUpdateInput,
|
||||||
|
).with_required()
|
||||||
|
self.field("delete", bool, resolver=self.delete_post).with_public().with_required().with_argument(
|
||||||
|
"id",
|
||||||
|
int,
|
||||||
|
).with_required()
|
||||||
|
self.field("restore", bool, resolver=self.restore_post).with_public().with_required().with_argument(
|
||||||
|
"id",
|
||||||
|
int,
|
||||||
|
).with_required()
|
||||||
|
|
||||||
|
async def create_post(self, input: PostCreateInput) -> int:
|
||||||
|
return await self._posts.create(Post(0, input.author_id, input.title, input.content))
|
||||||
|
|
||||||
|
async def update_post(self, input: PostUpdateInput) -> bool:
|
||||||
|
post = await self._posts.get_by_id(input.id)
|
||||||
|
if post is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
post.title = input.title if input.title is not None else post.title
|
||||||
|
post.content = input.content if input.content is not None else post.content
|
||||||
|
|
||||||
|
await self._posts.update(post)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def delete_post(self, id: int) -> bool:
|
||||||
|
post = await self._posts.get_by_id(id)
|
||||||
|
if post is None:
|
||||||
|
return False
|
||||||
|
await self._posts.delete(post)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def restore_post(self, id: int) -> bool:
|
||||||
|
post = await self._posts.get_by_id(id)
|
||||||
|
if post is None:
|
||||||
|
return False
|
||||||
|
await self._posts.restore(post)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class HelloQuery(Query):
|
|||||||
self.string_field(
|
self.string_field(
|
||||||
"message",
|
"message",
|
||||||
resolver=lambda name: f"Hello {name} {get_request().state.request_id}",
|
resolver=lambda name: f"Hello {name} {get_request().state.request_id}",
|
||||||
).with_argument(str, "name", "Name to greet", "world")
|
).with_argument("name", str, "Name to greet", "world")
|
||||||
|
|
||||||
self.collection_field(
|
self.collection_field(
|
||||||
UserGraphType,
|
UserGraphType,
|
||||||
|
|||||||
178
src/cpl-graphql/cpl/graphql/abc/query_abc.py
Normal file
178
src/cpl-graphql/cpl/graphql/abc/query_abc.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
import functools
|
||||||
|
import inspect
|
||||||
|
from abc import ABC
|
||||||
|
from asyncio import iscoroutinefunction
|
||||||
|
from typing import Callable, Type, Any, Optional
|
||||||
|
|
||||||
|
import strawberry
|
||||||
|
from strawberry.exceptions import StrawberryException
|
||||||
|
|
||||||
|
from cpl.api import Unauthorized, Forbidden
|
||||||
|
from cpl.core.ctx.user_context import get_user
|
||||||
|
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||||
|
from cpl.graphql.error import graphql_error
|
||||||
|
from cpl.graphql.query_context import QueryContext
|
||||||
|
from cpl.graphql.schema.field import Field
|
||||||
|
from cpl.graphql.typing import Resolver
|
||||||
|
from cpl.graphql.utils.type_collector import TypeCollector
|
||||||
|
|
||||||
|
|
||||||
|
class QueryABC(StrawberryProtocol, ABC):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
ABC.__init__(self)
|
||||||
|
self._fields: dict[str, Field] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fields(self) -> dict[str, Field]:
|
||||||
|
return self._fields
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fields_count(self) -> int:
|
||||||
|
return len(self._fields)
|
||||||
|
|
||||||
|
def get_fields(self) -> dict[str, Field]:
|
||||||
|
return self._fields
|
||||||
|
|
||||||
|
def field(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
t: type,
|
||||||
|
resolver: Resolver = None,
|
||||||
|
) -> Field:
|
||||||
|
from cpl.graphql.schema.field import Field
|
||||||
|
|
||||||
|
self._fields[name] = Field(name, t, resolver)
|
||||||
|
return self._fields[name]
|
||||||
|
|
||||||
|
def string_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||||
|
return self.field(name, str, resolver)
|
||||||
|
|
||||||
|
def int_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||||
|
return self.field(name, int, resolver)
|
||||||
|
|
||||||
|
def float_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||||
|
return self.field(name, float, resolver)
|
||||||
|
|
||||||
|
def bool_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||||
|
return self.field(name, bool, resolver)
|
||||||
|
|
||||||
|
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
|
||||||
|
return self.field(name, list[t], resolver)
|
||||||
|
|
||||||
|
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
|
||||||
|
return self.field(name, t().to_strawberry(), resolver)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_resolver(f: "Field"):
|
||||||
|
params: list[inspect.Parameter] = []
|
||||||
|
for arg in f.arguments.values():
|
||||||
|
_type = arg.type
|
||||||
|
if isinstance(_type, type) and issubclass(_type, StrawberryProtocol):
|
||||||
|
_type = _type().to_strawberry()
|
||||||
|
|
||||||
|
ann = Optional[_type] if arg.optional else _type
|
||||||
|
|
||||||
|
if arg.default is None:
|
||||||
|
param = inspect.Parameter(
|
||||||
|
arg.name,
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
annotation=ann,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
param = inspect.Parameter(
|
||||||
|
arg.name,
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
annotation=ann,
|
||||||
|
default=arg.default,
|
||||||
|
)
|
||||||
|
|
||||||
|
params.append(param)
|
||||||
|
|
||||||
|
sig = inspect.Signature(parameters=params, return_annotation=f.type)
|
||||||
|
|
||||||
|
async def _resolver(*args, **kwargs):
|
||||||
|
if f.resolver is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if iscoroutinefunction(f.resolver):
|
||||||
|
return await f.resolver(*args, **kwargs)
|
||||||
|
return f.resolver(*args, **kwargs)
|
||||||
|
|
||||||
|
_resolver.__signature__ = sig
|
||||||
|
return _resolver
|
||||||
|
|
||||||
|
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
|
||||||
|
sig = getattr(resolver, "__signature__", None)
|
||||||
|
|
||||||
|
@functools.wraps(resolver)
|
||||||
|
async def _auth_resolver(*args, **kwargs):
|
||||||
|
if f.public:
|
||||||
|
return await self._run_resolver(resolver, *args, **kwargs)
|
||||||
|
|
||||||
|
user = get_user()
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
raise graphql_error(Unauthorized(f"{f.name}: Authentication required"))
|
||||||
|
|
||||||
|
if f.require_any_permission:
|
||||||
|
if not any([await user.has_permission(p) for p in f.require_any_permission]):
|
||||||
|
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
|
||||||
|
|
||||||
|
if f.require_any:
|
||||||
|
perms, resolvers = f.require_any
|
||||||
|
if not any([await user.has_permission(p) for p in perms]):
|
||||||
|
ctx = QueryContext([x.name for x in await user.permissions])
|
||||||
|
resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers]
|
||||||
|
|
||||||
|
if not any(resolved):
|
||||||
|
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
|
||||||
|
|
||||||
|
return await self._run_resolver(resolver, *args, **kwargs)
|
||||||
|
|
||||||
|
if sig:
|
||||||
|
_auth_resolver.__signature__ = sig
|
||||||
|
|
||||||
|
return _auth_resolver
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _run_resolver(r: Callable, *args, **kwargs):
|
||||||
|
if iscoroutinefunction(r):
|
||||||
|
return await r(*args, **kwargs)
|
||||||
|
return r(*args, **kwargs)
|
||||||
|
|
||||||
|
def _field_to_strawberry(self, f: Field) -> Any:
|
||||||
|
resolver = None
|
||||||
|
try:
|
||||||
|
if f.arguments:
|
||||||
|
resolver = self._build_resolver(f)
|
||||||
|
elif not f.resolver:
|
||||||
|
resolver = lambda *_, **__: None
|
||||||
|
else:
|
||||||
|
ann = getattr(f.resolver, "__annotations__", {})
|
||||||
|
if "return" not in ann or ann["return"] is None:
|
||||||
|
ann = dict(ann)
|
||||||
|
ann["return"] = f.type
|
||||||
|
f.resolver.__annotations__ = ann
|
||||||
|
resolver = f.resolver
|
||||||
|
|
||||||
|
return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
|
||||||
|
except StrawberryException as e:
|
||||||
|
raise Exception(f"Error converting field '{f.name}' to strawberry field: {e}") from e
|
||||||
|
|
||||||
|
def to_strawberry(self) -> Type:
|
||||||
|
cls = self.__class__
|
||||||
|
if TypeCollector.has(cls):
|
||||||
|
return TypeCollector.get(cls)
|
||||||
|
|
||||||
|
annotations: dict[str, Any] = {}
|
||||||
|
namespace: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for name, f in self._fields.items():
|
||||||
|
annotations[name] = f.type
|
||||||
|
namespace[name] = self._field_to_strawberry(f)
|
||||||
|
|
||||||
|
namespace["__annotations__"] = annotations
|
||||||
|
gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
|
||||||
|
TypeCollector.set(cls, gql_type)
|
||||||
|
return gql_type
|
||||||
@@ -6,14 +6,15 @@ from cpl.graphql.schema.filter.date_filter import DateFilter
|
|||||||
from cpl.graphql.schema.filter.filter import Filter
|
from cpl.graphql.schema.filter.filter import Filter
|
||||||
from cpl.graphql.schema.filter.int_filter import IntFilter
|
from cpl.graphql.schema.filter.int_filter import IntFilter
|
||||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||||
|
from cpl.graphql.schema.root_mutation import RootMutation
|
||||||
from cpl.graphql.schema.root_query import RootQuery
|
from cpl.graphql.schema.root_query import RootQuery
|
||||||
from cpl.graphql.service.schema import Schema
|
|
||||||
from cpl.graphql.service.graphql import GraphQLService
|
from cpl.graphql.service.graphql import GraphQLService
|
||||||
|
from cpl.graphql.service.schema import Schema
|
||||||
|
|
||||||
|
|
||||||
class GraphQLModule(Module):
|
class GraphQLModule(Module):
|
||||||
dependencies = [ApiModule]
|
dependencies = [ApiModule]
|
||||||
singleton = [Schema, RootQuery]
|
singleton = [Schema, RootQuery, RootMutation]
|
||||||
scoped = [GraphQLService]
|
scoped = [GraphQLService]
|
||||||
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]
|
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Any
|
from typing import Optional
|
||||||
|
|
||||||
from graphql import GraphQLResolveInfo
|
from graphql import GraphQLResolveInfo
|
||||||
|
|
||||||
from cpl.auth.schema import AuthUser, Permission
|
from cpl.auth.schema import AuthUser, Permission
|
||||||
from cpl.core.ctx import get_user
|
from cpl.core.ctx import get_user
|
||||||
from cpl.core.utils import get_value
|
|
||||||
|
|
||||||
|
|
||||||
class QueryContext:
|
class QueryContext:
|
||||||
|
|||||||
@@ -1,38 +1,54 @@
|
|||||||
from typing import Any
|
from typing import Any, Self
|
||||||
|
|
||||||
|
|
||||||
class Argument:
|
class Argument:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
t: type,
|
|
||||||
name: str,
|
name: str,
|
||||||
|
t: type,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
default_value: Any = None,
|
default: Any = None,
|
||||||
optional: bool = None,
|
optional: bool = None,
|
||||||
):
|
):
|
||||||
self._type = t
|
|
||||||
self._name = name
|
self._name = name
|
||||||
|
self._type = t
|
||||||
self._description = description
|
self._description = description
|
||||||
self._default_value = default_value
|
self._default = default
|
||||||
self._optional = optional
|
self._optional = optional
|
||||||
|
|
||||||
@property
|
|
||||||
def type(self) -> type:
|
|
||||||
return self._type
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> type:
|
||||||
|
return self._type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str | None:
|
def description(self) -> str | None:
|
||||||
return self._description
|
return self._description
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_value(self) -> Any | None:
|
def default(self) -> Any | None:
|
||||||
return self._default_value
|
return self._default
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def optional(self) -> bool | None:
|
def optional(self) -> bool | None:
|
||||||
return self._optional
|
return self._optional
|
||||||
|
|
||||||
|
def with_description(self, description: str) -> Self:
|
||||||
|
self._description = description
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_default(self, default: Any) -> Self:
|
||||||
|
self._default = default
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_optional(self, optional: bool) -> Self:
|
||||||
|
self._optional = optional
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_required(self, required: bool = True) -> Self:
|
||||||
|
self._optional = not required
|
||||||
|
return self
|
||||||
|
|||||||
@@ -91,22 +91,26 @@ class Field:
|
|||||||
self._optional = optional
|
self._optional = optional
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def with_required(self, required: bool = True) -> Self:
|
||||||
|
self._optional = not required
|
||||||
|
return self
|
||||||
|
|
||||||
def with_default(self, default) -> Self:
|
def with_default(self, default) -> Self:
|
||||||
self._default = default
|
self._default = default
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self:
|
def with_argument(self, name: str, arg_type: type, description: str = None, default_value=None, optional=True) -> Argument:
|
||||||
if name in self._args:
|
if name in self._args:
|
||||||
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
|
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
|
||||||
self._args[name] = Argument(arg_type, name, description, default_value, optional)
|
self._args[name] = Argument(name, arg_type, description, default_value, optional)
|
||||||
return self
|
return self._args[name]
|
||||||
|
|
||||||
def with_arguments(self, args: list[Argument]) -> Self:
|
def with_arguments(self, args: list[Argument]) -> Self:
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if not isinstance(arg, Argument):
|
if not isinstance(arg, Argument):
|
||||||
raise ValueError(f"Expected Argument instance, got {type(arg)}")
|
raise ValueError(f"Expected Argument instance, got {type(arg)}")
|
||||||
|
|
||||||
self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional)
|
self.with_argument(arg.type, arg.name, arg.description, arg.default, arg.optional)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_require_any_permission(self, *permissions: TRequireAnyPermissions) -> Self:
|
def with_require_any_permission(self, *permissions: TRequireAnyPermissions) -> Self:
|
||||||
@@ -126,7 +130,7 @@ class Field:
|
|||||||
self._require_any = (permissions, resolvers)
|
self._require_any = (permissions, resolvers)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_public(self, public: bool = False) -> Self:
|
def with_public(self, public: bool = True) -> Self:
|
||||||
assert self._require_any is None, "Field cannot be public and have require_any set"
|
assert self._require_any is None, "Field cannot be public and have require_any set"
|
||||||
assert self._require_any_permission is None, "Field cannot be public and have require_any_permission set"
|
assert self._require_any_permission is None, "Field cannot be public and have require_any_permission set"
|
||||||
self._public = public
|
self._public = public
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Generic, Dict, Type, Optional, Self, Union
|
from typing import Generic, Dict, Type, Optional, Union, Any
|
||||||
|
|
||||||
import strawberry
|
import strawberry
|
||||||
|
|
||||||
@@ -12,12 +12,52 @@ _PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
|
|||||||
class Input(StrawberryProtocol, Generic[T]):
|
class Input(StrawberryProtocol, Generic[T]):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._fields: Dict[str, Field] = {}
|
self._fields: Dict[str, Field] = {}
|
||||||
|
self._values: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fields(self) -> Dict[str, Field]:
|
||||||
|
return self._fields
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
if item in self._values:
|
||||||
|
return self._values[item]
|
||||||
|
raise AttributeError(f"{self.__class__.__name__} has no attribute {item}")
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
if key in {"_fields", "_values"}:
|
||||||
|
super().__setattr__(key, value)
|
||||||
|
elif key in self._fields:
|
||||||
|
self._values[key] = value
|
||||||
|
else:
|
||||||
|
super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def get(self, key: str, default=None):
|
||||||
|
return self._values.get(key, default)
|
||||||
|
|
||||||
def get_fields(self) -> dict[str, Field]:
|
def get_fields(self) -> dict[str, Field]:
|
||||||
return self._fields
|
return self._fields
|
||||||
|
|
||||||
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True):
|
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True) -> Field:
|
||||||
self._fields[name] = Field(name, typ, optional=optional)
|
self._fields[name] = Field(name, typ, optional=optional)
|
||||||
|
return self._fields[name]
|
||||||
|
|
||||||
|
def string_field(self, name: str, optional: bool = True) -> Field:
|
||||||
|
return self.field(name, str)
|
||||||
|
|
||||||
|
def int_field(self, name: str, optional: bool = True) -> Field:
|
||||||
|
return self.field(name, int, optional)
|
||||||
|
|
||||||
|
def float_field(self, name: str, optional: bool = True) -> Field:
|
||||||
|
return self.field(name, float, optional)
|
||||||
|
|
||||||
|
def bool_field(self, name: str, optional: bool = True) -> Field:
|
||||||
|
return self.field(name, bool, optional)
|
||||||
|
|
||||||
|
def list_field(self, name: str, t: type, optional: bool = True) -> Field:
|
||||||
|
return self.field(name, list[t], optional)
|
||||||
|
|
||||||
|
def object_field(self, name: str, t: Type[StrawberryProtocol], optional: bool = True) -> Field:
|
||||||
|
return self.field(name, t().to_strawberry(), optional)
|
||||||
|
|
||||||
def to_strawberry(self) -> Type:
|
def to_strawberry(self) -> Type:
|
||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
|
|||||||
25
src/cpl-graphql/cpl/graphql/schema/mutation.py
Normal file
25
src/cpl-graphql/cpl/graphql/schema/mutation.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from cpl.dependency.inject import inject
|
||||||
|
from cpl.dependency.service_provider import ServiceProvider
|
||||||
|
from cpl.graphql.abc.query_abc import QueryABC
|
||||||
|
from cpl.graphql.schema.field import Field
|
||||||
|
|
||||||
|
|
||||||
|
class Mutation(QueryABC):
|
||||||
|
|
||||||
|
@inject
|
||||||
|
def __init__(self, provider: ServiceProvider):
|
||||||
|
QueryABC.__init__(self)
|
||||||
|
self._provider = provider
|
||||||
|
|
||||||
|
from cpl.graphql.service.schema import Schema
|
||||||
|
|
||||||
|
self._schema = provider.get_service(Schema)
|
||||||
|
|
||||||
|
def with_mutation(self, name: str, cls: Type["Mutation"]) -> Field:
|
||||||
|
sub = self._provider.get_service(cls)
|
||||||
|
if not sub:
|
||||||
|
raise ValueError(f"Mutation '{cls.__name__}' not registered in service provider")
|
||||||
|
|
||||||
|
return self.field(name, sub.to_strawberry(), lambda: sub)
|
||||||
@@ -1,76 +1,32 @@
|
|||||||
import functools
|
from typing import Callable, Type
|
||||||
import inspect
|
|
||||||
from asyncio import iscoroutinefunction
|
|
||||||
from typing import Callable, Type, Any, Optional
|
|
||||||
|
|
||||||
import strawberry
|
|
||||||
from strawberry.exceptions import StrawberryException
|
|
||||||
|
|
||||||
from cpl.api import Unauthorized, Forbidden
|
|
||||||
from cpl.core.ctx import get_user
|
|
||||||
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
||||||
from cpl.dependency.inject import inject
|
from cpl.dependency.inject import inject
|
||||||
from cpl.dependency.service_provider import ServiceProvider
|
from cpl.dependency.service_provider import ServiceProvider
|
||||||
|
from cpl.graphql.abc.query_abc import QueryABC
|
||||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||||
from cpl.graphql.error import graphql_error
|
|
||||||
from cpl.graphql.query_context import QueryContext
|
|
||||||
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
|
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
|
||||||
from cpl.graphql.schema.field import Field
|
from cpl.graphql.schema.field import Field
|
||||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||||
from cpl.graphql.typing import Resolver
|
|
||||||
from cpl.graphql.utils.type_collector import TypeCollector
|
|
||||||
|
|
||||||
|
|
||||||
class Query(StrawberryProtocol):
|
class Query(QueryABC):
|
||||||
|
|
||||||
@inject
|
@inject
|
||||||
def __init__(self, provider: ServiceProvider):
|
def __init__(self, provider: ServiceProvider):
|
||||||
|
QueryABC.__init__(self)
|
||||||
self._provider = provider
|
self._provider = provider
|
||||||
|
|
||||||
from cpl.graphql.service.schema import Schema
|
from cpl.graphql.service.schema import Schema
|
||||||
|
|
||||||
self._schema = provider.get_service(Schema)
|
self._schema = provider.get_service(Schema)
|
||||||
self._fields: dict[str, Field] = {}
|
|
||||||
|
|
||||||
def get_fields(self) -> dict[str, Field]:
|
def with_query(self, name: str, subquery_cls: Type["Query"]) -> Field:
|
||||||
return self._fields
|
|
||||||
|
|
||||||
def field(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
t: type,
|
|
||||||
resolver: Resolver = None,
|
|
||||||
) -> Field:
|
|
||||||
from cpl.graphql.schema.field import Field
|
|
||||||
|
|
||||||
self._fields[name] = Field(name, t, resolver)
|
|
||||||
return self._fields[name]
|
|
||||||
|
|
||||||
def string_field(self, name: str, resolver: Resolver = None) -> Field:
|
|
||||||
return self.field(name, str, resolver)
|
|
||||||
|
|
||||||
def int_field(self, name: str, resolver: Resolver = None) -> Field:
|
|
||||||
return self.field(name, int, resolver)
|
|
||||||
|
|
||||||
def float_field(self, name: str, resolver: Resolver = None) -> Field:
|
|
||||||
return self.field(name, float, resolver)
|
|
||||||
|
|
||||||
def bool_field(self, name: str, resolver: Resolver = None) -> Field:
|
|
||||||
return self.field(name, bool, resolver)
|
|
||||||
|
|
||||||
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
|
|
||||||
return self.field(name, list[t], resolver)
|
|
||||||
|
|
||||||
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
|
|
||||||
return self.field(name, t().to_strawberry(), resolver)
|
|
||||||
|
|
||||||
def with_query(self, name: str, subquery_cls: Type["Query"]):
|
|
||||||
sub = self._provider.get_service(subquery_cls)
|
sub = self._provider.get_service(subquery_cls)
|
||||||
if not sub:
|
if not sub:
|
||||||
raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider")
|
raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider")
|
||||||
|
|
||||||
self.field(name, sub.to_strawberry(), lambda: sub)
|
return self.field(name, sub.to_strawberry(), lambda: sub)
|
||||||
return self
|
|
||||||
|
|
||||||
def collection_field(
|
def collection_field(
|
||||||
self,
|
self,
|
||||||
@@ -105,10 +61,10 @@ class Query(StrawberryProtocol):
|
|||||||
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
|
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
|
||||||
|
|
||||||
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection)
|
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection)
|
||||||
f.with_argument(filter.to_strawberry(), "filter")
|
f.with_argument("filter", filter.to_strawberry())
|
||||||
f.with_argument(sort.to_strawberry(), "sort")
|
f.with_argument("sort", sort.to_strawberry())
|
||||||
f.with_argument(int, "skip", default_value=0)
|
f.with_argument("skip", int, default_value=0)
|
||||||
f.with_argument(int, "take", default_value=10)
|
f.with_argument("take", int, default_value=10)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
def dao_collection_field(
|
def dao_collection_field(
|
||||||
@@ -168,120 +124,8 @@ class Query(StrawberryProtocol):
|
|||||||
return Collection(nodes=data, total_count=total_count, count=len(data))
|
return Collection(nodes=data, total_count=total_count, count=len(data))
|
||||||
|
|
||||||
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver)
|
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver)
|
||||||
f.with_argument(filter.to_strawberry(), "filter")
|
f.with_argument("filter", filter.to_strawberry())
|
||||||
f.with_argument(sort.to_strawberry(), "sort")
|
f.with_argument("sort", sort.to_strawberry())
|
||||||
f.with_argument(int, "skip", default_value=0)
|
f.with_argument("skip", int, default_value=0)
|
||||||
f.with_argument(int, "take", default_value=10)
|
f.with_argument("take", int, default_value=10)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_resolver(f: "Field"):
|
|
||||||
params: list[inspect.Parameter] = []
|
|
||||||
for arg in f.arguments.values():
|
|
||||||
ann = Optional[arg.type] if arg.optional else arg.type
|
|
||||||
|
|
||||||
if arg.default_value is None:
|
|
||||||
param = inspect.Parameter(
|
|
||||||
arg.name,
|
|
||||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
||||||
annotation=ann,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
param = inspect.Parameter(
|
|
||||||
arg.name,
|
|
||||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
||||||
annotation=ann,
|
|
||||||
default=arg.default_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
params.append(param)
|
|
||||||
|
|
||||||
sig = inspect.Signature(parameters=params, return_annotation=f.type)
|
|
||||||
|
|
||||||
async def _resolver(*args, **kwargs):
|
|
||||||
if f.resolver is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if iscoroutinefunction(f.resolver):
|
|
||||||
return await f.resolver(*args, **kwargs)
|
|
||||||
return f.resolver(*args, **kwargs)
|
|
||||||
|
|
||||||
_resolver.__signature__ = sig
|
|
||||||
return _resolver
|
|
||||||
|
|
||||||
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
|
|
||||||
sig = getattr(resolver, "__signature__", None)
|
|
||||||
|
|
||||||
@functools.wraps(resolver)
|
|
||||||
async def _auth_resolver(*args, **kwargs):
|
|
||||||
if f.public:
|
|
||||||
return await self._run_resolver(resolver, *args, **kwargs)
|
|
||||||
|
|
||||||
user = get_user()
|
|
||||||
|
|
||||||
if user is None:
|
|
||||||
raise graphql_error(Unauthorized(f"{f.name}: Authentication required"))
|
|
||||||
|
|
||||||
if f.require_any_permission:
|
|
||||||
if not any([await user.has_permission(p) for p in f.require_any_permission]):
|
|
||||||
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
|
|
||||||
|
|
||||||
if f.require_any:
|
|
||||||
perms, resolvers = f.require_any
|
|
||||||
if not any([await user.has_permission(p) for p in perms]):
|
|
||||||
ctx = QueryContext([x.name for x in await user.permissions])
|
|
||||||
resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers]
|
|
||||||
|
|
||||||
if not any(resolved):
|
|
||||||
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
|
|
||||||
|
|
||||||
return await self._run_resolver(resolver, *args, **kwargs)
|
|
||||||
|
|
||||||
if sig:
|
|
||||||
_auth_resolver.__signature__ = sig
|
|
||||||
|
|
||||||
return _auth_resolver
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _run_resolver(r: Callable, *args, **kwargs):
|
|
||||||
if iscoroutinefunction(r):
|
|
||||||
return await r(*args, **kwargs)
|
|
||||||
return r(*args, **kwargs)
|
|
||||||
|
|
||||||
def _field_to_strawberry(self, f: Field) -> Any:
|
|
||||||
resolver = None
|
|
||||||
try:
|
|
||||||
if f.arguments:
|
|
||||||
resolver = self._build_resolver(f)
|
|
||||||
elif not f.resolver:
|
|
||||||
resolver = lambda *_, **__: None
|
|
||||||
else:
|
|
||||||
ann = getattr(f.resolver, "__annotations__", {})
|
|
||||||
if "return" not in ann or ann["return"] is None:
|
|
||||||
ann = dict(ann)
|
|
||||||
ann["return"] = f.type
|
|
||||||
f.resolver.__annotations__ = ann
|
|
||||||
resolver = f.resolver
|
|
||||||
|
|
||||||
return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
|
|
||||||
except StrawberryException as e:
|
|
||||||
raise Exception(
|
|
||||||
f"Error converting field '{f.name}' to strawberry field: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
def to_strawberry(self) -> Type:
|
|
||||||
cls = self.__class__
|
|
||||||
if TypeCollector.has(cls):
|
|
||||||
return TypeCollector.get(cls)
|
|
||||||
|
|
||||||
annotations: dict[str, Any] = {}
|
|
||||||
namespace: dict[str, Any] = {}
|
|
||||||
|
|
||||||
for name, f in self._fields.items():
|
|
||||||
annotations[name] = f.type
|
|
||||||
namespace[name] = self._field_to_strawberry(f)
|
|
||||||
|
|
||||||
namespace["__annotations__"] = annotations
|
|
||||||
gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
|
|
||||||
TypeCollector.set(cls, gql_type)
|
|
||||||
return gql_type
|
|
||||||
|
|||||||
6
src/cpl-graphql/cpl/graphql/schema/root_mutation.py
Normal file
6
src/cpl-graphql/cpl/graphql/schema/root_mutation.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from cpl.graphql.schema.mutation import Mutation
|
||||||
|
|
||||||
|
|
||||||
|
class RootMutation(Mutation):
|
||||||
|
def __init__(self):
|
||||||
|
Mutation.__init__(self)
|
||||||
@@ -6,6 +6,7 @@ import strawberry
|
|||||||
from cpl.api.logger import APILogger
|
from cpl.api.logger import APILogger
|
||||||
from cpl.dependency.service_provider import ServiceProvider
|
from cpl.dependency.service_provider import ServiceProvider
|
||||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||||
|
from cpl.graphql.schema.root_mutation import RootMutation
|
||||||
from cpl.graphql.schema.root_query import RootQuery
|
from cpl.graphql.schema.root_query import RootQuery
|
||||||
|
|
||||||
|
|
||||||
@@ -25,7 +26,17 @@ class Schema:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def query(self) -> RootQuery:
|
def query(self) -> RootQuery:
|
||||||
return self._provider.get_service(RootQuery)
|
query = self._provider.get_service(RootQuery)
|
||||||
|
if not query:
|
||||||
|
raise ValueError("RootQuery not registered in service provider")
|
||||||
|
return query
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mutation(self) -> RootMutation:
|
||||||
|
mutation = self._provider.get_service(RootMutation)
|
||||||
|
if not mutation:
|
||||||
|
raise ValueError("RootMutation not registered in service provider")
|
||||||
|
return mutation
|
||||||
|
|
||||||
def with_type(self, t: Type[StrawberryProtocol]) -> Self:
|
def with_type(self, t: Type[StrawberryProtocol]) -> Self:
|
||||||
self._types[t.__name__] = t
|
self._types[t.__name__] = t
|
||||||
@@ -43,13 +54,13 @@ class Schema:
|
|||||||
|
|
||||||
def build(self) -> strawberry.Schema:
|
def build(self) -> strawberry.Schema:
|
||||||
logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL)
|
logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL)
|
||||||
query = self._provider.get_service(RootQuery)
|
|
||||||
if not query:
|
query = self.query
|
||||||
raise ValueError("RootQuery not registered in service provider")
|
mutation = self.mutation
|
||||||
|
|
||||||
self._schema = strawberry.Schema(
|
self._schema = strawberry.Schema(
|
||||||
query=query.to_strawberry(),
|
query=query.to_strawberry() if query.fields_count > 0 else None,
|
||||||
mutation=None,
|
mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None,
|
||||||
subscription=None,
|
subscription=None,
|
||||||
types=self._get_types(),
|
types=self._get_types(),
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user