WIP: dev into master #184

Draft
edraft wants to merge 121 commits from dev into master
19 changed files with 317 additions and 230 deletions
Showing only changes of commit ada50c693e - Show all commits

View File

@@ -1,7 +1,8 @@
from starlette.responses import JSONResponse
from api.src.queries.cities import CityGraphType
from api.src.queries.cities import CityGraphType, CityFilter, CitySort
from api.src.queries.hello import UserGraphType
from api.src.queries.user import UserFilter, UserSort
from cpl.api.api_module import ApiModule
from cpl.application.application_builder import ApplicationBuilder
from cpl.auth.permission.permissions import Permissions
@@ -38,7 +39,13 @@ def main():
builder.services.add_cache(Role)
builder.services.add_transient(CityGraphType)
builder.services.add_transient(CityFilter)
builder.services.add_transient(CitySort)
builder.services.add_transient(UserGraphType)
builder.services.add_transient(UserFilter)
builder.services.add_transient(UserSort)
builder.services.add_transient(HelloQuery)
app = builder.build()
@@ -57,7 +64,7 @@ def main():
app.with_routes_directory("routes")
schema = app.with_graphql()
schema.query.string_field("ping", resolver=lambda *_: "pong")
schema.query.string_field("ping", resolver=lambda: "pong")
schema.query.with_query("hello", HelloQuery)
app.with_playground()

View File

@@ -1,5 +1,5 @@
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.object_graph_type import ObjectGraphType
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
@@ -25,15 +25,15 @@ class CitySort(Sort[City]):
self.field("name", SortOrder)
class CityGraphType(ObjectGraphType):
class CityGraphType(GraphType[City]):
def __init__(self):
ObjectGraphType.__init__(self)
GraphType.__init__(self)
self.string_field(
self.int_field(
"id",
resolver=lambda user, *_: user.id,
resolver=lambda root: root.id,
)
self.string_field(
"name",
resolver=lambda user, *_: user.name,
resolver=lambda root: root.name,
)

View File

@@ -11,7 +11,7 @@ class HelloQuery(Query):
Query.__init__(self)
self.string_field(
"message",
resolver=lambda *_, name: f"Hello {name} {get_request().state.request_id}",
resolver=lambda name: f"Hello {name} {get_request().state.request_id}",
).with_argument(str, "name", "Name to greet", "world")
self.collection_field(
@@ -19,12 +19,12 @@ class HelloQuery(Query):
"users",
UserFilter,
UserSort,
resolver=lambda *_: users,
resolver=lambda: users,
)
self.collection_field(
CityGraphType,
"cities",
CityFilter,
CitySort,
resolver=lambda *_: cities,
resolver=lambda: cities,
)

View File

@@ -1,6 +1,5 @@
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.object_graph_type import ObjectGraphType
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
@@ -25,15 +24,16 @@ class UserSort(Sort[User]):
self.field("name", SortOrder)
class UserGraphType(ObjectGraphType):
def __init__(self):
ObjectGraphType.__init__(self)
class UserGraphType(GraphType[User]):
self.string_field(
def __init__(self):
GraphType.__init__(self)
self.int_field(
"id",
resolver=lambda user, *_: user.id,
resolver=lambda root: root.id,
)
self.string_field(
"name",
resolver=lambda user, *_: user.name,
resolver=lambda root: root.name,
)

View File

@@ -0,0 +1,9 @@
from typing import Protocol, Type, runtime_checkable
from cpl.graphql.schema.field import Field
@runtime_checkable
class StrawberryProtocol(Protocol):
def to_strawberry(self) -> Type: ...
def get_fields(self) -> dict[str, Field]: ...

View File

@@ -1,17 +1,15 @@
from cpl.api.api_module import ApiModule
from cpl.dependency.module.module import Module
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.schema.collection import CollectionGraphType
from cpl.graphql.schema.root_query import RootQuery
from cpl.graphql.service.schema import Schema
from cpl.graphql.service.service import GraphQLService
from cpl.graphql.service.type_converter import TypeConverter
class GraphQLModule(Module):
dependencies = [ApiModule]
singleton = [TypeConverter, Schema]
scoped = [GraphQLService, RootQuery, CollectionGraphType]
singleton = [Schema, RootQuery]
scoped = [GraphQLService]
@staticmethod
def configure(services: ServiceProvider) -> None:

View File

@@ -1,9 +1,21 @@
from typing import Any
class Argument:
def __init__(self, t: type, name: str, description: str = None, default_value=None):
def __init__(
self,
t: type,
name: str,
description: str = None,
default_value: Any = None,
optional: bool = None,
):
self._type = t
self._name = name
self._description = description
self._default_value = default_value
self._optional = optional
@property
def type(self) -> type:
@@ -18,5 +30,9 @@ class Argument:
return self._description
@property
def default_value(self):
def default_value(self) -> Any | None:
return self._default_value
@property
def optional(self) -> bool | None:
return self._optional

View File

@@ -1,18 +1,53 @@
from typing import Generic, Type
from typing import Type, Dict, List
import strawberry
from cpl.core.typing import T
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
class Collection(Generic[T]):
class CollectionGraphTypeFactory:
_cache: Dict[Type, Type] = {}
@classmethod
def get(cls, node_type: Type[StrawberryProtocol]) -> Type:
if node_type in cls._cache:
return cls._cache[node_type]
gql_node = node_type().to_strawberry() if hasattr(node_type, "to_strawberry") else node_type
gql_type = strawberry.type(
type(
f"{node_type.__name__}Collection",
(),
{
"__annotations__": {
"nodes": List[gql_node],
"total_count": int,
"count": int,
}
},
)
)
cls._cache[node_type] = gql_type
return gql_type
class Collection:
def __init__(self, nodes: list[T], total_count: int, count: int):
self.nodes = nodes
self.totalCount = total_count
self.count = count
self._nodes = nodes
self._total_count = total_count
self._count = count
class CollectionGraphType(GraphType[T]):
def __init__(self, t: Type[GraphType[T]]):
GraphType.__init__(self)
self.string_field("totalCount", resolver=lambda obj, *_: obj.totalCount)
self.string_field("count", resolver=lambda obj, *_: obj.count)
self.list_field("nodes", t, resolver=lambda obj, *_: obj.nodes)
@property
def nodes(self) -> list[T]:
return self._nodes
@property
def total_count(self) -> int:
return self._total_count
@property
def count(self) -> int:
return self._count

View File

@@ -6,11 +6,24 @@ from cpl.graphql.typing import TQuery, Resolver
class Field:
def __init__(self, name: str, gql_type: type, resolver: Resolver = None, subquery: TQuery = None):
def __init__(
self,
name: str,
gql_type: type = None,
resolver: Resolver = None,
optional=None,
default=None,
subquery: TQuery = None,
parent_type=None,
):
self._name = name
self._gql_type = gql_type
self._resolver = resolver
self._optional = optional or True
self._default = default
self._subquery = subquery
self._parent_type = parent_type
self._args: dict[str, Argument] = {}
@@ -26,6 +39,14 @@ class Field:
def resolver(self) -> callable:
return self._resolver
@property
def optional(self) -> bool | None:
return self._optional
@property
def default(self):
return self._default
@property
def args(self) -> dict:
return self._args
@@ -34,10 +55,18 @@ class Field:
def subquery(self) -> TQuery | None:
return self._subquery
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None) -> Self:
@property
def parent_type(self):
return self._parent_type
@property
def arguments(self) -> dict[str, Argument]:
return self._args
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self:
if name in self._args:
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
self._args[name] = Argument(arg_type, name, description, default_value)
self._args[name] = Argument(arg_type, name, description, default_value, optional)
return self
def with_arguments(self, args: list[Argument]) -> Self:
@@ -45,5 +74,5 @@ class Field:
if not isinstance(arg, Argument):
raise ValueError(f"Expected Argument instance, got {type(arg)}")
self.with_argument(arg.type, arg.name, arg.description, arg.default_value)
self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional)
return self

View File

@@ -3,7 +3,5 @@ from cpl.graphql.schema.input import Input
class Filter(Input[T]):
def __init__(
self,
):
def __init__(self):
Input.__init__(self)

View File

@@ -4,7 +4,7 @@ from cpl.core.typing import T
from cpl.graphql.schema.query import Query
class GraphType(Generic[T], Query):
class GraphType(Query, Generic[T]):
def __init__(self):
Query.__init__(self)

View File

@@ -1,26 +1,34 @@
from datetime import datetime
from enum import Enum
from typing import Type, Generic
from typing import Generic, Dict, Type, Any, Optional
import graphene
import strawberry
from cpl.core.typing import T
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.field import Field
class Input(Generic[T], graphene.InputObjectType):
def __init__(
self,
):
graphene.InputObjectType.__init__(self)
self._fields: dict[str, Field] = {}
class Input(StrawberryProtocol, Generic[T]):
def __init__(self):
self._fields: Dict[str, Field] = {}
def get_fields(self) -> dict[str, Field]:
return self._fields
def field(
self,
field: str,
t: Type["Input"] | Type[int | str | bool | datetime | list | Enum],
):
self._fields[field] = Field(field, t)
def field(self, name: str, typ: type, optional: bool = True):
self._fields[name] = Field(name, typ, optional=optional)
def to_strawberry(self) -> Type:
annotations = {}
namespace = {}
for name, f in self._fields.items():
ann = f.type if not f.optional else Optional[f.type]
annotations[name] = ann
if f.optional:
namespace[name] = None
elif f.default is not None:
namespace[name] = f.default
namespace["__annotations__"] = annotations
return strawberry.input(type(f"{self.__class__.__name__}Input", (), namespace))

View File

@@ -1,9 +0,0 @@
from cpl.core.typing import T
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.query import Query
class ObjectGraphType(GraphType[T], Query):
def __init__(self):
Query.__init__(self)

View File

@@ -1,21 +1,27 @@
from typing import Callable, Type
import inspect
from typing import Callable, Type, Any, Optional
from graphene import ObjectType
import strawberry
from strawberry.exceptions import StrawberryException
from cpl.graphql.schema.argument import Argument
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
from cpl.graphql.schema.field import Field
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
from cpl.graphql.typing import Resolver
class Query(ObjectType):
class Query(StrawberryProtocol):
def __init__(self):
from cpl.graphql.schema.field import Field
@inject
def __init__(self, provider: ServiceProvider):
self._provider = provider
ObjectType.__init__(self)
from cpl.graphql.service.schema import Schema
self._schema = provider.get_service(Schema)
self._fields: dict[str, Field] = {}
def get_fields(self) -> dict[str, Field]:
@@ -25,69 +31,137 @@ class Query(ObjectType):
self,
name: str,
t: type,
resolver: Callable | None = None,
) -> "Field":
resolver: Resolver = None,
) -> Field:
from cpl.graphql.schema.field import Field
self._fields[name] = Field(name, t, resolver)
return self._fields[name]
def with_query(self, name: str, subquery: Type["Query"]):
from cpl.graphql.schema.field import Field
f = Field(name=name, gql_type=subquery, resolver=lambda root, info, **kwargs: {}, subquery=subquery)
self._fields[name] = f
return self._fields[name]
def string_field(self, name: str, resolver: Resolver = None) -> "Field":
def string_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, str, resolver)
def int_field(self, name: str, resolver: Resolver = None) -> "Field":
def int_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, int, resolver)
def float_field(self, name: str, resolver: Resolver = None) -> "Field":
def float_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, float, resolver)
def bool_field(self, name: str, resolver: Resolver = None) -> "Field":
def bool_field(self, name: str, resolver: Resolver = None) -> Field:
return self.field(name, bool, resolver)
def list_field(self, name: str, t: type, resolver: Resolver = None) -> "Field":
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
return self.field(name, list[t], resolver)
def with_query(self, name: str, subquery_cls: Type["Query"]):
sub = self._provider.get_service(subquery_cls)
if not sub:
raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider")
self.field(name, sub.to_strawberry(), lambda: sub)
return self
def collection_field(
self, t: type, name: str, filter_type: type, sort_type: type, resolver: Resolver = None
) -> "Field":
from cpl.graphql.schema.collection import Collection, CollectionGraphType
self,
t: type,
name: str,
filter_type: Type[StrawberryProtocol],
sort_type: Type[StrawberryProtocol],
resolver: Callable,
) -> Field:
# self._schema.with_type(filter_type)
# self._schema.with_type(sort_type)
def _resolve_collection(*_, filter: Filter, sort: Sort, skip: int, take: int):
def _resolve_collection(filter=None, sort=None, skip=0, take=10):
items = resolver()
if filter:
for field, value in filter.__dict__.items():
if value is None:
continue
items = [i for i in items if getattr(i, field) == value]
for field in filter or []:
if filter[field] is None:
continue
items = [item for item in items if getattr(item, field) == filter[field]]
for field in sort or []:
if sort[field] is None:
continue
reverse = sort[field] == SortOrder.DESC
items = sorted(items, key=lambda item: getattr(item, field), reverse=reverse)
if sort:
for field, direction in sort.__dict__.items():
reverse = direction == SortOrder.DESC
items = sorted(items, key=lambda i: getattr(i, field), reverse=reverse)
total_count = len(items)
paged = items[skip : skip + take]
return Collection(nodes=paged, total_count=total_count, count=len(paged))
# base = getattr(t, "__gqlname__", t.__class__.__name__)
wrapper = CollectionGraphType(t)
# wrapper.set_graphql_name(f"{base}Collection")
f = self.field(name, wrapper, resolver=_resolve_collection)
return f.with_arguments(
[
Argument(filter_type, "filter"),
Argument(sort_type, "sort"),
Argument(int, "skip", default_value=0),
Argument(int, "take", default_value=10),
]
)
filter = self._provider.get_service(filter_type)
if not filter:
raise ValueError(f"Filter '{filter_type.__name__}' not registered in service provider")
sort = self._provider.get_service(sort_type)
if not sort:
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection)
f.with_argument(filter.to_strawberry(), "filter")
f.with_argument(sort.to_strawberry(), "sort")
f.with_argument(int, "skip", default_value=0)
f.with_argument(int, "take", default_value=10)
return f
@staticmethod
def _build_resolver(f: "Field"):
params: list[inspect.Parameter] = []
for arg in f.arguments.values():
ann = Optional[arg.type] if arg.optional else arg.type
if arg.default_value is None:
param = inspect.Parameter(
arg.name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=ann,
)
else:
param = inspect.Parameter(
arg.name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=ann,
default=arg.default_value,
)
params.append(param)
sig = inspect.Signature(parameters=params, return_annotation=f.type)
def _resolver(*args, **kwargs):
return f.resolver(*args, **kwargs) if f.resolver else None
_resolver.__signature__ = sig
return _resolver
def _field_to_strawberry(self, f: Field) -> Any:
try:
if f.resolver:
ann = getattr(f.resolver, "__annotations__", {})
if "return" not in ann or ann["return"] is None:
ann = dict(ann)
ann["return"] = f.type
f.resolver.__annotations__ = ann
if f.arguments:
resolver = self._build_resolver(f)
return strawberry.field(resolver=resolver)
if not f.resolver:
return strawberry.field(resolver=lambda *_, **__: None)
return strawberry.field(resolver=f.resolver)
except StrawberryException as e:
raise Exception(
f"Error converting field '{f.name}' to strawberry field: {e}"
) from e
def to_strawberry(self) -> Type:
annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {}
for name, f in self._fields.items():
annotations[name] = f.type
namespace[name] = self._field_to_strawberry(f)
namespace["__annotations__"] = annotations
return strawberry.type(type(f"{self.__class__.__name__}GraphType", (), namespace))

View File

@@ -1,43 +1,54 @@
import graphene
from typing import Type, Self
import strawberry
from cpl.api.logger import APILogger
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.schema.collection import CollectionGraphType
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.root_query import RootQuery
from cpl.graphql.service.type_converter import TypeConverter
class Schema:
def __init__(self, logger: APILogger, converter: TypeConverter, query: RootQuery, provider: ServiceProvider):
def __init__(self, logger: APILogger, provider: ServiceProvider):
self._logger = logger
self._provider = provider
self._converter = converter
self._types = set(GraphType.__subclasses__())
self._types.remove(CollectionGraphType)
self._types: dict[str, Type[StrawberryProtocol]] = {}
self._query = query
self._schema = None
@property
def schema(self) -> graphene.Schema | None:
def schema(self) -> strawberry.Schema | None:
return self._schema
@property
def query(self) -> RootQuery:
return self._query
return self._provider.get_service(RootQuery)
def with_type(self, t: type[GraphType]):
self._types.add(t)
def with_type(self, t: Type[StrawberryProtocol]) -> Self:
self._types[t.__name__] = t
return self
def build(self) -> graphene.Schema:
self._schema = graphene.Schema(
query=self._converter.to_graphene(self._query),
def _get_types(self):
types: list[Type] = []
for t in self._types.values():
t_obj = self._provider.get_service(t)
if not t_obj:
raise ValueError(f"Type '{t.__name__}' not registered in service provider")
types.append(t_obj.to_strawberry())
return types
def build(self) -> strawberry.Schema:
query = self._provider.get_service(RootQuery)
if not query:
raise ValueError("RootQuery not registered in service provider")
self._schema = strawberry.Schema(
query=query.to_strawberry(),
mutation=None,
subscription=None,
# types=[self._converter.to_graphene(t) for t in self._types] if len(self._types) > 0 else None,
types=self._get_types(),
)
return self._schema

View File

@@ -16,7 +16,7 @@ class GraphQLService:
variables: Optional[Dict[str, Any]],
request: TRequest,
) -> Dict[str, Any]:
result = await self._schema.execute_async(
result = await self._schema.execute(
query,
variable_values=variables,
context_value={"request": request},

View File

@@ -1,89 +0,0 @@
import typing
from enum import Enum
from inspect import isclass
import graphene
from typing import Any, get_origin, get_args
from cpl.dependency import ServiceProvider
from cpl.graphql.schema.argument import Argument
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.object_graph_type import ObjectGraphType
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.typing import Resolver
from cpl.graphql.utils.name_pipe import NamePipe
class TypeConverter:
__scalar_map: dict[Any, type[graphene.Scalar]] = {
str: graphene.String,
int: graphene.Int,
float: graphene.Float,
bool: graphene.Boolean,
}
def __init__(self, provider: ServiceProvider):
self._provider = provider
def _field_to_graphene(self, t: typing.Type[graphene.Scalar] | type, args: dict[str, Argument] = None, resolver: Resolver = None) -> graphene.Field:
arguments = {}
if args is not None:
arguments = {
arg.name: graphene.Argument(self.to_graphene(arg.type), name=arg.name, description=arg.description, default_value=arg.default_value)
for arg in args.values()
}
return graphene.Field(t, args=arguments, resolver=resolver)
def to_graphene(self, t: Any, name: str | None = None) -> Any:
try:
origin = get_origin(t)
args = get_args(t)
if t in self.__scalar_map:
return self.__scalar_map[t]
if origin in (list, typing.List):
if not args:
raise ValueError("List must specify element type, e.g. list[str]")
inner = self.to_graphene(args[0])
return graphene.List(inner)
if t is list or t is typing.List:
raise ValueError("List must be parametrized: list[str], list[int], list[UserQuery]")
if isclass(t) and issubclass(t, Enum):
return graphene.Enum.from_enum(t)
from cpl.graphql.schema.query import Query
if isinstance(t, type) and issubclass(t, (Query)):
query = self._provider.get_service(t)
if query is None:
raise ValueError(f"Could not resolve query of type {t}")
t = query
if isinstance(t, type) and issubclass(t, (ObjectGraphType, GraphType, Filter, Sort)):
t = t()
if isinstance(t, (Query, Filter, Sort)):
attrs = {}
for field in t.get_fields().values():
if isclass(field.type) and issubclass(field.type, Query) and field.subquery is not None:
subquery = self._provider.get_service(field.subquery)
sub = self.to_graphene(subquery, name=field.name.capitalize())
attrs[field.name] = self._field_to_graphene(sub, field.args, field.resolver)
continue
attrs[field.name] = self._field_to_graphene(self.to_graphene(field.type), field.args, field.resolver)
class_name = NamePipe.to_str(name or t.__class__)
if isinstance(t, (Filter, Sort)):
return type(class_name, (graphene.InputObjectType,), attrs)
return type(class_name, (graphene.ObjectType,), attrs)
raise ValueError(f"Unsupported field type: {t}")
except Exception as e:
raise ValueError(f"Failed to convert type {t} to graphene type: {e}") from e

View File

@@ -1,2 +1,2 @@
cpl-api
graphene==3.4.3
strawberry-graphql==0.282.0