diff --git a/example/api/src/main.py b/example/api/src/main.py index b1f525cc..bfb953fb 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -1,7 +1,8 @@ from starlette.responses import JSONResponse -from api.src.queries.cities import CityGraphType +from api.src.queries.cities import CityGraphType, CityFilter, CitySort from api.src.queries.hello import UserGraphType +from api.src.queries.user import UserFilter, UserSort from cpl.api.api_module import ApiModule from cpl.application.application_builder import ApplicationBuilder from cpl.auth.permission.permissions import Permissions @@ -38,7 +39,13 @@ def main(): builder.services.add_cache(Role) builder.services.add_transient(CityGraphType) + builder.services.add_transient(CityFilter) + builder.services.add_transient(CitySort) + builder.services.add_transient(UserGraphType) + builder.services.add_transient(UserFilter) + builder.services.add_transient(UserSort) + builder.services.add_transient(HelloQuery) app = builder.build() @@ -57,7 +64,7 @@ def main(): app.with_routes_directory("routes") schema = app.with_graphql() - schema.query.string_field("ping", resolver=lambda *_: "pong") + schema.query.string_field("ping", resolver=lambda: "pong") schema.query.with_query("hello", HelloQuery) app.with_playground() diff --git a/example/api/src/queries/cities.py b/example/api/src/queries/cities.py index 4234f8e2..7fd88273 100644 --- a/example/api/src/queries/cities.py +++ b/example/api/src/queries/cities.py @@ -1,5 +1,5 @@ from cpl.graphql.schema.filter.filter import Filter -from cpl.graphql.schema.object_graph_type import ObjectGraphType +from cpl.graphql.schema.graph_type import GraphType from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort_order import SortOrder @@ -25,15 +25,15 @@ class CitySort(Sort[City]): self.field("name", SortOrder) -class CityGraphType(ObjectGraphType): +class CityGraphType(GraphType[City]): def __init__(self): - ObjectGraphType.__init__(self) + GraphType.__init__(self) - self.string_field( + self.int_field( "id", - resolver=lambda user, *_: user.id, + resolver=lambda root: root.id, ) self.string_field( "name", - resolver=lambda user, *_: user.name, + resolver=lambda root: root.name, ) diff --git a/example/api/src/queries/hello.py b/example/api/src/queries/hello.py index 0f61c27c..2f2ba633 100644 --- a/example/api/src/queries/hello.py +++ b/example/api/src/queries/hello.py @@ -11,7 +11,7 @@ class HelloQuery(Query): Query.__init__(self) self.string_field( "message", - resolver=lambda *_, name: f"Hello {name} {get_request().state.request_id}", + resolver=lambda name: f"Hello {name} {get_request().state.request_id}", ).with_argument(str, "name", "Name to greet", "world") self.collection_field( @@ -19,12 +19,12 @@ class HelloQuery(Query): "users", UserFilter, UserSort, - resolver=lambda *_: users, + resolver=lambda: users, ) self.collection_field( CityGraphType, "cities", CityFilter, CitySort, - resolver=lambda *_: cities, + resolver=lambda: cities, ) diff --git a/example/api/src/queries/user.py b/example/api/src/queries/user.py index 3c4dd70c..a35a1780 100644 --- a/example/api/src/queries/user.py +++ b/example/api/src/queries/user.py @@ -1,6 +1,5 @@ from cpl.graphql.schema.filter.filter import Filter -from cpl.graphql.schema.object_graph_type import ObjectGraphType - +from cpl.graphql.schema.graph_type import GraphType from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort_order import SortOrder @@ -25,15 +24,16 @@ class UserSort(Sort[User]): self.field("name", SortOrder) -class UserGraphType(ObjectGraphType): - def __init__(self): - ObjectGraphType.__init__(self) +class UserGraphType(GraphType[User]): - self.string_field( + def __init__(self): + GraphType.__init__(self) + + self.int_field( "id", - resolver=lambda user, *_: user.id, + resolver=lambda root: root.id, ) self.string_field( "name", - resolver=lambda user, *_: user.name, + resolver=lambda root: root.name, ) diff --git a/src/cpl-graphql/cpl/graphql/abc/__init__.py b/src/cpl-graphql/cpl/graphql/abc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py b/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py new file mode 100644 index 00000000..1c0b6592 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py @@ -0,0 +1,9 @@ +from typing import Protocol, Type, runtime_checkable + +from cpl.graphql.schema.field import Field + + +@runtime_checkable +class StrawberryProtocol(Protocol): + def to_strawberry(self) -> Type: ... + def get_fields(self) -> dict[str, Field]: ... diff --git a/src/cpl-graphql/cpl/graphql/graphql_module.py b/src/cpl-graphql/cpl/graphql/graphql_module.py index 29d9d79d..70efa400 100644 --- a/src/cpl-graphql/cpl/graphql/graphql_module.py +++ b/src/cpl-graphql/cpl/graphql/graphql_module.py @@ -1,17 +1,15 @@ from cpl.api.api_module import ApiModule from cpl.dependency.module.module import Module from cpl.dependency.service_provider import ServiceProvider -from cpl.graphql.schema.collection import CollectionGraphType from cpl.graphql.schema.root_query import RootQuery from cpl.graphql.service.schema import Schema from cpl.graphql.service.service import GraphQLService -from cpl.graphql.service.type_converter import TypeConverter class GraphQLModule(Module): dependencies = [ApiModule] - singleton = [TypeConverter, Schema] - scoped = [GraphQLService, RootQuery, CollectionGraphType] + singleton = [Schema, RootQuery] + scoped = [GraphQLService] @staticmethod def configure(services: ServiceProvider) -> None: diff --git a/src/cpl-graphql/cpl/graphql/schema/argument.py b/src/cpl-graphql/cpl/graphql/schema/argument.py index 2f3b938c..cbf8b32f 100644 --- a/src/cpl-graphql/cpl/graphql/schema/argument.py +++ b/src/cpl-graphql/cpl/graphql/schema/argument.py @@ -1,9 +1,21 @@ +from typing import Any + + class Argument: - def __init__(self, t: type, name: str, description: str = None, default_value=None): + + def __init__( + self, + t: type, + name: str, + description: str = None, + default_value: Any = None, + optional: bool = None, + ): self._type = t self._name = name self._description = description self._default_value = default_value + self._optional = optional @property def type(self) -> type: @@ -18,5 +30,9 @@ class Argument: return self._description @property - def default_value(self): + def default_value(self) -> Any | None: return self._default_value + + @property + def optional(self) -> bool | None: + return self._optional diff --git a/src/cpl-graphql/cpl/graphql/schema/collection.py b/src/cpl-graphql/cpl/graphql/schema/collection.py index f14269fc..68b8aa69 100644 --- a/src/cpl-graphql/cpl/graphql/schema/collection.py +++ b/src/cpl-graphql/cpl/graphql/schema/collection.py @@ -1,18 +1,53 @@ -from typing import Generic, Type +from typing import Type, Dict, List + +import strawberry from cpl.core.typing import T -from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol -class Collection(Generic[T]): +class CollectionGraphTypeFactory: + _cache: Dict[Type, Type] = {} + + @classmethod + def get(cls, node_type: Type[StrawberryProtocol]) -> Type: + if node_type in cls._cache: + return cls._cache[node_type] + + gql_node = node_type().to_strawberry() if hasattr(node_type, "to_strawberry") else node_type + + gql_type = strawberry.type( + type( + f"{node_type.__name__}Collection", + (), + { + "__annotations__": { + "nodes": List[gql_node], + "total_count": int, + "count": int, + } + }, + ) + ) + + cls._cache[node_type] = gql_type + return gql_type + + +class Collection: def __init__(self, nodes: list[T], total_count: int, count: int): - self.nodes = nodes - self.totalCount = total_count - self.count = count + self._nodes = nodes + self._total_count = total_count + self._count = count -class CollectionGraphType(GraphType[T]): - def __init__(self, t: Type[GraphType[T]]): - GraphType.__init__(self) - self.string_field("totalCount", resolver=lambda obj, *_: obj.totalCount) - self.string_field("count", resolver=lambda obj, *_: obj.count) - self.list_field("nodes", t, resolver=lambda obj, *_: obj.nodes) + @property + def nodes(self) -> list[T]: + return self._nodes + + @property + def total_count(self) -> int: + return self._total_count + + @property + def count(self) -> int: + return self._count diff --git a/src/cpl-graphql/cpl/graphql/schema/field.py b/src/cpl-graphql/cpl/graphql/schema/field.py index e6358e83..2231e11c 100644 --- a/src/cpl-graphql/cpl/graphql/schema/field.py +++ b/src/cpl-graphql/cpl/graphql/schema/field.py @@ -6,11 +6,24 @@ from cpl.graphql.typing import TQuery, Resolver class Field: - def __init__(self, name: str, gql_type: type, resolver: Resolver = None, subquery: TQuery = None): + def __init__( + self, + name: str, + gql_type: type = None, + resolver: Resolver = None, + optional=None, + default=None, + subquery: TQuery = None, + parent_type=None, + ): self._name = name self._gql_type = gql_type self._resolver = resolver + self._optional = optional or True + self._default = default + self._subquery = subquery + self._parent_type = parent_type self._args: dict[str, Argument] = {} @@ -26,6 +39,14 @@ class Field: def resolver(self) -> callable: return self._resolver + @property + def optional(self) -> bool | None: + return self._optional + + @property + def default(self): + return self._default + @property def args(self) -> dict: return self._args @@ -34,10 +55,18 @@ class Field: def subquery(self) -> TQuery | None: return self._subquery - def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None) -> Self: + @property + def parent_type(self): + return self._parent_type + + @property + def arguments(self) -> dict[str, Argument]: + return self._args + + def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self: if name in self._args: raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'") - self._args[name] = Argument(arg_type, name, description, default_value) + self._args[name] = Argument(arg_type, name, description, default_value, optional) return self def with_arguments(self, args: list[Argument]) -> Self: @@ -45,5 +74,5 @@ class Field: if not isinstance(arg, Argument): raise ValueError(f"Expected Argument instance, got {type(arg)}") - self.with_argument(arg.type, arg.name, arg.description, arg.default_value) + self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional) return self diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py index 26339bbc..2f76c4b4 100644 --- a/src/cpl-graphql/cpl/graphql/schema/filter/filter.py +++ b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py @@ -3,7 +3,5 @@ from cpl.graphql.schema.input import Input class Filter(Input[T]): - def __init__( - self, - ): + def __init__(self): Input.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/schema/graph_type.py b/src/cpl-graphql/cpl/graphql/schema/graph_type.py index 8fff69cf..e829b82d 100644 --- a/src/cpl-graphql/cpl/graphql/schema/graph_type.py +++ b/src/cpl-graphql/cpl/graphql/schema/graph_type.py @@ -4,7 +4,7 @@ from cpl.core.typing import T from cpl.graphql.schema.query import Query -class GraphType(Generic[T], Query): +class GraphType(Query, Generic[T]): def __init__(self): Query.__init__(self) \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/schema/input.py b/src/cpl-graphql/cpl/graphql/schema/input.py index 8f66c69c..4c9afc86 100644 --- a/src/cpl-graphql/cpl/graphql/schema/input.py +++ b/src/cpl-graphql/cpl/graphql/schema/input.py @@ -1,26 +1,34 @@ -from datetime import datetime -from enum import Enum -from typing import Type, Generic +from typing import Generic, Dict, Type, Any, Optional -import graphene +import strawberry from cpl.core.typing import T +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.schema.field import Field -class Input(Generic[T], graphene.InputObjectType): - def __init__( - self, - ): - graphene.InputObjectType.__init__(self) - self._fields: dict[str, Field] = {} +class Input(StrawberryProtocol, Generic[T]): + def __init__(self): + self._fields: Dict[str, Field] = {} def get_fields(self) -> dict[str, Field]: return self._fields - def field( - self, - field: str, - t: Type["Input"] | Type[int | str | bool | datetime | list | Enum], - ): - self._fields[field] = Field(field, t) + def field(self, name: str, typ: type, optional: bool = True): + self._fields[name] = Field(name, typ, optional=optional) + + def to_strawberry(self) -> Type: + annotations = {} + namespace = {} + + for name, f in self._fields.items(): + ann = f.type if not f.optional else Optional[f.type] + annotations[name] = ann + + if f.optional: + namespace[name] = None + elif f.default is not None: + namespace[name] = f.default + + namespace["__annotations__"] = annotations + return strawberry.input(type(f"{self.__class__.__name__}Input", (), namespace)) diff --git a/src/cpl-graphql/cpl/graphql/schema/object_graph_type.py b/src/cpl-graphql/cpl/graphql/schema/object_graph_type.py deleted file mode 100644 index 5cc46a0a..00000000 --- a/src/cpl-graphql/cpl/graphql/schema/object_graph_type.py +++ /dev/null @@ -1,9 +0,0 @@ -from cpl.core.typing import T -from cpl.graphql.schema.graph_type import GraphType -from cpl.graphql.schema.query import Query - - -class ObjectGraphType(GraphType[T], Query): - - def __init__(self): - Query.__init__(self) \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py index 0e14d6b9..a453734a 100644 --- a/src/cpl-graphql/cpl/graphql/schema/query.py +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -1,21 +1,27 @@ -from typing import Callable, Type +import inspect +from typing import Callable, Type, Any, Optional -from graphene import ObjectType +import strawberry +from strawberry.exceptions import StrawberryException -from cpl.graphql.schema.argument import Argument +from cpl.dependency.inject import inject +from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol +from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory from cpl.graphql.schema.field import Field -from cpl.graphql.schema.filter.filter import Filter -from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort_order import SortOrder from cpl.graphql.typing import Resolver -class Query(ObjectType): +class Query(StrawberryProtocol): - def __init__(self): - from cpl.graphql.schema.field import Field + @inject + def __init__(self, provider: ServiceProvider): + self._provider = provider - ObjectType.__init__(self) + from cpl.graphql.service.schema import Schema + + self._schema = provider.get_service(Schema) self._fields: dict[str, Field] = {} def get_fields(self) -> dict[str, Field]: @@ -25,69 +31,137 @@ class Query(ObjectType): self, name: str, t: type, - resolver: Callable | None = None, - ) -> "Field": + resolver: Resolver = None, + ) -> Field: from cpl.graphql.schema.field import Field self._fields[name] = Field(name, t, resolver) return self._fields[name] - def with_query(self, name: str, subquery: Type["Query"]): - from cpl.graphql.schema.field import Field - - f = Field(name=name, gql_type=subquery, resolver=lambda root, info, **kwargs: {}, subquery=subquery) - self._fields[name] = f - return self._fields[name] - - def string_field(self, name: str, resolver: Resolver = None) -> "Field": + def string_field(self, name: str, resolver: Resolver = None) -> Field: return self.field(name, str, resolver) - def int_field(self, name: str, resolver: Resolver = None) -> "Field": + def int_field(self, name: str, resolver: Resolver = None) -> Field: return self.field(name, int, resolver) - def float_field(self, name: str, resolver: Resolver = None) -> "Field": + def float_field(self, name: str, resolver: Resolver = None) -> Field: return self.field(name, float, resolver) - def bool_field(self, name: str, resolver: Resolver = None) -> "Field": + def bool_field(self, name: str, resolver: Resolver = None) -> Field: return self.field(name, bool, resolver) - def list_field(self, name: str, t: type, resolver: Resolver = None) -> "Field": + def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field: return self.field(name, list[t], resolver) + def with_query(self, name: str, subquery_cls: Type["Query"]): + sub = self._provider.get_service(subquery_cls) + if not sub: + raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider") + + self.field(name, sub.to_strawberry(), lambda: sub) + return self + def collection_field( - self, t: type, name: str, filter_type: type, sort_type: type, resolver: Resolver = None - ) -> "Field": - from cpl.graphql.schema.collection import Collection, CollectionGraphType + self, + t: type, + name: str, + filter_type: Type[StrawberryProtocol], + sort_type: Type[StrawberryProtocol], + resolver: Callable, + ) -> Field: + # self._schema.with_type(filter_type) + # self._schema.with_type(sort_type) - def _resolve_collection(*_, filter: Filter, sort: Sort, skip: int, take: int): + def _resolve_collection(filter=None, sort=None, skip=0, take=10): items = resolver() + if filter: + for field, value in filter.__dict__.items(): + if value is None: + continue + items = [i for i in items if getattr(i, field) == value] - for field in filter or []: - if filter[field] is None: - continue - - items = [item for item in items if getattr(item, field) == filter[field]] - - for field in sort or []: - if sort[field] is None: - continue - - reverse = sort[field] == SortOrder.DESC - items = sorted(items, key=lambda item: getattr(item, field), reverse=reverse) - + if sort: + for field, direction in sort.__dict__.items(): + reverse = direction == SortOrder.DESC + items = sorted(items, key=lambda i: getattr(i, field), reverse=reverse) total_count = len(items) paged = items[skip : skip + take] return Collection(nodes=paged, total_count=total_count, count=len(paged)) - # base = getattr(t, "__gqlname__", t.__class__.__name__) - wrapper = CollectionGraphType(t) - # wrapper.set_graphql_name(f"{base}Collection") - f = self.field(name, wrapper, resolver=_resolve_collection) - return f.with_arguments( - [ - Argument(filter_type, "filter"), - Argument(sort_type, "sort"), - Argument(int, "skip", default_value=0), - Argument(int, "take", default_value=10), - ] - ) + filter = self._provider.get_service(filter_type) + if not filter: + raise ValueError(f"Filter '{filter_type.__name__}' not registered in service provider") + + sort = self._provider.get_service(sort_type) + if not sort: + raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider") + + f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection) + f.with_argument(filter.to_strawberry(), "filter") + f.with_argument(sort.to_strawberry(), "sort") + f.with_argument(int, "skip", default_value=0) + f.with_argument(int, "take", default_value=10) + return f + + @staticmethod + def _build_resolver(f: "Field"): + params: list[inspect.Parameter] = [] + for arg in f.arguments.values(): + ann = Optional[arg.type] if arg.optional else arg.type + + if arg.default_value is None: + param = inspect.Parameter( + arg.name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=ann, + ) + else: + param = inspect.Parameter( + arg.name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=ann, + default=arg.default_value, + ) + + params.append(param) + + sig = inspect.Signature(parameters=params, return_annotation=f.type) + + def _resolver(*args, **kwargs): + return f.resolver(*args, **kwargs) if f.resolver else None + + _resolver.__signature__ = sig + return _resolver + + def _field_to_strawberry(self, f: Field) -> Any: + try: + if f.resolver: + ann = getattr(f.resolver, "__annotations__", {}) + if "return" not in ann or ann["return"] is None: + ann = dict(ann) + ann["return"] = f.type + f.resolver.__annotations__ = ann + + if f.arguments: + resolver = self._build_resolver(f) + return strawberry.field(resolver=resolver) + + if not f.resolver: + return strawberry.field(resolver=lambda *_, **__: None) + + return strawberry.field(resolver=f.resolver) + except StrawberryException as e: + raise Exception( + f"Error converting field '{f.name}' to strawberry field: {e}" + ) from e + + def to_strawberry(self) -> Type: + annotations: dict[str, Any] = {} + namespace: dict[str, Any] = {} + + for name, f in self._fields.items(): + annotations[name] = f.type + namespace[name] = self._field_to_strawberry(f) + + namespace["__annotations__"] = annotations + return strawberry.type(type(f"{self.__class__.__name__}GraphType", (), namespace)) diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py index 9912c739..23627ee4 100644 --- a/src/cpl-graphql/cpl/graphql/service/schema.py +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -1,43 +1,54 @@ -import graphene +from typing import Type, Self + +import strawberry from cpl.api.logger import APILogger from cpl.dependency.service_provider import ServiceProvider -from cpl.graphql.schema.collection import CollectionGraphType -from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.schema.root_query import RootQuery -from cpl.graphql.service.type_converter import TypeConverter class Schema: - def __init__(self, logger: APILogger, converter: TypeConverter, query: RootQuery, provider: ServiceProvider): + def __init__(self, logger: APILogger, provider: ServiceProvider): self._logger = logger self._provider = provider - self._converter = converter - self._types = set(GraphType.__subclasses__()) - self._types.remove(CollectionGraphType) + self._types: dict[str, Type[StrawberryProtocol]] = {} - self._query = query self._schema = None @property - def schema(self) -> graphene.Schema | None: + def schema(self) -> strawberry.Schema | None: return self._schema @property def query(self) -> RootQuery: - return self._query + return self._provider.get_service(RootQuery) - def with_type(self, t: type[GraphType]): - self._types.add(t) + def with_type(self, t: Type[StrawberryProtocol]) -> Self: + self._types[t.__name__] = t return self - def build(self) -> graphene.Schema: - self._schema = graphene.Schema( - query=self._converter.to_graphene(self._query), + def _get_types(self): + types: list[Type] = [] + for t in self._types.values(): + t_obj = self._provider.get_service(t) + if not t_obj: + raise ValueError(f"Type '{t.__name__}' not registered in service provider") + types.append(t_obj.to_strawberry()) + + return types + + def build(self) -> strawberry.Schema: + query = self._provider.get_service(RootQuery) + if not query: + raise ValueError("RootQuery not registered in service provider") + + self._schema = strawberry.Schema( + query=query.to_strawberry(), mutation=None, subscription=None, - # types=[self._converter.to_graphene(t) for t in self._types] if len(self._types) > 0 else None, + types=self._get_types(), ) return self._schema diff --git a/src/cpl-graphql/cpl/graphql/service/service.py b/src/cpl-graphql/cpl/graphql/service/service.py index 54c4f388..f039ccbd 100644 --- a/src/cpl-graphql/cpl/graphql/service/service.py +++ b/src/cpl-graphql/cpl/graphql/service/service.py @@ -16,7 +16,7 @@ class GraphQLService: variables: Optional[Dict[str, Any]], request: TRequest, ) -> Dict[str, Any]: - result = await self._schema.execute_async( + result = await self._schema.execute( query, variable_values=variables, context_value={"request": request}, diff --git a/src/cpl-graphql/cpl/graphql/service/type_converter.py b/src/cpl-graphql/cpl/graphql/service/type_converter.py deleted file mode 100644 index bf483b42..00000000 --- a/src/cpl-graphql/cpl/graphql/service/type_converter.py +++ /dev/null @@ -1,89 +0,0 @@ -import typing -from enum import Enum -from inspect import isclass - -import graphene -from typing import Any, get_origin, get_args - -from cpl.dependency import ServiceProvider -from cpl.graphql.schema.argument import Argument -from cpl.graphql.schema.filter.filter import Filter -from cpl.graphql.schema.graph_type import GraphType -from cpl.graphql.schema.object_graph_type import ObjectGraphType -from cpl.graphql.schema.sort.sort import Sort -from cpl.graphql.typing import Resolver -from cpl.graphql.utils.name_pipe import NamePipe - - -class TypeConverter: - __scalar_map: dict[Any, type[graphene.Scalar]] = { - str: graphene.String, - int: graphene.Int, - float: graphene.Float, - bool: graphene.Boolean, - } - - def __init__(self, provider: ServiceProvider): - self._provider = provider - - def _field_to_graphene(self, t: typing.Type[graphene.Scalar] | type, args: dict[str, Argument] = None, resolver: Resolver = None) -> graphene.Field: - arguments = {} - if args is not None: - arguments = { - arg.name: graphene.Argument(self.to_graphene(arg.type), name=arg.name, description=arg.description, default_value=arg.default_value) - for arg in args.values() - } - - return graphene.Field(t, args=arguments, resolver=resolver) - - def to_graphene(self, t: Any, name: str | None = None) -> Any: - try: - origin = get_origin(t) - args = get_args(t) - - if t in self.__scalar_map: - return self.__scalar_map[t] - - if origin in (list, typing.List): - if not args: - raise ValueError("List must specify element type, e.g. list[str]") - inner = self.to_graphene(args[0]) - return graphene.List(inner) - - if t is list or t is typing.List: - raise ValueError("List must be parametrized: list[str], list[int], list[UserQuery]") - - if isclass(t) and issubclass(t, Enum): - return graphene.Enum.from_enum(t) - - from cpl.graphql.schema.query import Query - if isinstance(t, type) and issubclass(t, (Query)): - query = self._provider.get_service(t) - if query is None: - raise ValueError(f"Could not resolve query of type {t}") - - t = query - - if isinstance(t, type) and issubclass(t, (ObjectGraphType, GraphType, Filter, Sort)): - t = t() - - if isinstance(t, (Query, Filter, Sort)): - attrs = {} - for field in t.get_fields().values(): - if isclass(field.type) and issubclass(field.type, Query) and field.subquery is not None: - subquery = self._provider.get_service(field.subquery) - sub = self.to_graphene(subquery, name=field.name.capitalize()) - attrs[field.name] = self._field_to_graphene(sub, field.args, field.resolver) - continue - - attrs[field.name] = self._field_to_graphene(self.to_graphene(field.type), field.args, field.resolver) - - class_name = NamePipe.to_str(name or t.__class__) - if isinstance(t, (Filter, Sort)): - return type(class_name, (graphene.InputObjectType,), attrs) - - return type(class_name, (graphene.ObjectType,), attrs) - - raise ValueError(f"Unsupported field type: {t}") - except Exception as e: - raise ValueError(f"Failed to convert type {t} to graphene type: {e}") from e \ No newline at end of file diff --git a/src/cpl-graphql/requirements.txt b/src/cpl-graphql/requirements.txt index abe92c36..d74de843 100644 --- a/src/cpl-graphql/requirements.txt +++ b/src/cpl-graphql/requirements.txt @@ -1,2 +1,2 @@ cpl-api -graphene==3.4.3 \ No newline at end of file +strawberry-graphql==0.282.0 \ No newline at end of file