Changed to strawberry #181
Some checks failed
Test before pr merge / test-lint (pull_request) Failing after 6s
Some checks failed
Test before pr merge / test-lint (pull_request) Failing after 6s
This commit is contained in:
@@ -1,7 +1,8 @@
|
|||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
from api.src.queries.cities import CityGraphType
|
from api.src.queries.cities import CityGraphType, CityFilter, CitySort
|
||||||
from api.src.queries.hello import UserGraphType
|
from api.src.queries.hello import UserGraphType
|
||||||
|
from api.src.queries.user import UserFilter, UserSort
|
||||||
from cpl.api.api_module import ApiModule
|
from cpl.api.api_module import ApiModule
|
||||||
from cpl.application.application_builder import ApplicationBuilder
|
from cpl.application.application_builder import ApplicationBuilder
|
||||||
from cpl.auth.permission.permissions import Permissions
|
from cpl.auth.permission.permissions import Permissions
|
||||||
@@ -38,7 +39,13 @@ def main():
|
|||||||
builder.services.add_cache(Role)
|
builder.services.add_cache(Role)
|
||||||
|
|
||||||
builder.services.add_transient(CityGraphType)
|
builder.services.add_transient(CityGraphType)
|
||||||
|
builder.services.add_transient(CityFilter)
|
||||||
|
builder.services.add_transient(CitySort)
|
||||||
|
|
||||||
builder.services.add_transient(UserGraphType)
|
builder.services.add_transient(UserGraphType)
|
||||||
|
builder.services.add_transient(UserFilter)
|
||||||
|
builder.services.add_transient(UserSort)
|
||||||
|
|
||||||
builder.services.add_transient(HelloQuery)
|
builder.services.add_transient(HelloQuery)
|
||||||
|
|
||||||
app = builder.build()
|
app = builder.build()
|
||||||
@@ -57,7 +64,7 @@ def main():
|
|||||||
app.with_routes_directory("routes")
|
app.with_routes_directory("routes")
|
||||||
|
|
||||||
schema = app.with_graphql()
|
schema = app.with_graphql()
|
||||||
schema.query.string_field("ping", resolver=lambda *_: "pong")
|
schema.query.string_field("ping", resolver=lambda: "pong")
|
||||||
schema.query.with_query("hello", HelloQuery)
|
schema.query.with_query("hello", HelloQuery)
|
||||||
|
|
||||||
app.with_playground()
|
app.with_playground()
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from cpl.graphql.schema.filter.filter import Filter
|
from cpl.graphql.schema.filter.filter import Filter
|
||||||
from cpl.graphql.schema.object_graph_type import ObjectGraphType
|
from cpl.graphql.schema.graph_type import GraphType
|
||||||
|
|
||||||
from cpl.graphql.schema.sort.sort import Sort
|
from cpl.graphql.schema.sort.sort import Sort
|
||||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||||
@@ -25,15 +25,15 @@ class CitySort(Sort[City]):
|
|||||||
self.field("name", SortOrder)
|
self.field("name", SortOrder)
|
||||||
|
|
||||||
|
|
||||||
class CityGraphType(ObjectGraphType):
|
class CityGraphType(GraphType[City]):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
ObjectGraphType.__init__(self)
|
GraphType.__init__(self)
|
||||||
|
|
||||||
self.string_field(
|
self.int_field(
|
||||||
"id",
|
"id",
|
||||||
resolver=lambda user, *_: user.id,
|
resolver=lambda root: root.id,
|
||||||
)
|
)
|
||||||
self.string_field(
|
self.string_field(
|
||||||
"name",
|
"name",
|
||||||
resolver=lambda user, *_: user.name,
|
resolver=lambda root: root.name,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class HelloQuery(Query):
|
|||||||
Query.__init__(self)
|
Query.__init__(self)
|
||||||
self.string_field(
|
self.string_field(
|
||||||
"message",
|
"message",
|
||||||
resolver=lambda *_, name: f"Hello {name} {get_request().state.request_id}",
|
resolver=lambda name: f"Hello {name} {get_request().state.request_id}",
|
||||||
).with_argument(str, "name", "Name to greet", "world")
|
).with_argument(str, "name", "Name to greet", "world")
|
||||||
|
|
||||||
self.collection_field(
|
self.collection_field(
|
||||||
@@ -19,12 +19,12 @@ class HelloQuery(Query):
|
|||||||
"users",
|
"users",
|
||||||
UserFilter,
|
UserFilter,
|
||||||
UserSort,
|
UserSort,
|
||||||
resolver=lambda *_: users,
|
resolver=lambda: users,
|
||||||
)
|
)
|
||||||
self.collection_field(
|
self.collection_field(
|
||||||
CityGraphType,
|
CityGraphType,
|
||||||
"cities",
|
"cities",
|
||||||
CityFilter,
|
CityFilter,
|
||||||
CitySort,
|
CitySort,
|
||||||
resolver=lambda *_: cities,
|
resolver=lambda: cities,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from cpl.graphql.schema.filter.filter import Filter
|
from cpl.graphql.schema.filter.filter import Filter
|
||||||
from cpl.graphql.schema.object_graph_type import ObjectGraphType
|
from cpl.graphql.schema.graph_type import GraphType
|
||||||
|
|
||||||
from cpl.graphql.schema.sort.sort import Sort
|
from cpl.graphql.schema.sort.sort import Sort
|
||||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||||
|
|
||||||
@@ -25,15 +24,16 @@ class UserSort(Sort[User]):
|
|||||||
self.field("name", SortOrder)
|
self.field("name", SortOrder)
|
||||||
|
|
||||||
|
|
||||||
class UserGraphType(ObjectGraphType):
|
class UserGraphType(GraphType[User]):
|
||||||
def __init__(self):
|
|
||||||
ObjectGraphType.__init__(self)
|
|
||||||
|
|
||||||
self.string_field(
|
def __init__(self):
|
||||||
|
GraphType.__init__(self)
|
||||||
|
|
||||||
|
self.int_field(
|
||||||
"id",
|
"id",
|
||||||
resolver=lambda user, *_: user.id,
|
resolver=lambda root: root.id,
|
||||||
)
|
)
|
||||||
self.string_field(
|
self.string_field(
|
||||||
"name",
|
"name",
|
||||||
resolver=lambda user, *_: user.name,
|
resolver=lambda root: root.name,
|
||||||
)
|
)
|
||||||
|
|||||||
0
src/cpl-graphql/cpl/graphql/abc/__init__.py
Normal file
0
src/cpl-graphql/cpl/graphql/abc/__init__.py
Normal file
9
src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py
Normal file
9
src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from typing import Protocol, Type, runtime_checkable
|
||||||
|
|
||||||
|
from cpl.graphql.schema.field import Field
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class StrawberryProtocol(Protocol):
|
||||||
|
def to_strawberry(self) -> Type: ...
|
||||||
|
def get_fields(self) -> dict[str, Field]: ...
|
||||||
@@ -1,17 +1,15 @@
|
|||||||
from cpl.api.api_module import ApiModule
|
from cpl.api.api_module import ApiModule
|
||||||
from cpl.dependency.module.module import Module
|
from cpl.dependency.module.module import Module
|
||||||
from cpl.dependency.service_provider import ServiceProvider
|
from cpl.dependency.service_provider import ServiceProvider
|
||||||
from cpl.graphql.schema.collection import CollectionGraphType
|
|
||||||
from cpl.graphql.schema.root_query import RootQuery
|
from cpl.graphql.schema.root_query import RootQuery
|
||||||
from cpl.graphql.service.schema import Schema
|
from cpl.graphql.service.schema import Schema
|
||||||
from cpl.graphql.service.service import GraphQLService
|
from cpl.graphql.service.service import GraphQLService
|
||||||
from cpl.graphql.service.type_converter import TypeConverter
|
|
||||||
|
|
||||||
|
|
||||||
class GraphQLModule(Module):
|
class GraphQLModule(Module):
|
||||||
dependencies = [ApiModule]
|
dependencies = [ApiModule]
|
||||||
singleton = [TypeConverter, Schema]
|
singleton = [Schema, RootQuery]
|
||||||
scoped = [GraphQLService, RootQuery, CollectionGraphType]
|
scoped = [GraphQLService]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def configure(services: ServiceProvider) -> None:
|
def configure(services: ServiceProvider) -> None:
|
||||||
|
|||||||
@@ -1,9 +1,21 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class Argument:
|
class Argument:
|
||||||
def __init__(self, t: type, name: str, description: str = None, default_value=None):
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
t: type,
|
||||||
|
name: str,
|
||||||
|
description: str = None,
|
||||||
|
default_value: Any = None,
|
||||||
|
optional: bool = None,
|
||||||
|
):
|
||||||
self._type = t
|
self._type = t
|
||||||
self._name = name
|
self._name = name
|
||||||
self._description = description
|
self._description = description
|
||||||
self._default_value = default_value
|
self._default_value = default_value
|
||||||
|
self._optional = optional
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self) -> type:
|
def type(self) -> type:
|
||||||
@@ -18,5 +30,9 @@ class Argument:
|
|||||||
return self._description
|
return self._description
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_value(self):
|
def default_value(self) -> Any | None:
|
||||||
return self._default_value
|
return self._default_value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def optional(self) -> bool | None:
|
||||||
|
return self._optional
|
||||||
|
|||||||
@@ -1,18 +1,53 @@
|
|||||||
from typing import Generic, Type
|
from typing import Type, Dict, List
|
||||||
|
|
||||||
|
import strawberry
|
||||||
|
|
||||||
from cpl.core.typing import T
|
from cpl.core.typing import T
|
||||||
from cpl.graphql.schema.graph_type import GraphType
|
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||||
|
|
||||||
|
|
||||||
class Collection(Generic[T]):
|
class CollectionGraphTypeFactory:
|
||||||
|
_cache: Dict[Type, Type] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls, node_type: Type[StrawberryProtocol]) -> Type:
|
||||||
|
if node_type in cls._cache:
|
||||||
|
return cls._cache[node_type]
|
||||||
|
|
||||||
|
gql_node = node_type().to_strawberry() if hasattr(node_type, "to_strawberry") else node_type
|
||||||
|
|
||||||
|
gql_type = strawberry.type(
|
||||||
|
type(
|
||||||
|
f"{node_type.__name__}Collection",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"__annotations__": {
|
||||||
|
"nodes": List[gql_node],
|
||||||
|
"total_count": int,
|
||||||
|
"count": int,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
cls._cache[node_type] = gql_type
|
||||||
|
return gql_type
|
||||||
|
|
||||||
|
|
||||||
|
class Collection:
|
||||||
def __init__(self, nodes: list[T], total_count: int, count: int):
|
def __init__(self, nodes: list[T], total_count: int, count: int):
|
||||||
self.nodes = nodes
|
self._nodes = nodes
|
||||||
self.totalCount = total_count
|
self._total_count = total_count
|
||||||
self.count = count
|
self._count = count
|
||||||
|
|
||||||
class CollectionGraphType(GraphType[T]):
|
@property
|
||||||
def __init__(self, t: Type[GraphType[T]]):
|
def nodes(self) -> list[T]:
|
||||||
GraphType.__init__(self)
|
return self._nodes
|
||||||
self.string_field("totalCount", resolver=lambda obj, *_: obj.totalCount)
|
|
||||||
self.string_field("count", resolver=lambda obj, *_: obj.count)
|
@property
|
||||||
self.list_field("nodes", t, resolver=lambda obj, *_: obj.nodes)
|
def total_count(self) -> int:
|
||||||
|
return self._total_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def count(self) -> int:
|
||||||
|
return self._count
|
||||||
|
|||||||
@@ -6,11 +6,24 @@ from cpl.graphql.typing import TQuery, Resolver
|
|||||||
|
|
||||||
class Field:
|
class Field:
|
||||||
|
|
||||||
def __init__(self, name: str, gql_type: type, resolver: Resolver = None, subquery: TQuery = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
gql_type: type = None,
|
||||||
|
resolver: Resolver = None,
|
||||||
|
optional=None,
|
||||||
|
default=None,
|
||||||
|
subquery: TQuery = None,
|
||||||
|
parent_type=None,
|
||||||
|
):
|
||||||
self._name = name
|
self._name = name
|
||||||
self._gql_type = gql_type
|
self._gql_type = gql_type
|
||||||
self._resolver = resolver
|
self._resolver = resolver
|
||||||
|
self._optional = optional or True
|
||||||
|
self._default = default
|
||||||
|
|
||||||
self._subquery = subquery
|
self._subquery = subquery
|
||||||
|
self._parent_type = parent_type
|
||||||
|
|
||||||
self._args: dict[str, Argument] = {}
|
self._args: dict[str, Argument] = {}
|
||||||
|
|
||||||
@@ -26,6 +39,14 @@ class Field:
|
|||||||
def resolver(self) -> callable:
|
def resolver(self) -> callable:
|
||||||
return self._resolver
|
return self._resolver
|
||||||
|
|
||||||
|
@property
|
||||||
|
def optional(self) -> bool | None:
|
||||||
|
return self._optional
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default(self):
|
||||||
|
return self._default
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def args(self) -> dict:
|
def args(self) -> dict:
|
||||||
return self._args
|
return self._args
|
||||||
@@ -34,10 +55,18 @@ class Field:
|
|||||||
def subquery(self) -> TQuery | None:
|
def subquery(self) -> TQuery | None:
|
||||||
return self._subquery
|
return self._subquery
|
||||||
|
|
||||||
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None) -> Self:
|
@property
|
||||||
|
def parent_type(self):
|
||||||
|
return self._parent_type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def arguments(self) -> dict[str, Argument]:
|
||||||
|
return self._args
|
||||||
|
|
||||||
|
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self:
|
||||||
if name in self._args:
|
if name in self._args:
|
||||||
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
|
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
|
||||||
self._args[name] = Argument(arg_type, name, description, default_value)
|
self._args[name] = Argument(arg_type, name, description, default_value, optional)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_arguments(self, args: list[Argument]) -> Self:
|
def with_arguments(self, args: list[Argument]) -> Self:
|
||||||
@@ -45,5 +74,5 @@ class Field:
|
|||||||
if not isinstance(arg, Argument):
|
if not isinstance(arg, Argument):
|
||||||
raise ValueError(f"Expected Argument instance, got {type(arg)}")
|
raise ValueError(f"Expected Argument instance, got {type(arg)}")
|
||||||
|
|
||||||
self.with_argument(arg.type, arg.name, arg.description, arg.default_value)
|
self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional)
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -3,7 +3,5 @@ from cpl.graphql.schema.input import Input
|
|||||||
|
|
||||||
|
|
||||||
class Filter(Input[T]):
|
class Filter(Input[T]):
|
||||||
def __init__(
|
def __init__(self):
|
||||||
self,
|
|
||||||
):
|
|
||||||
Input.__init__(self)
|
Input.__init__(self)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from cpl.core.typing import T
|
|||||||
from cpl.graphql.schema.query import Query
|
from cpl.graphql.schema.query import Query
|
||||||
|
|
||||||
|
|
||||||
class GraphType(Generic[T], Query):
|
class GraphType(Query, Generic[T]):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
Query.__init__(self)
|
Query.__init__(self)
|
||||||
@@ -1,26 +1,34 @@
|
|||||||
from datetime import datetime
|
from typing import Generic, Dict, Type, Any, Optional
|
||||||
from enum import Enum
|
|
||||||
from typing import Type, Generic
|
|
||||||
|
|
||||||
import graphene
|
import strawberry
|
||||||
|
|
||||||
from cpl.core.typing import T
|
from cpl.core.typing import T
|
||||||
|
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||||
from cpl.graphql.schema.field import Field
|
from cpl.graphql.schema.field import Field
|
||||||
|
|
||||||
|
|
||||||
class Input(Generic[T], graphene.InputObjectType):
|
class Input(StrawberryProtocol, Generic[T]):
|
||||||
def __init__(
|
def __init__(self):
|
||||||
self,
|
self._fields: Dict[str, Field] = {}
|
||||||
):
|
|
||||||
graphene.InputObjectType.__init__(self)
|
|
||||||
self._fields: dict[str, Field] = {}
|
|
||||||
|
|
||||||
def get_fields(self) -> dict[str, Field]:
|
def get_fields(self) -> dict[str, Field]:
|
||||||
return self._fields
|
return self._fields
|
||||||
|
|
||||||
def field(
|
def field(self, name: str, typ: type, optional: bool = True):
|
||||||
self,
|
self._fields[name] = Field(name, typ, optional=optional)
|
||||||
field: str,
|
|
||||||
t: Type["Input"] | Type[int | str | bool | datetime | list | Enum],
|
def to_strawberry(self) -> Type:
|
||||||
):
|
annotations = {}
|
||||||
self._fields[field] = Field(field, t)
|
namespace = {}
|
||||||
|
|
||||||
|
for name, f in self._fields.items():
|
||||||
|
ann = f.type if not f.optional else Optional[f.type]
|
||||||
|
annotations[name] = ann
|
||||||
|
|
||||||
|
if f.optional:
|
||||||
|
namespace[name] = None
|
||||||
|
elif f.default is not None:
|
||||||
|
namespace[name] = f.default
|
||||||
|
|
||||||
|
namespace["__annotations__"] = annotations
|
||||||
|
return strawberry.input(type(f"{self.__class__.__name__}Input", (), namespace))
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
from cpl.core.typing import T
|
|
||||||
from cpl.graphql.schema.graph_type import GraphType
|
|
||||||
from cpl.graphql.schema.query import Query
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectGraphType(GraphType[T], Query):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
Query.__init__(self)
|
|
||||||
@@ -1,21 +1,27 @@
|
|||||||
from typing import Callable, Type
|
import inspect
|
||||||
|
from typing import Callable, Type, Any, Optional
|
||||||
|
|
||||||
from graphene import ObjectType
|
import strawberry
|
||||||
|
from strawberry.exceptions import StrawberryException
|
||||||
|
|
||||||
from cpl.graphql.schema.argument import Argument
|
from cpl.dependency.inject import inject
|
||||||
|
from cpl.dependency.service_provider import ServiceProvider
|
||||||
|
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||||
|
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
|
||||||
from cpl.graphql.schema.field import Field
|
from cpl.graphql.schema.field import Field
|
||||||
from cpl.graphql.schema.filter.filter import Filter
|
|
||||||
from cpl.graphql.schema.sort.sort import Sort
|
|
||||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||||
from cpl.graphql.typing import Resolver
|
from cpl.graphql.typing import Resolver
|
||||||
|
|
||||||
|
|
||||||
class Query(ObjectType):
|
class Query(StrawberryProtocol):
|
||||||
|
|
||||||
def __init__(self):
|
@inject
|
||||||
from cpl.graphql.schema.field import Field
|
def __init__(self, provider: ServiceProvider):
|
||||||
|
self._provider = provider
|
||||||
|
|
||||||
ObjectType.__init__(self)
|
from cpl.graphql.service.schema import Schema
|
||||||
|
|
||||||
|
self._schema = provider.get_service(Schema)
|
||||||
self._fields: dict[str, Field] = {}
|
self._fields: dict[str, Field] = {}
|
||||||
|
|
||||||
def get_fields(self) -> dict[str, Field]:
|
def get_fields(self) -> dict[str, Field]:
|
||||||
@@ -25,69 +31,137 @@ class Query(ObjectType):
|
|||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
t: type,
|
t: type,
|
||||||
resolver: Callable | None = None,
|
resolver: Resolver = None,
|
||||||
) -> "Field":
|
) -> Field:
|
||||||
from cpl.graphql.schema.field import Field
|
from cpl.graphql.schema.field import Field
|
||||||
|
|
||||||
self._fields[name] = Field(name, t, resolver)
|
self._fields[name] = Field(name, t, resolver)
|
||||||
return self._fields[name]
|
return self._fields[name]
|
||||||
|
|
||||||
def with_query(self, name: str, subquery: Type["Query"]):
|
def string_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||||
from cpl.graphql.schema.field import Field
|
|
||||||
|
|
||||||
f = Field(name=name, gql_type=subquery, resolver=lambda root, info, **kwargs: {}, subquery=subquery)
|
|
||||||
self._fields[name] = f
|
|
||||||
return self._fields[name]
|
|
||||||
|
|
||||||
def string_field(self, name: str, resolver: Resolver = None) -> "Field":
|
|
||||||
return self.field(name, str, resolver)
|
return self.field(name, str, resolver)
|
||||||
|
|
||||||
def int_field(self, name: str, resolver: Resolver = None) -> "Field":
|
def int_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||||
return self.field(name, int, resolver)
|
return self.field(name, int, resolver)
|
||||||
|
|
||||||
def float_field(self, name: str, resolver: Resolver = None) -> "Field":
|
def float_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||||
return self.field(name, float, resolver)
|
return self.field(name, float, resolver)
|
||||||
|
|
||||||
def bool_field(self, name: str, resolver: Resolver = None) -> "Field":
|
def bool_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||||
return self.field(name, bool, resolver)
|
return self.field(name, bool, resolver)
|
||||||
|
|
||||||
def list_field(self, name: str, t: type, resolver: Resolver = None) -> "Field":
|
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
|
||||||
return self.field(name, list[t], resolver)
|
return self.field(name, list[t], resolver)
|
||||||
|
|
||||||
|
def with_query(self, name: str, subquery_cls: Type["Query"]):
|
||||||
|
sub = self._provider.get_service(subquery_cls)
|
||||||
|
if not sub:
|
||||||
|
raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider")
|
||||||
|
|
||||||
|
self.field(name, sub.to_strawberry(), lambda: sub)
|
||||||
|
return self
|
||||||
|
|
||||||
def collection_field(
|
def collection_field(
|
||||||
self, t: type, name: str, filter_type: type, sort_type: type, resolver: Resolver = None
|
self,
|
||||||
) -> "Field":
|
t: type,
|
||||||
from cpl.graphql.schema.collection import Collection, CollectionGraphType
|
name: str,
|
||||||
|
filter_type: Type[StrawberryProtocol],
|
||||||
|
sort_type: Type[StrawberryProtocol],
|
||||||
|
resolver: Callable,
|
||||||
|
) -> Field:
|
||||||
|
# self._schema.with_type(filter_type)
|
||||||
|
# self._schema.with_type(sort_type)
|
||||||
|
|
||||||
def _resolve_collection(*_, filter: Filter, sort: Sort, skip: int, take: int):
|
def _resolve_collection(filter=None, sort=None, skip=0, take=10):
|
||||||
items = resolver()
|
items = resolver()
|
||||||
|
if filter:
|
||||||
|
for field, value in filter.__dict__.items():
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
items = [i for i in items if getattr(i, field) == value]
|
||||||
|
|
||||||
for field in filter or []:
|
if sort:
|
||||||
if filter[field] is None:
|
for field, direction in sort.__dict__.items():
|
||||||
continue
|
reverse = direction == SortOrder.DESC
|
||||||
|
items = sorted(items, key=lambda i: getattr(i, field), reverse=reverse)
|
||||||
items = [item for item in items if getattr(item, field) == filter[field]]
|
|
||||||
|
|
||||||
for field in sort or []:
|
|
||||||
if sort[field] is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
reverse = sort[field] == SortOrder.DESC
|
|
||||||
items = sorted(items, key=lambda item: getattr(item, field), reverse=reverse)
|
|
||||||
|
|
||||||
total_count = len(items)
|
total_count = len(items)
|
||||||
paged = items[skip : skip + take]
|
paged = items[skip : skip + take]
|
||||||
return Collection(nodes=paged, total_count=total_count, count=len(paged))
|
return Collection(nodes=paged, total_count=total_count, count=len(paged))
|
||||||
|
|
||||||
# base = getattr(t, "__gqlname__", t.__class__.__name__)
|
filter = self._provider.get_service(filter_type)
|
||||||
wrapper = CollectionGraphType(t)
|
if not filter:
|
||||||
# wrapper.set_graphql_name(f"{base}Collection")
|
raise ValueError(f"Filter '{filter_type.__name__}' not registered in service provider")
|
||||||
f = self.field(name, wrapper, resolver=_resolve_collection)
|
|
||||||
return f.with_arguments(
|
sort = self._provider.get_service(sort_type)
|
||||||
[
|
if not sort:
|
||||||
Argument(filter_type, "filter"),
|
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
|
||||||
Argument(sort_type, "sort"),
|
|
||||||
Argument(int, "skip", default_value=0),
|
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection)
|
||||||
Argument(int, "take", default_value=10),
|
f.with_argument(filter.to_strawberry(), "filter")
|
||||||
]
|
f.with_argument(sort.to_strawberry(), "sort")
|
||||||
)
|
f.with_argument(int, "skip", default_value=0)
|
||||||
|
f.with_argument(int, "take", default_value=10)
|
||||||
|
return f
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_resolver(f: "Field"):
|
||||||
|
params: list[inspect.Parameter] = []
|
||||||
|
for arg in f.arguments.values():
|
||||||
|
ann = Optional[arg.type] if arg.optional else arg.type
|
||||||
|
|
||||||
|
if arg.default_value is None:
|
||||||
|
param = inspect.Parameter(
|
||||||
|
arg.name,
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
annotation=ann,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
param = inspect.Parameter(
|
||||||
|
arg.name,
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
annotation=ann,
|
||||||
|
default=arg.default_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
params.append(param)
|
||||||
|
|
||||||
|
sig = inspect.Signature(parameters=params, return_annotation=f.type)
|
||||||
|
|
||||||
|
def _resolver(*args, **kwargs):
|
||||||
|
return f.resolver(*args, **kwargs) if f.resolver else None
|
||||||
|
|
||||||
|
_resolver.__signature__ = sig
|
||||||
|
return _resolver
|
||||||
|
|
||||||
|
def _field_to_strawberry(self, f: Field) -> Any:
|
||||||
|
try:
|
||||||
|
if f.resolver:
|
||||||
|
ann = getattr(f.resolver, "__annotations__", {})
|
||||||
|
if "return" not in ann or ann["return"] is None:
|
||||||
|
ann = dict(ann)
|
||||||
|
ann["return"] = f.type
|
||||||
|
f.resolver.__annotations__ = ann
|
||||||
|
|
||||||
|
if f.arguments:
|
||||||
|
resolver = self._build_resolver(f)
|
||||||
|
return strawberry.field(resolver=resolver)
|
||||||
|
|
||||||
|
if not f.resolver:
|
||||||
|
return strawberry.field(resolver=lambda *_, **__: None)
|
||||||
|
|
||||||
|
return strawberry.field(resolver=f.resolver)
|
||||||
|
except StrawberryException as e:
|
||||||
|
raise Exception(
|
||||||
|
f"Error converting field '{f.name}' to strawberry field: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def to_strawberry(self) -> Type:
|
||||||
|
annotations: dict[str, Any] = {}
|
||||||
|
namespace: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for name, f in self._fields.items():
|
||||||
|
annotations[name] = f.type
|
||||||
|
namespace[name] = self._field_to_strawberry(f)
|
||||||
|
|
||||||
|
namespace["__annotations__"] = annotations
|
||||||
|
return strawberry.type(type(f"{self.__class__.__name__}GraphType", (), namespace))
|
||||||
|
|||||||
@@ -1,43 +1,54 @@
|
|||||||
import graphene
|
from typing import Type, Self
|
||||||
|
|
||||||
|
import strawberry
|
||||||
|
|
||||||
from cpl.api.logger import APILogger
|
from cpl.api.logger import APILogger
|
||||||
from cpl.dependency.service_provider import ServiceProvider
|
from cpl.dependency.service_provider import ServiceProvider
|
||||||
from cpl.graphql.schema.collection import CollectionGraphType
|
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||||
from cpl.graphql.schema.graph_type import GraphType
|
|
||||||
from cpl.graphql.schema.root_query import RootQuery
|
from cpl.graphql.schema.root_query import RootQuery
|
||||||
from cpl.graphql.service.type_converter import TypeConverter
|
|
||||||
|
|
||||||
|
|
||||||
class Schema:
|
class Schema:
|
||||||
|
|
||||||
def __init__(self, logger: APILogger, converter: TypeConverter, query: RootQuery, provider: ServiceProvider):
|
def __init__(self, logger: APILogger, provider: ServiceProvider):
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._provider = provider
|
self._provider = provider
|
||||||
self._converter = converter
|
|
||||||
|
|
||||||
self._types = set(GraphType.__subclasses__())
|
self._types: dict[str, Type[StrawberryProtocol]] = {}
|
||||||
self._types.remove(CollectionGraphType)
|
|
||||||
|
|
||||||
self._query = query
|
|
||||||
self._schema = None
|
self._schema = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def schema(self) -> graphene.Schema | None:
|
def schema(self) -> strawberry.Schema | None:
|
||||||
return self._schema
|
return self._schema
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def query(self) -> RootQuery:
|
def query(self) -> RootQuery:
|
||||||
return self._query
|
return self._provider.get_service(RootQuery)
|
||||||
|
|
||||||
def with_type(self, t: type[GraphType]):
|
def with_type(self, t: Type[StrawberryProtocol]) -> Self:
|
||||||
self._types.add(t)
|
self._types[t.__name__] = t
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def build(self) -> graphene.Schema:
|
def _get_types(self):
|
||||||
self._schema = graphene.Schema(
|
types: list[Type] = []
|
||||||
query=self._converter.to_graphene(self._query),
|
for t in self._types.values():
|
||||||
|
t_obj = self._provider.get_service(t)
|
||||||
|
if not t_obj:
|
||||||
|
raise ValueError(f"Type '{t.__name__}' not registered in service provider")
|
||||||
|
types.append(t_obj.to_strawberry())
|
||||||
|
|
||||||
|
return types
|
||||||
|
|
||||||
|
def build(self) -> strawberry.Schema:
|
||||||
|
query = self._provider.get_service(RootQuery)
|
||||||
|
if not query:
|
||||||
|
raise ValueError("RootQuery not registered in service provider")
|
||||||
|
|
||||||
|
self._schema = strawberry.Schema(
|
||||||
|
query=query.to_strawberry(),
|
||||||
mutation=None,
|
mutation=None,
|
||||||
subscription=None,
|
subscription=None,
|
||||||
# types=[self._converter.to_graphene(t) for t in self._types] if len(self._types) > 0 else None,
|
types=self._get_types(),
|
||||||
)
|
)
|
||||||
return self._schema
|
return self._schema
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class GraphQLService:
|
|||||||
variables: Optional[Dict[str, Any]],
|
variables: Optional[Dict[str, Any]],
|
||||||
request: TRequest,
|
request: TRequest,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
result = await self._schema.execute_async(
|
result = await self._schema.execute(
|
||||||
query,
|
query,
|
||||||
variable_values=variables,
|
variable_values=variables,
|
||||||
context_value={"request": request},
|
context_value={"request": request},
|
||||||
|
|||||||
@@ -1,89 +0,0 @@
|
|||||||
import typing
|
|
||||||
from enum import Enum
|
|
||||||
from inspect import isclass
|
|
||||||
|
|
||||||
import graphene
|
|
||||||
from typing import Any, get_origin, get_args
|
|
||||||
|
|
||||||
from cpl.dependency import ServiceProvider
|
|
||||||
from cpl.graphql.schema.argument import Argument
|
|
||||||
from cpl.graphql.schema.filter.filter import Filter
|
|
||||||
from cpl.graphql.schema.graph_type import GraphType
|
|
||||||
from cpl.graphql.schema.object_graph_type import ObjectGraphType
|
|
||||||
from cpl.graphql.schema.sort.sort import Sort
|
|
||||||
from cpl.graphql.typing import Resolver
|
|
||||||
from cpl.graphql.utils.name_pipe import NamePipe
|
|
||||||
|
|
||||||
|
|
||||||
class TypeConverter:
|
|
||||||
__scalar_map: dict[Any, type[graphene.Scalar]] = {
|
|
||||||
str: graphene.String,
|
|
||||||
int: graphene.Int,
|
|
||||||
float: graphene.Float,
|
|
||||||
bool: graphene.Boolean,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, provider: ServiceProvider):
|
|
||||||
self._provider = provider
|
|
||||||
|
|
||||||
def _field_to_graphene(self, t: typing.Type[graphene.Scalar] | type, args: dict[str, Argument] = None, resolver: Resolver = None) -> graphene.Field:
|
|
||||||
arguments = {}
|
|
||||||
if args is not None:
|
|
||||||
arguments = {
|
|
||||||
arg.name: graphene.Argument(self.to_graphene(arg.type), name=arg.name, description=arg.description, default_value=arg.default_value)
|
|
||||||
for arg in args.values()
|
|
||||||
}
|
|
||||||
|
|
||||||
return graphene.Field(t, args=arguments, resolver=resolver)
|
|
||||||
|
|
||||||
def to_graphene(self, t: Any, name: str | None = None) -> Any:
|
|
||||||
try:
|
|
||||||
origin = get_origin(t)
|
|
||||||
args = get_args(t)
|
|
||||||
|
|
||||||
if t in self.__scalar_map:
|
|
||||||
return self.__scalar_map[t]
|
|
||||||
|
|
||||||
if origin in (list, typing.List):
|
|
||||||
if not args:
|
|
||||||
raise ValueError("List must specify element type, e.g. list[str]")
|
|
||||||
inner = self.to_graphene(args[0])
|
|
||||||
return graphene.List(inner)
|
|
||||||
|
|
||||||
if t is list or t is typing.List:
|
|
||||||
raise ValueError("List must be parametrized: list[str], list[int], list[UserQuery]")
|
|
||||||
|
|
||||||
if isclass(t) and issubclass(t, Enum):
|
|
||||||
return graphene.Enum.from_enum(t)
|
|
||||||
|
|
||||||
from cpl.graphql.schema.query import Query
|
|
||||||
if isinstance(t, type) and issubclass(t, (Query)):
|
|
||||||
query = self._provider.get_service(t)
|
|
||||||
if query is None:
|
|
||||||
raise ValueError(f"Could not resolve query of type {t}")
|
|
||||||
|
|
||||||
t = query
|
|
||||||
|
|
||||||
if isinstance(t, type) and issubclass(t, (ObjectGraphType, GraphType, Filter, Sort)):
|
|
||||||
t = t()
|
|
||||||
|
|
||||||
if isinstance(t, (Query, Filter, Sort)):
|
|
||||||
attrs = {}
|
|
||||||
for field in t.get_fields().values():
|
|
||||||
if isclass(field.type) and issubclass(field.type, Query) and field.subquery is not None:
|
|
||||||
subquery = self._provider.get_service(field.subquery)
|
|
||||||
sub = self.to_graphene(subquery, name=field.name.capitalize())
|
|
||||||
attrs[field.name] = self._field_to_graphene(sub, field.args, field.resolver)
|
|
||||||
continue
|
|
||||||
|
|
||||||
attrs[field.name] = self._field_to_graphene(self.to_graphene(field.type), field.args, field.resolver)
|
|
||||||
|
|
||||||
class_name = NamePipe.to_str(name or t.__class__)
|
|
||||||
if isinstance(t, (Filter, Sort)):
|
|
||||||
return type(class_name, (graphene.InputObjectType,), attrs)
|
|
||||||
|
|
||||||
return type(class_name, (graphene.ObjectType,), attrs)
|
|
||||||
|
|
||||||
raise ValueError(f"Unsupported field type: {t}")
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Failed to convert type {t} to graphene type: {e}") from e
|
|
||||||
@@ -1,2 +1,2 @@
|
|||||||
cpl-api
|
cpl-api
|
||||||
graphene==3.4.3
|
strawberry-graphql==0.282.0
|
||||||
Reference in New Issue
Block a user