From cdb5e4ff894e9f08c85d8dc017c858c91dae6d81 Mon Sep 17 00:00:00 2001 From: edraft Date: Sat, 4 Oct 2025 06:57:14 +0200 Subject: [PATCH] [WIP] Subscriptions --- example/api/src/main.py | 14 ++- example/api/src/model/post_query.py | 29 ++++- example/api/src/queries/hello.py | 5 +- src/cpl-api/cpl/api/middleware/request.py | 4 +- .../cpl/database/mysql/mysql_pool.py | 9 +- .../cpl/dependency/event_bus.py | 10 ++ .../cpl/graphql/abc/strawberry_protocol.py | 4 +- .../cpl/graphql/application/graphql_app.py | 4 +- .../graphql/auth/api_key/api_key_mutation.py | 17 ++- .../cpl/graphql/event_bus/__init__.py | 0 .../cpl/graphql/event_bus/memory.py | 27 +++++ src/cpl-graphql/cpl/graphql/graphql_module.py | 3 +- .../cpl/graphql/schema/collection.py | 9 +- .../cpl/graphql/schema/db_model_graph_type.py | 1 + .../graphql/schema/filter/db_model_filter.py | 1 + .../cpl/graphql/schema/mutation.py | 28 ++--- .../cpl/graphql/schema/root_subscription.py | 6 + .../cpl/graphql/schema/subscription.py | 109 ++++++++++++++++++ .../cpl/graphql/schema/subscription_field.py | 11 ++ src/cpl-graphql/cpl/graphql/service/schema.py | 10 +- src/cpl-graphql/cpl/graphql/typing.py | 5 +- 21 files changed, 254 insertions(+), 52 deletions(-) create mode 100644 src/cpl-dependency/cpl/dependency/event_bus.py create mode 100644 src/cpl-graphql/cpl/graphql/event_bus/__init__.py create mode 100644 src/cpl-graphql/cpl/graphql/event_bus/memory.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/root_subscription.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/subscription.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/subscription_field.py diff --git a/example/api/src/main.py b/example/api/src/main.py index ac906f9b..7dc53652 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -1,8 +1,10 @@ from starlette.responses import JSONResponse -from api.src.queries.cities import CityGraphType, CityFilter, CitySort -from api.src.queries.hello import UserGraphType#, UserFilter, UserSort, UserGraphType -from api.src.queries.user import UserFilter, UserSort +from cpl.dependency.event_bus import EventBusABC +from cpl.graphql.event_bus.memory import InMemoryEventBus +from queries.cities import CityGraphType, CityFilter, CitySort +from queries.hello import UserGraphType # , UserFilter, UserSort, UserGraphType +from queries.user import UserFilter, UserSort from cpl.api.api_module import ApiModule from cpl.application.application_builder import ApplicationBuilder from cpl.auth.schema import User, Role @@ -17,7 +19,7 @@ from cpl.graphql.graphql_module import GraphQLModule from model.author_dao import AuthorDao from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort from model.post_dao import PostDao -from model.post_query import PostFilter, PostSort, PostGraphType, PostMutation +from model.post_query import PostFilter, PostSort, PostGraphType, PostMutation, PostSubscription from permissions import PostPermissions from queries.hello import HelloQuery from scoped_service import ScopedService @@ -41,6 +43,7 @@ def main(): .add_module(GraphQLModule) .add_module(GraphQLAuthModule) .add_scoped(ScopedService) + .add_singleton(EventBusABC, InMemoryEventBus) .add_cache(User) .add_cache(Role) .add_transient(CityGraphType) @@ -66,6 +69,7 @@ def main(): .add_transient(PostFilter) .add_transient(PostSort) .add_transient(PostMutation) + .add_transient(PostSubscription) ) app = builder.build() @@ -96,6 +100,8 @@ def main(): schema.mutation.with_mutation("post", PostMutation).with_public() + schema.subscription.with_subscription("post", PostSubscription) + app.with_auth_root_queries(True) app.with_auth_root_mutations(True) diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py index d12f308c..eab7525f 100644 --- a/example/api/src/model/post_query.py +++ b/example/api/src/model/post_query.py @@ -1,3 +1,4 @@ +from cpl.dependency.event_bus import EventBusABC from cpl.graphql.query_context import QueryContext from cpl.graphql.schema.db_model_graph_type import DbModelGraphType from cpl.graphql.schema.filter.db_model_filter import DbModelFilter @@ -5,6 +6,7 @@ from cpl.graphql.schema.input import Input from cpl.graphql.schema.mutation import Mutation from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort_order import SortOrder +from cpl.graphql.schema.subscription import Subscription from model.author_dao import AuthorDao from model.author_query import AuthorGraphType, AuthorFilter from model.post import Post @@ -19,6 +21,7 @@ class PostFilter(DbModelFilter[Post]): self.string_field("title") self.string_field("content") + class PostSort(Sort[Post]): def __init__(self): Sort.__init__(self) @@ -26,6 +29,7 @@ class PostSort(Sort[Post]): self.field("title", SortOrder) self.field("content", SortOrder) + class PostGraphType(DbModelGraphType[Post]): def __init__(self, authors: AuthorDao): @@ -42,7 +46,7 @@ class PostGraphType(DbModelGraphType[Post]): def r_name(ctx: QueryContext): return ctx.user.username == "admin" - self.object_field("author", AuthorGraphType, resolver=_a).with_public(True)# .with_require_any([], [r_name])) + self.object_field("author", AuthorGraphType, resolver=_a).with_public(True) # .with_require_any([], [r_name])) self.string_field( "title", resolver=lambda root: root.title, @@ -64,6 +68,7 @@ class PostCreateInput(Input[Post]): self.string_field("content").with_required() self.int_field("author_id").with_required() + class PostUpdateInput(Input[Post]): title: str content: str @@ -75,13 +80,31 @@ class PostUpdateInput(Input[Post]): self.string_field("title").with_required(False) self.string_field("content").with_required(False) + +class PostSubscription(Subscription): + def __init__(self, bus: EventBusABC): + Subscription.__init__(self) + self._bus = bus + + async def post_changed(): + async for event in await self._bus.subscribe("postChange"): + print("Event:", event, type(event)) + yield event + + def selector(event: Post, info) -> bool: + return True + + self.subscription_field("postChange", PostGraphType, post_changed, selector) + + class PostMutation(Mutation): - def __init__(self, posts: PostDao, authors: AuthorDao): + def __init__(self, posts: PostDao, authors: AuthorDao, bus: EventBusABC): Mutation.__init__(self) self._posts = posts self._authors = authors + self._bus = bus self.field("create", int, resolver=self.create_post).with_public().with_required().with_argument( "input", @@ -112,6 +135,7 @@ class PostMutation(Mutation): post.content = input.content if input.content is not None else post.content await self._posts.update(post) + await self._bus.publish("postChange", post) return True async def delete_post(self, id: int) -> bool: @@ -127,4 +151,3 @@ class PostMutation(Mutation): return False await self._posts.restore(post) return True - diff --git a/example/api/src/queries/hello.py b/example/api/src/queries/hello.py index c53ce008..19a1f774 100644 --- a/example/api/src/queries/hello.py +++ b/example/api/src/queries/hello.py @@ -1,5 +1,5 @@ -from api.src.queries.cities import CityFilter, CitySort, CityGraphType, City -from api.src.queries.user import User, UserFilter, UserSort, UserGraphType +from queries.cities import CityFilter, CitySort, CityGraphType, City +from queries.user import User, UserFilter, UserSort, UserGraphType from cpl.api.middleware.request import get_request from cpl.auth.schema import UserDao, User from cpl.graphql.schema.filter.filter import Filter @@ -38,6 +38,7 @@ cities = [City(i, f"City {i}") for i in range(1, 101)] # resolver=lambda root: root.username, # ) + class HelloQuery(Query): def __init__(self): Query.__init__(self) diff --git a/src/cpl-api/cpl/api/middleware/request.py b/src/cpl-api/cpl/api/middleware/request.py index 05a291e3..4f3ae5a4 100644 --- a/src/cpl-api/cpl/api/middleware/request.py +++ b/src/cpl-api/cpl/api/middleware/request.py @@ -21,9 +21,7 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa class RequestMiddleware(ASGIMiddleware): - def __init__( - self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao - ): + def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao): ASGIMiddleware.__init__(self, app) self._provider = provider diff --git a/src/cpl-database/cpl/database/mysql/mysql_pool.py b/src/cpl-database/cpl/database/mysql/mysql_pool.py index 474bf6ce..b482229c 100644 --- a/src/cpl-database/cpl/database/mysql/mysql_pool.py +++ b/src/cpl-database/cpl/database/mysql/mysql_pool.py @@ -119,6 +119,13 @@ class MySQLPool: try: async with await con.cursor(dictionary=True) as cursor: res = await self._exec_sql(cursor, query, args, multi) - return list(res) + decoded_res = [] + for row in res: + decoded_row = { + k: (v.decode("utf-8") if isinstance(v, (bytes, bytearray)) else v) for k, v in row.items() + } + decoded_res.append(decoded_row) + + return decoded_res finally: await con.close() diff --git a/src/cpl-dependency/cpl/dependency/event_bus.py b/src/cpl-dependency/cpl/dependency/event_bus.py new file mode 100644 index 00000000..efd372aa --- /dev/null +++ b/src/cpl-dependency/cpl/dependency/event_bus.py @@ -0,0 +1,10 @@ +from abc import abstractmethod, ABC +from typing import Any, AsyncGenerator + + +class EventBusABC(ABC): + @abstractmethod + async def publish(self, channel: str, event: Any) -> None: ... + + @abstractmethod + async def subscribe(self, channel: str) -> AsyncGenerator[Any, None]: ... diff --git a/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py b/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py index 1c0b6592..ad8f18b8 100644 --- a/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py +++ b/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py @@ -1,9 +1,11 @@ from typing import Protocol, Type, runtime_checkable from cpl.graphql.schema.field import Field +from cpl.graphql.schema.subscription_field import SubscriptionField @runtime_checkable class StrawberryProtocol(Protocol): def to_strawberry(self) -> Type: ... - def get_fields(self) -> dict[str, Field]: ... + + def get_fields(self) -> dict[str, Field | SubscriptionField]: ... diff --git a/src/cpl-graphql/cpl/graphql/application/graphql_app.py b/src/cpl-graphql/cpl/graphql/application/graphql_app.py index 8c15dec8..5a67de13 100644 --- a/src/cpl-graphql/cpl/graphql/application/graphql_app.py +++ b/src/cpl-graphql/cpl/graphql/application/graphql_app.py @@ -114,4 +114,6 @@ class GraphQLApp(WebApp): if self._with_graphiql: self._logger.warning(f"GraphiQL available at http://{host}:{self._api_settings.port}/api/graphiql") if self._with_playground: - self._logger.warning(f"GraphQL Playground available at http://{host}:{self._api_settings.port}/api/playground") + self._logger.warning( + f"GraphQL Playground available at http://{host}:{self._api_settings.port}/api/playground" + ) diff --git a/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py index c431eee8..67444e3b 100644 --- a/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py +++ b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py @@ -8,13 +8,12 @@ from cpl.graphql.schema.mutation import Mutation class ApiKeyMutation(Mutation): def __init__( - self, - - logger: APILogger, - api_key_dao: ApiKeyDao, - api_key_permission_dao: ApiKeyPermissionDao, - permission_dao: ApiKeyPermissionDao, - keycloak_admin: KeycloakAdmin, + self, + logger: APILogger, + api_key_dao: ApiKeyDao, + api_key_permission_dao: ApiKeyPermissionDao, + permission_dao: ApiKeyPermissionDao, + keycloak_admin: KeycloakAdmin, ): Mutation.__init__(self) self._logger = logger @@ -61,9 +60,7 @@ class ApiKeyMutation(Mutation): api_key = ApiKey.new(obj.identifier) await self._api_key_dao.create(api_key) api_key = await self._api_key_dao.get_single_by([{ApiKey.identifier: obj.identifier}]) - await self._api_key_permission_dao.create_many( - [ApiKeyPermission(0, api_key.id, x) for x in obj.permissions] - ) + await self._api_key_permission_dao.create_many([ApiKeyPermission(0, api_key.id, x) for x in obj.permissions]) return api_key async def resolve_update(self, input: ApiKeyUpdateInput): diff --git a/src/cpl-graphql/cpl/graphql/event_bus/__init__.py b/src/cpl-graphql/cpl/graphql/event_bus/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/event_bus/memory.py b/src/cpl-graphql/cpl/graphql/event_bus/memory.py new file mode 100644 index 00000000..4d74c1af --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/event_bus/memory.py @@ -0,0 +1,27 @@ +import asyncio +from typing import Any, AsyncGenerator + +from cpl.dependency.event_bus import EventBusABC + + +class InMemoryEventBus(EventBusABC): + def __init__(self): + self._subscribers: dict[str, list[asyncio.Queue]] = {} + + async def publish(self, channel: str, event: Any) -> None: + queues = self._subscribers.get(channel, []) + for q in queues.copy(): + await q.put(event) + + async def subscribe(self, channel: str) -> AsyncGenerator[Any, None]: + q = asyncio.Queue() + if channel not in self._subscribers: + self._subscribers[channel] = [] + self._subscribers[channel].append(q) + + try: + while True: + item = await q.get() + yield item + finally: + self._subscribers[channel].remove(q) diff --git a/src/cpl-graphql/cpl/graphql/graphql_module.py b/src/cpl-graphql/cpl/graphql/graphql_module.py index b749d16e..3672e119 100644 --- a/src/cpl-graphql/cpl/graphql/graphql_module.py +++ b/src/cpl-graphql/cpl/graphql/graphql_module.py @@ -8,13 +8,14 @@ from cpl.graphql.schema.filter.int_filter import IntFilter from cpl.graphql.schema.filter.string_filter import StringFilter from cpl.graphql.schema.root_mutation import RootMutation from cpl.graphql.schema.root_query import RootQuery +from cpl.graphql.schema.root_subscription import RootSubscription from cpl.graphql.service.graphql import GraphQLService from cpl.graphql.service.schema import Schema class GraphQLModule(Module): dependencies = [ApiModule] - singleton = [Schema, RootQuery, RootMutation] + singleton = [Schema, RootQuery, RootMutation, RootSubscription] scoped = [GraphQLService] transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter] diff --git a/src/cpl-graphql/cpl/graphql/schema/collection.py b/src/cpl-graphql/cpl/graphql/schema/collection.py index 9d600ab9..650fc71e 100644 --- a/src/cpl-graphql/cpl/graphql/schema/collection.py +++ b/src/cpl-graphql/cpl/graphql/schema/collection.py @@ -9,6 +9,7 @@ from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.utils.type_collector import TypeCollector + class CollectionGraphTypeFactory: @classmethod def get(cls, node_type: Type[StrawberryProtocol]) -> Type: @@ -23,11 +24,7 @@ class CollectionGraphTypeFactory: gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type - gql_cls = type( - type_name, - (), - {} - ) + gql_cls = type(type_name, (), {}) TypeCollector.set(type_name, gql_cls) @@ -45,8 +42,6 @@ class CollectionGraphTypeFactory: return gql_type - - class Collection: def __init__(self, nodes: list[T], total_count: int, count: int): self._nodes = nodes diff --git a/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py b/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py index 2b9a39bb..ed4153a2 100644 --- a/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py +++ b/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py @@ -26,6 +26,7 @@ class DbModelGraphType(GraphType[T], Generic[T]): if Configuration.get("GraphQLAuthModuleEnabled", False): from cpl.graphql.auth.user.user_graph_type import UserGraphType + self.object_field("editor", lambda: UserGraphType, lambda root: root.editor).with_public(public) self.string_field("created", lambda root: root.created).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py index a7a22cb7..4a91544c 100644 --- a/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py +++ b/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py @@ -16,6 +16,7 @@ class DbModelFilter(Filter[T], Generic[T]): self.field("deleted", BoolFilter).with_public(public) if Configuration.get("GraphQLAuthModuleEnabled", False): from cpl.graphql.auth.user.user_filter import UserFilter + self.field("editor", lambda: UserFilter).with_public(public) self.field("created", DateFilter).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/schema/mutation.py b/src/cpl-graphql/cpl/graphql/schema/mutation.py index 82a707e9..d336f3b1 100644 --- a/src/cpl-graphql/cpl/graphql/schema/mutation.py +++ b/src/cpl-graphql/cpl/graphql/schema/mutation.py @@ -28,14 +28,14 @@ class Mutation(QueryABC): @staticmethod async def _resolve_assignments( - foreign_objs: list[int], - resolved_obj: T, - reference_key_own: Union[str, property], - reference_key_foreign: Union[str, property], - source_dao: DataAccessObjectABC[T], - join_dao: DataAccessObjectABC[T], - join_type: Type[DbJoinModelABC], - foreign_dao: DataAccessObjectABC[T], + foreign_objs: list[int], + resolved_obj: T, + reference_key_own: Union[str, property], + reference_key_foreign: Union[str, property], + source_dao: DataAccessObjectABC[T], + join_dao: DataAccessObjectABC[T], + join_type: Type[DbJoinModelABC], + foreign_dao: DataAccessObjectABC[T], ): if foreign_objs is None: return @@ -44,9 +44,7 @@ class Mutation(QueryABC): if isinstance(reference_key_foreign, property): reference_key_foreign_attr = reference_key_foreign.fget.__name__ - foreign_list = await join_dao.find_by( - [{reference_key_own: resolved_obj.id}, {"deleted": False}] - ) + foreign_list = await join_dao.find_by([{reference_key_own: resolved_obj.id}, {"deleted": False}]) to_delete = ( foreign_list @@ -61,9 +59,7 @@ class Mutation(QueryABC): foreign_ids = [getattr(x, reference_key_foreign_attr) for x in foreign_list] deleted_foreign_ids = [ getattr(x, reference_key_foreign_attr) - for x in await join_dao.find_by( - [{reference_key_own: resolved_obj.id}, {"deleted": True}] - ) + for x in await join_dao.find_by([{reference_key_own: resolved_obj.id}, {"deleted": True}]) ] to_create = [ @@ -94,6 +90,4 @@ class Mutation(QueryABC): foreign_changes = [*to_delete, *to_create, *to_restore] if len(foreign_changes) > 0: await source_dao.touch(resolved_obj) - await foreign_dao.touch_many_by_id( - [getattr(x, reference_key_foreign_attr) for x in foreign_changes] - ) + await foreign_dao.touch_many_by_id([getattr(x, reference_key_foreign_attr) for x in foreign_changes]) diff --git a/src/cpl-graphql/cpl/graphql/schema/root_subscription.py b/src/cpl-graphql/cpl/graphql/schema/root_subscription.py new file mode 100644 index 00000000..fab2bc8f --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/root_subscription.py @@ -0,0 +1,6 @@ +from cpl.graphql.schema.subscription import Subscription + + +class RootSubscription(Subscription): + def __init__(self): + Subscription.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/schema/subscription.py b/src/cpl-graphql/cpl/graphql/schema/subscription.py new file mode 100644 index 00000000..5d73a980 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/subscription.py @@ -0,0 +1,109 @@ +import inspect +import types +from abc import ABC +from typing import Any, Callable, Dict, Type, AsyncGenerator, Optional + +import strawberry +from strawberry.exceptions import StrawberryException + +from cpl.dependency import ServiceProvider, get_provider, inject +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol +from cpl.graphql.schema.subscription_field import SubscriptionField +from cpl.graphql.typing import Selector +from cpl.graphql.utils.type_collector import TypeCollector + + +class Subscription(ABC, StrawberryProtocol): + + @inject + def __init__(self, provider: ServiceProvider): + ABC.__init__(self) + self._provider = provider + self._fields: Dict[str, SubscriptionField] = {} + + @property + def fields(self) -> dict[str, SubscriptionField]: + return self._fields + + @property + def fields_count(self) -> int: + return len(self._fields) + + def get_fields(self) -> dict[str, SubscriptionField]: + return self._fields + + def subscription_field( + self, + name: str, + type_: Type, + resolver: Callable[..., AsyncGenerator], + selector: Selector = None, + ) -> SubscriptionField: + f = SubscriptionField(name, type_, resolver, selector) + self._fields[name] = f + return f + + def with_subscription(self, name: str, sub_cls: Type["Subscription"]) -> SubscriptionField: + sub = self._provider.get_service(sub_cls) + if not sub: + raise ValueError(f"Subscription '{sub_cls.__name__}' not registered in service provider") + + async def _resolver(root, info): + return sub + + self._fields[name] = SubscriptionField(name, sub.to_strawberry(), resolver=_resolver) + return self._fields[name] + + @staticmethod + def _type_to_strawberry(t: Type) -> Type: + _t = get_provider().get_service(t) + if isinstance(_t, StrawberryProtocol): + return _t.to_strawberry() + return t + + @staticmethod + def _build_resolver(f: SubscriptionField) -> Callable: + async def _resolver(root, info): + async for event in f.resolver(root, info): + if not f.selector or f.selector(event, info): + yield event + + return _resolver + + def to_strawberry(self) -> Type: + cls = self.__class__ + if TypeCollector.has(cls): + return TypeCollector.get(cls) + + gql_cls = type(cls.__name__, (), {}) + TypeCollector.set(cls, gql_cls) + + annotations: dict[str, Any] = {} + namespace: dict[str, Any] = {} + + for name, f in self._fields.items(): + t = f.type + if isinstance(t, types.GenericAlias): + t = t.__args__[0] + elif isinstance(t, type) and issubclass(t, StrawberryProtocol): + t = self._type_to_strawberry(t) + + annotations[name] = Optional[t] + + try: + namespace[name] = strawberry.subscription(resolver=self._build_resolver(f)) + + except StrawberryException as e: + raise Exception(f"Error converting subscription field '{f.name}': {e}") from e + + gql_cls.__annotations__ = annotations + for k, v in namespace.items(): + setattr(gql_cls, k, v) + + try: + gql_type = strawberry.type(gql_cls) + except Exception as e: + raise Exception(f"Error creating strawberry type for '{cls.__name__}': {e}") from e + + TypeCollector.set(cls, gql_type) + return gql_type diff --git a/src/cpl-graphql/cpl/graphql/schema/subscription_field.py b/src/cpl-graphql/cpl/graphql/schema/subscription_field.py new file mode 100644 index 00000000..3e60cda4 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/subscription_field.py @@ -0,0 +1,11 @@ +from typing import Type, Callable + +from cpl.graphql.typing import Selector + + +class SubscriptionField: + def __init__(self, name: str, type_: Type, resolver: Callable, selector: Selector = None): + self.name = name + self.type = type_ + self.resolver = resolver + self.selector = selector diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py index 9141f455..2a5a2dd0 100644 --- a/src/cpl-graphql/cpl/graphql/service/schema.py +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -8,6 +8,7 @@ from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.schema.root_mutation import RootMutation from cpl.graphql.schema.root_query import RootQuery +from cpl.graphql.schema.root_subscription import RootSubscription class Schema: @@ -38,6 +39,13 @@ class Schema: raise ValueError("RootMutation not registered in service provider") return mutation + @property + def subscription(self) -> RootSubscription: + subscription = self._provider.get_service(RootSubscription) + if not subscription: + raise ValueError("RootSubscription not registered in service provider") + return subscription + def with_type(self, t: Type[StrawberryProtocol]) -> Self: self._types[t.__name__] = t return self @@ -61,7 +69,7 @@ class Schema: self._schema = strawberry.Schema( query=query.to_strawberry() if query.fields_count > 0 else None, mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None, - subscription=None, + subscription=self.subscription.to_strawberry() if self.subscription.fields_count > 0 else None, types=self._get_types(), ) return self._schema diff --git a/src/cpl-graphql/cpl/graphql/typing.py b/src/cpl-graphql/cpl/graphql/typing.py index 3dd33106..bb8cda8e 100644 --- a/src/cpl-graphql/cpl/graphql/typing.py +++ b/src/cpl-graphql/cpl/graphql/typing.py @@ -1,11 +1,14 @@ from enum import Enum -from typing import Type, Callable, List, Tuple, Awaitable +from typing import Type, Callable, List, Tuple, Awaitable, Any + +import strawberry from cpl.auth.permission import Permissions from cpl.graphql.query_context import QueryContext TQuery = Type["Query"] Resolver = Callable +Selector = Callable[[Any, strawberry.types.Info], bool] ScalarType = str | int | float | bool | object AttributeName = str | property TRequireAnyPermissions = List[Enum | Permissions] | None