[WIP] collection #181

This commit is contained in:
2025-09-27 04:08:32 +02:00
parent 7673c3d10e
commit 7772a0a51c
21 changed files with 375 additions and 85 deletions

View File

@@ -1,9 +1,9 @@
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from api.src.queries.cities import CityGraphType
from api.src.queries.hello import UserGraphType
from cpl.api.api_module import ApiModule from cpl.api.api_module import ApiModule
from cpl.api.application.web_app import WebApp
from cpl.application.application_builder import ApplicationBuilder from cpl.application.application_builder import ApplicationBuilder
from cpl.graphql.application.graphql_app import GraphQLApp
from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema import AuthUser, Role from cpl.auth.schema import AuthUser, Role
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
@@ -11,8 +11,8 @@ from cpl.core.console import Console
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.core.utils.cache import Cache from cpl.core.utils.cache import Cache
from cpl.database.mysql.mysql_module import MySQLModule from cpl.database.mysql.mysql_module import MySQLModule
from cpl.graphql.application.graphql_app import GraphQLApp
from cpl.graphql.graphql_module import GraphQLModule from cpl.graphql.graphql_module import GraphQLModule
from cpl.graphql.schema.root_query import RootQuery
from queries.hello import HelloQuery from queries.hello import HelloQuery
from scoped_service import ScopedService from scoped_service import ScopedService
from service import PingService from service import PingService
@@ -37,6 +37,8 @@ def main():
builder.services.add_cache(AuthUser) builder.services.add_cache(AuthUser)
builder.services.add_cache(Role) builder.services.add_cache(Role)
builder.services.add_transient(CityGraphType)
builder.services.add_transient(UserGraphType)
builder.services.add_transient(HelloQuery) builder.services.add_transient(HelloQuery)
app = builder.build() app = builder.build()

View File

@@ -0,0 +1,39 @@
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.object_graph_type import ObjectGraphType
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
class City:
def __init__(self, id: int, name: str):
self.id = id
self.name = name
class CityFilter(Filter[City]):
def __init__(self):
Filter.__init__(self)
self.field("id", int)
self.field("name", str)
class CitySort(Sort[City]):
def __init__(self):
Sort.__init__(self)
self.field("id", SortOrder)
self.field("name", SortOrder)
class CityGraphType(ObjectGraphType):
def __init__(self):
ObjectGraphType.__init__(self)
self.string_field(
"id",
resolver=lambda user, *_: user.id,
)
self.string_field(
"name",
resolver=lambda user, *_: user.name,
)

View File

@@ -1,6 +1,10 @@
from api.src.queries.cities import CityFilter, CitySort, CityGraphType, City
from api.src.queries.user import User, UserFilter, UserSort, UserGraphType
from cpl.api.middleware.request import get_request from cpl.api.middleware.request import get_request
from cpl.graphql.schema.query import Query from cpl.graphql.schema.query import Query
users = [User(i, f"User {i}") for i in range(1, 101)]
cities = [City(i, f"City {i}") for i in range(1, 101)]
class HelloQuery(Query): class HelloQuery(Query):
def __init__(self): def __init__(self):
@@ -9,3 +13,18 @@ class HelloQuery(Query):
"message", "message",
resolver=lambda *_, name: f"Hello {name} {get_request().state.request_id}", resolver=lambda *_, name: f"Hello {name} {get_request().state.request_id}",
).with_argument(str, "name", "Name to greet", "world") ).with_argument(str, "name", "Name to greet", "world")
self.collection_field(
UserGraphType,
"users",
UserFilter,
UserSort,
resolver=lambda *_: users,
)
self.collection_field(
CityGraphType,
"cities",
CityFilter,
CitySort,
resolver=lambda *_: cities,
)

View File

@@ -0,0 +1,39 @@
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.object_graph_type import ObjectGraphType
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
class User:
def __init__(self, id: int, name: str):
self.id = id
self.name = name
class UserFilter(Filter[User]):
def __init__(self):
Filter.__init__(self)
self.field("id", int)
self.field("name", str)
class UserSort(Sort[User]):
def __init__(self):
Sort.__init__(self)
self.field("id", SortOrder)
self.field("name", SortOrder)
class UserGraphType(ObjectGraphType):
def __init__(self):
ObjectGraphType.__init__(self)
self.string_field(
"id",
resolver=lambda user, *_: user.id,
)
self.string_field(
"name",
resolver=lambda user, *_: user.name,
)

View File

@@ -25,7 +25,7 @@ class ServiceProvider:
for descriptor in self._service_descriptors: for descriptor in self._service_descriptors:
if typing.get_origin(service_type) is None and ( if typing.get_origin(service_type) is None and (
descriptor.service_type == service_type descriptor.service_type.__name__ == service_type.__name__
or typing.get_origin(descriptor.base_type) is None or typing.get_origin(descriptor.base_type) is None
and issubclass(descriptor.base_type, service_type) and issubclass(descriptor.base_type, service_type)
): ):

View File

@@ -1,15 +1,17 @@
from cpl.api.api_module import ApiModule from cpl.api.api_module import ApiModule
from cpl.dependency.module.module import Module from cpl.dependency.module.module import Module
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.schema.collection import CollectionGraphType
from cpl.graphql.schema.root_query import RootQuery from cpl.graphql.schema.root_query import RootQuery
from cpl.graphql.service.schema import Schema from cpl.graphql.service.schema import Schema
from cpl.graphql.service.service import GraphQLService from cpl.graphql.service.service import GraphQLService
from cpl.graphql.service.type_converter import TypeConverter
class GraphQLModule(Module): class GraphQLModule(Module):
dependencies = [ApiModule] dependencies = [ApiModule]
singleton = [Schema] singleton = [TypeConverter, Schema]
scoped = [GraphQLService, RootQuery] scoped = [GraphQLService, RootQuery, CollectionGraphType]
@staticmethod @staticmethod
def configure(services: ServiceProvider) -> None: def configure(services: ServiceProvider) -> None:

View File

@@ -0,0 +1,18 @@
from typing import Generic, Type
from cpl.core.typing import T
from cpl.graphql.schema.graph_type import GraphType
class Collection(Generic[T]):
def __init__(self, nodes: list[T], total_count: int, count: int):
self.nodes = nodes
self.totalCount = total_count
self.count = count
class CollectionGraphType(GraphType[T]):
def __init__(self, t: Type[GraphType[T]]):
GraphType.__init__(self)
self.string_field("totalCount", resolver=lambda obj, *_: obj.totalCount)
self.string_field("count", resolver=lambda obj, *_: obj.count)
self.list_field("nodes", t, resolver=lambda obj, *_: obj.nodes)

View File

@@ -1,12 +1,12 @@
from typing import Self from typing import Self
from cpl.graphql.schema.argument import Argument from cpl.graphql.schema.argument import Argument
from cpl.graphql.typing import TQuery from cpl.graphql.typing import TQuery, Resolver
class Field: class Field:
def __init__(self, name: str, gql_type: type, resolver: callable, subquery: TQuery | None = None): def __init__(self, name: str, gql_type: type, resolver: Resolver = None, subquery: TQuery = None):
self._name = name self._name = name
self._gql_type = gql_type self._gql_type = gql_type
self._resolver = resolver self._resolver = resolver
@@ -37,7 +37,7 @@ class Field:
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None) -> Self: def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None) -> Self:
if name in self._args: if name in self._args:
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'") raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
self._args[name] = Argument(name, arg_type, description, default_value) self._args[name] = Argument(arg_type, name, description, default_value)
return self return self
def with_arguments(self, args: list[Argument]) -> Self: def with_arguments(self, args: list[Argument]) -> Self:
@@ -45,5 +45,5 @@ class Field:
if not isinstance(arg, Argument): if not isinstance(arg, Argument):
raise ValueError(f"Expected Argument instance, got {type(arg)}") raise ValueError(f"Expected Argument instance, got {type(arg)}")
self.with_argument(arg.name, arg.type, arg.description, arg.default_value) self.with_argument(arg.type, arg.name, arg.description, arg.default_value)
return self return self

View File

@@ -0,0 +1,9 @@
from cpl.core.typing import T
from cpl.graphql.schema.input import Input
class Filter(Input[T]):
def __init__(
self,
):
Input.__init__(self)

View File

@@ -0,0 +1,10 @@
from typing import Generic
from cpl.core.typing import T
from cpl.graphql.schema.query import Query
class GraphType(Generic[T], Query):
def __init__(self):
Query.__init__(self)

View File

@@ -0,0 +1,26 @@
from datetime import datetime
from enum import Enum
from typing import Type, Generic
import graphene
from cpl.core.typing import T
from cpl.graphql.schema.field import Field
class Input(Generic[T], graphene.InputObjectType):
def __init__(
self,
):
graphene.InputObjectType.__init__(self)
self._fields: dict[str, Field] = {}
def get_fields(self) -> dict[str, Field]:
return self._fields
def field(
self,
field: str,
t: Type["Input"] | Type[int | str | bool | datetime | list | Enum],
):
self._fields[field] = Field(field, t)

View File

@@ -0,0 +1,9 @@
from cpl.core.typing import T
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.query import Query
class ObjectGraphType(GraphType[T], Query):
def __init__(self):
Query.__init__(self)

View File

@@ -2,9 +2,12 @@ from typing import Callable, Type
from graphene import ObjectType from graphene import ObjectType
from cpl.graphql.schema.argument import Argument
from cpl.graphql.schema.field import Field from cpl.graphql.schema.field import Field
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
from cpl.graphql.typing import Resolver from cpl.graphql.typing import Resolver
from cpl.graphql.utils.type_converter import TypeConverter
class Query(ObjectType): class Query(ObjectType):
@@ -32,7 +35,7 @@ class Query(ObjectType):
def with_query(self, name: str, subquery: Type["Query"]): def with_query(self, name: str, subquery: Type["Query"]):
from cpl.graphql.schema.field import Field from cpl.graphql.schema.field import Field
f = Field(name=name, gql_type=object, resolver=lambda root, info, **kwargs: {}, subquery=subquery) f = Field(name=name, gql_type=subquery, resolver=lambda root, info, **kwargs: {}, subquery=subquery)
self._fields[name] = f self._fields[name] = f
return self._fields[name] return self._fields[name]
@@ -47,3 +50,44 @@ class Query(ObjectType):
def bool_field(self, name: str, resolver: Resolver = None) -> "Field": def bool_field(self, name: str, resolver: Resolver = None) -> "Field":
return self.field(name, bool, resolver) return self.field(name, bool, resolver)
def list_field(self, name: str, t: type, resolver: Resolver = None) -> "Field":
return self.field(name, list[t], resolver)
def collection_field(
self, t: type, name: str, filter_type: type, sort_type: type, resolver: Resolver = None
) -> "Field":
from cpl.graphql.schema.collection import Collection, CollectionGraphType
def _resolve_collection(*_, filter: Filter, sort: Sort, skip: int, take: int):
items = resolver()
for field in filter or []:
if filter[field] is None:
continue
items = [item for item in items if getattr(item, field) == filter[field]]
for field in sort or []:
if sort[field] is None:
continue
reverse = sort[field] == SortOrder.DESC
items = sorted(items, key=lambda item: getattr(item, field), reverse=reverse)
total_count = len(items)
paged = items[skip : skip + take]
return Collection(nodes=paged, total_count=total_count, count=len(paged))
# base = getattr(t, "__gqlname__", t.__class__.__name__)
wrapper = CollectionGraphType(t)
# wrapper.set_graphql_name(f"{base}Collection")
f = self.field(name, wrapper, resolver=_resolve_collection)
return f.with_arguments(
[
Argument(filter_type, "filter"),
Argument(sort_type, "sort"),
Argument(int, "skip", default_value=0),
Argument(int, "take", default_value=10),
]
)

View File

@@ -0,0 +1,9 @@
from cpl.core.typing import T
from cpl.graphql.schema.input import Input
class Sort(Input[T]):
def __init__(
self,
):
Input.__init__(self)

View File

@@ -0,0 +1,6 @@
from enum import Enum, auto
class SortOrder(Enum):
ASC = auto()
DESC = auto()

View File

@@ -1,21 +1,22 @@
from typing import Type
import graphene import graphene
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.schema.argument import Argument from cpl.graphql.schema.collection import CollectionGraphType
from cpl.graphql.schema.query import Query from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.root_query import RootQuery from cpl.graphql.schema.root_query import RootQuery
from cpl.graphql.typing import Resolver from cpl.graphql.service.type_converter import TypeConverter
from cpl.graphql.utils.type_converter import TypeConverter
class Schema: class Schema:
def __init__(self, logger: APILogger, query: RootQuery, provider: ServiceProvider): def __init__(self, logger: APILogger, converter: TypeConverter, query: RootQuery, provider: ServiceProvider):
self._logger = logger self._logger = logger
self._provider = provider self._provider = provider
self._converter = converter
self._types = set(GraphType.__subclasses__())
self._types.remove(CollectionGraphType)
self._query = query self._query = query
self._schema = None self._schema = None
@@ -28,37 +29,15 @@ class Schema:
def query(self) -> RootQuery: def query(self) -> RootQuery:
return self._query return self._query
def with_type(self, t: type[GraphType]):
self._types.add(t)
return self
def build(self) -> graphene.Schema: def build(self) -> graphene.Schema:
self._schema = graphene.Schema( self._schema = graphene.Schema(
query=self.to_graphene(self._query), query=self._converter.to_graphene(self._query),
mutation=None, mutation=None,
subscription=None, subscription=None,
# types=[self._converter.to_graphene(t) for t in self._types] if len(self._types) > 0 else None,
) )
return self._schema return self._schema
@staticmethod
def _field_to_graphene(t: Type[graphene.Scalar] | type, args: dict[str, Argument] = None, resolver: Resolver = None) -> graphene.Field:
arguments = {}
if args is not None:
arguments = {
arg.name: graphene.Argument(TypeConverter.to_graphene(arg.type), description=arg.description, default_value=arg.default_value)
for arg in args.values()
}
return graphene.Field(t, args=arguments, resolver=resolver)
def to_graphene(self, query: Query, name: str | None = None):
assert query is not None, "Query cannot be None"
attrs = {}
for field in query.get_fields().values():
if field.type == object and field.subquery is not None:
subquery = self._provider.get_service(field.subquery)
sub = self.to_graphene(subquery, name=field.name.capitalize())
attrs[field.name] = self._field_to_graphene(sub, field.args, field.resolver)
continue
attrs[field.name] = self._field_to_graphene(TypeConverter.to_graphene(field.type), field.args, field.resolver)
class_name = name or query.__class__.__name__
return type(class_name, (graphene.ObjectType,), attrs)

View File

@@ -0,0 +1,89 @@
import typing
from enum import Enum
from inspect import isclass
import graphene
from typing import Any, get_origin, get_args
from cpl.dependency import ServiceProvider
from cpl.graphql.schema.argument import Argument
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.object_graph_type import ObjectGraphType
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.typing import Resolver
from cpl.graphql.utils.name_pipe import NamePipe
class TypeConverter:
__scalar_map: dict[Any, type[graphene.Scalar]] = {
str: graphene.String,
int: graphene.Int,
float: graphene.Float,
bool: graphene.Boolean,
}
def __init__(self, provider: ServiceProvider):
self._provider = provider
def _field_to_graphene(self, t: typing.Type[graphene.Scalar] | type, args: dict[str, Argument] = None, resolver: Resolver = None) -> graphene.Field:
arguments = {}
if args is not None:
arguments = {
arg.name: graphene.Argument(self.to_graphene(arg.type), name=arg.name, description=arg.description, default_value=arg.default_value)
for arg in args.values()
}
return graphene.Field(t, args=arguments, resolver=resolver)
def to_graphene(self, t: Any, name: str | None = None) -> Any:
try:
origin = get_origin(t)
args = get_args(t)
if t in self.__scalar_map:
return self.__scalar_map[t]
if origin in (list, typing.List):
if not args:
raise ValueError("List must specify element type, e.g. list[str]")
inner = self.to_graphene(args[0])
return graphene.List(inner)
if t is list or t is typing.List:
raise ValueError("List must be parametrized: list[str], list[int], list[UserQuery]")
if isclass(t) and issubclass(t, Enum):
return graphene.Enum.from_enum(t)
from cpl.graphql.schema.query import Query
if isinstance(t, type) and issubclass(t, (Query)):
query = self._provider.get_service(t)
if query is None:
raise ValueError(f"Could not resolve query of type {t}")
t = query
if isinstance(t, type) and issubclass(t, (ObjectGraphType, GraphType, Filter, Sort)):
t = t()
if isinstance(t, (Query, Filter, Sort)):
attrs = {}
for field in t.get_fields().values():
if isclass(field.type) and issubclass(field.type, Query) and field.subquery is not None:
subquery = self._provider.get_service(field.subquery)
sub = self.to_graphene(subquery, name=field.name.capitalize())
attrs[field.name] = self._field_to_graphene(sub, field.args, field.resolver)
continue
attrs[field.name] = self._field_to_graphene(self.to_graphene(field.type), field.args, field.resolver)
class_name = NamePipe.to_str(name or t.__class__)
if isinstance(t, (Filter, Sort)):
return type(class_name, (graphene.InputObjectType,), attrs)
return type(class_name, (graphene.ObjectType,), attrs)
raise ValueError(f"Unsupported field type: {t}")
except Exception as e:
raise ValueError(f"Failed to convert type {t} to graphene type: {e}") from e

View File

@@ -0,0 +1,28 @@
from cpl.core.pipes import PipeABC
from cpl.core.typing import T
from cpl.graphql.schema.collection import CollectionGraphType
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.object_graph_type import ObjectGraphType
class NamePipe(PipeABC):
@staticmethod
def to_str(value: type, *args) -> str:
if isinstance(value, str):
return value
if not isinstance(value, type):
raise ValueError(f"Expected a type, got {type(value)}")
if issubclass(value, CollectionGraphType):
return f"{value.__name__.replace(GraphType.__name__, "")}"
if issubclass(value, (ObjectGraphType, GraphType)):
return value.__name__.replace(GraphType.__name__, "")
return value.__name__
@staticmethod
def from_str(value: str, *args) -> T:
pass

View File

@@ -1,38 +0,0 @@
from typing import Type
import graphene
from cpl.graphql.typing import ScalarType
class TypeConverter:
@staticmethod
def from_graphene(t: Type[graphene.Scalar]) -> ScalarType:
graphene_type_map: dict[Type[graphene.Scalar], ScalarType] = {
graphene.String: str,
graphene.Int: int,
graphene.Float: float,
graphene.Boolean: bool,
graphene.ObjectType: object,
}
if t not in graphene_type_map:
raise ValueError(f"Unsupported field type: {t}")
return graphene_type_map[t]
@staticmethod
def to_graphene(t: ScalarType) -> Type[graphene.Scalar]:
type_graphene_map: dict[ScalarType, Type[graphene.Scalar]] = {
str: graphene.String,
int: graphene.Int,
float: graphene.Float,
bool: graphene.Boolean,
object: graphene.ObjectType,
}
if t not in type_graphene_map:
raise ValueError(f"Unsupported field type: {t}")
return type_graphene_map[t]