[WIP] Subscriptions
All checks were successful
Test before pr merge / test-lint (pull_request) Successful in 6s

This commit is contained in:
2025-10-04 06:57:14 +02:00
parent e362b7fb61
commit cdb5e4ff89
21 changed files with 254 additions and 52 deletions

View File

@@ -1,8 +1,10 @@
from starlette.responses import JSONResponse
from api.src.queries.cities import CityGraphType, CityFilter, CitySort
from api.src.queries.hello import UserGraphType#, UserFilter, UserSort, UserGraphType
from api.src.queries.user import UserFilter, UserSort
from cpl.dependency.event_bus import EventBusABC
from cpl.graphql.event_bus.memory import InMemoryEventBus
from queries.cities import CityGraphType, CityFilter, CitySort
from queries.hello import UserGraphType # , UserFilter, UserSort, UserGraphType
from queries.user import UserFilter, UserSort
from cpl.api.api_module import ApiModule
from cpl.application.application_builder import ApplicationBuilder
from cpl.auth.schema import User, Role
@@ -17,7 +19,7 @@ from cpl.graphql.graphql_module import GraphQLModule
from model.author_dao import AuthorDao
from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort
from model.post_dao import PostDao
from model.post_query import PostFilter, PostSort, PostGraphType, PostMutation
from model.post_query import PostFilter, PostSort, PostGraphType, PostMutation, PostSubscription
from permissions import PostPermissions
from queries.hello import HelloQuery
from scoped_service import ScopedService
@@ -41,6 +43,7 @@ def main():
.add_module(GraphQLModule)
.add_module(GraphQLAuthModule)
.add_scoped(ScopedService)
.add_singleton(EventBusABC, InMemoryEventBus)
.add_cache(User)
.add_cache(Role)
.add_transient(CityGraphType)
@@ -66,6 +69,7 @@ def main():
.add_transient(PostFilter)
.add_transient(PostSort)
.add_transient(PostMutation)
.add_transient(PostSubscription)
)
app = builder.build()
@@ -96,6 +100,8 @@ def main():
schema.mutation.with_mutation("post", PostMutation).with_public()
schema.subscription.with_subscription("post", PostSubscription)
app.with_auth_root_queries(True)
app.with_auth_root_mutations(True)

View File

@@ -1,3 +1,4 @@
from cpl.dependency.event_bus import EventBusABC
from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
@@ -5,6 +6,7 @@ from cpl.graphql.schema.input import Input
from cpl.graphql.schema.mutation import Mutation
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
from cpl.graphql.schema.subscription import Subscription
from model.author_dao import AuthorDao
from model.author_query import AuthorGraphType, AuthorFilter
from model.post import Post
@@ -19,6 +21,7 @@ class PostFilter(DbModelFilter[Post]):
self.string_field("title")
self.string_field("content")
class PostSort(Sort[Post]):
def __init__(self):
Sort.__init__(self)
@@ -26,6 +29,7 @@ class PostSort(Sort[Post]):
self.field("title", SortOrder)
self.field("content", SortOrder)
class PostGraphType(DbModelGraphType[Post]):
def __init__(self, authors: AuthorDao):
@@ -42,7 +46,7 @@ class PostGraphType(DbModelGraphType[Post]):
def r_name(ctx: QueryContext):
return ctx.user.username == "admin"
self.object_field("author", AuthorGraphType, resolver=_a).with_public(True)# .with_require_any([], [r_name]))
self.object_field("author", AuthorGraphType, resolver=_a).with_public(True) # .with_require_any([], [r_name]))
self.string_field(
"title",
resolver=lambda root: root.title,
@@ -64,6 +68,7 @@ class PostCreateInput(Input[Post]):
self.string_field("content").with_required()
self.int_field("author_id").with_required()
class PostUpdateInput(Input[Post]):
title: str
content: str
@@ -75,13 +80,31 @@ class PostUpdateInput(Input[Post]):
self.string_field("title").with_required(False)
self.string_field("content").with_required(False)
class PostSubscription(Subscription):
def __init__(self, bus: EventBusABC):
Subscription.__init__(self)
self._bus = bus
async def post_changed():
async for event in await self._bus.subscribe("postChange"):
print("Event:", event, type(event))
yield event
def selector(event: Post, info) -> bool:
return True
self.subscription_field("postChange", PostGraphType, post_changed, selector)
class PostMutation(Mutation):
def __init__(self, posts: PostDao, authors: AuthorDao):
def __init__(self, posts: PostDao, authors: AuthorDao, bus: EventBusABC):
Mutation.__init__(self)
self._posts = posts
self._authors = authors
self._bus = bus
self.field("create", int, resolver=self.create_post).with_public().with_required().with_argument(
"input",
@@ -112,6 +135,7 @@ class PostMutation(Mutation):
post.content = input.content if input.content is not None else post.content
await self._posts.update(post)
await self._bus.publish("postChange", post)
return True
async def delete_post(self, id: int) -> bool:
@@ -127,4 +151,3 @@ class PostMutation(Mutation):
return False
await self._posts.restore(post)
return True

View File

@@ -1,5 +1,5 @@
from api.src.queries.cities import CityFilter, CitySort, CityGraphType, City
from api.src.queries.user import User, UserFilter, UserSort, UserGraphType
from queries.cities import CityFilter, CitySort, CityGraphType, City
from queries.user import User, UserFilter, UserSort, UserGraphType
from cpl.api.middleware.request import get_request
from cpl.auth.schema import UserDao, User
from cpl.graphql.schema.filter.filter import Filter
@@ -38,6 +38,7 @@ cities = [City(i, f"City {i}") for i in range(1, 101)]
# resolver=lambda root: root.username,
# )
class HelloQuery(Query):
def __init__(self):
Query.__init__(self)

View File

@@ -21,9 +21,7 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
class RequestMiddleware(ASGIMiddleware):
def __init__(
self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao
):
def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao):
ASGIMiddleware.__init__(self, app)
self._provider = provider

View File

@@ -119,6 +119,13 @@ class MySQLPool:
try:
async with await con.cursor(dictionary=True) as cursor:
res = await self._exec_sql(cursor, query, args, multi)
return list(res)
decoded_res = []
for row in res:
decoded_row = {
k: (v.decode("utf-8") if isinstance(v, (bytes, bytearray)) else v) for k, v in row.items()
}
decoded_res.append(decoded_row)
return decoded_res
finally:
await con.close()

View File

@@ -0,0 +1,10 @@
from abc import abstractmethod, ABC
from typing import Any, AsyncGenerator
class EventBusABC(ABC):
@abstractmethod
async def publish(self, channel: str, event: Any) -> None: ...
@abstractmethod
async def subscribe(self, channel: str) -> AsyncGenerator[Any, None]: ...

View File

@@ -1,9 +1,11 @@
from typing import Protocol, Type, runtime_checkable
from cpl.graphql.schema.field import Field
from cpl.graphql.schema.subscription_field import SubscriptionField
@runtime_checkable
class StrawberryProtocol(Protocol):
def to_strawberry(self) -> Type: ...
def get_fields(self) -> dict[str, Field]: ...
def get_fields(self) -> dict[str, Field | SubscriptionField]: ...

View File

@@ -114,4 +114,6 @@ class GraphQLApp(WebApp):
if self._with_graphiql:
self._logger.warning(f"GraphiQL available at http://{host}:{self._api_settings.port}/api/graphiql")
if self._with_playground:
self._logger.warning(f"GraphQL Playground available at http://{host}:{self._api_settings.port}/api/playground")
self._logger.warning(
f"GraphQL Playground available at http://{host}:{self._api_settings.port}/api/playground"
)

View File

@@ -9,7 +9,6 @@ from cpl.graphql.schema.mutation import Mutation
class ApiKeyMutation(Mutation):
def __init__(
self,
logger: APILogger,
api_key_dao: ApiKeyDao,
api_key_permission_dao: ApiKeyPermissionDao,
@@ -61,9 +60,7 @@ class ApiKeyMutation(Mutation):
api_key = ApiKey.new(obj.identifier)
await self._api_key_dao.create(api_key)
api_key = await self._api_key_dao.get_single_by([{ApiKey.identifier: obj.identifier}])
await self._api_key_permission_dao.create_many(
[ApiKeyPermission(0, api_key.id, x) for x in obj.permissions]
)
await self._api_key_permission_dao.create_many([ApiKeyPermission(0, api_key.id, x) for x in obj.permissions])
return api_key
async def resolve_update(self, input: ApiKeyUpdateInput):

View File

@@ -0,0 +1,27 @@
import asyncio
from typing import Any, AsyncGenerator
from cpl.dependency.event_bus import EventBusABC
class InMemoryEventBus(EventBusABC):
def __init__(self):
self._subscribers: dict[str, list[asyncio.Queue]] = {}
async def publish(self, channel: str, event: Any) -> None:
queues = self._subscribers.get(channel, [])
for q in queues.copy():
await q.put(event)
async def subscribe(self, channel: str) -> AsyncGenerator[Any, None]:
q = asyncio.Queue()
if channel not in self._subscribers:
self._subscribers[channel] = []
self._subscribers[channel].append(q)
try:
while True:
item = await q.get()
yield item
finally:
self._subscribers[channel].remove(q)

View File

@@ -8,13 +8,14 @@ from cpl.graphql.schema.filter.int_filter import IntFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
from cpl.graphql.schema.root_mutation import RootMutation
from cpl.graphql.schema.root_query import RootQuery
from cpl.graphql.schema.root_subscription import RootSubscription
from cpl.graphql.service.graphql import GraphQLService
from cpl.graphql.service.schema import Schema
class GraphQLModule(Module):
dependencies = [ApiModule]
singleton = [Schema, RootQuery, RootMutation]
singleton = [Schema, RootQuery, RootMutation, RootSubscription]
scoped = [GraphQLService]
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]

View File

@@ -9,6 +9,7 @@ from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.utils.type_collector import TypeCollector
class CollectionGraphTypeFactory:
@classmethod
def get(cls, node_type: Type[StrawberryProtocol]) -> Type:
@@ -23,11 +24,7 @@ class CollectionGraphTypeFactory:
gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type
gql_cls = type(
type_name,
(),
{}
)
gql_cls = type(type_name, (), {})
TypeCollector.set(type_name, gql_cls)
@@ -45,8 +42,6 @@ class CollectionGraphTypeFactory:
return gql_type
class Collection:
def __init__(self, nodes: list[T], total_count: int, count: int):
self._nodes = nodes

View File

@@ -26,6 +26,7 @@ class DbModelGraphType(GraphType[T], Generic[T]):
if Configuration.get("GraphQLAuthModuleEnabled", False):
from cpl.graphql.auth.user.user_graph_type import UserGraphType
self.object_field("editor", lambda: UserGraphType, lambda root: root.editor).with_public(public)
self.string_field("created", lambda root: root.created).with_public(public)

View File

@@ -16,6 +16,7 @@ class DbModelFilter(Filter[T], Generic[T]):
self.field("deleted", BoolFilter).with_public(public)
if Configuration.get("GraphQLAuthModuleEnabled", False):
from cpl.graphql.auth.user.user_filter import UserFilter
self.field("editor", lambda: UserFilter).with_public(public)
self.field("created", DateFilter).with_public(public)

View File

@@ -44,9 +44,7 @@ class Mutation(QueryABC):
if isinstance(reference_key_foreign, property):
reference_key_foreign_attr = reference_key_foreign.fget.__name__
foreign_list = await join_dao.find_by(
[{reference_key_own: resolved_obj.id}, {"deleted": False}]
)
foreign_list = await join_dao.find_by([{reference_key_own: resolved_obj.id}, {"deleted": False}])
to_delete = (
foreign_list
@@ -61,9 +59,7 @@ class Mutation(QueryABC):
foreign_ids = [getattr(x, reference_key_foreign_attr) for x in foreign_list]
deleted_foreign_ids = [
getattr(x, reference_key_foreign_attr)
for x in await join_dao.find_by(
[{reference_key_own: resolved_obj.id}, {"deleted": True}]
)
for x in await join_dao.find_by([{reference_key_own: resolved_obj.id}, {"deleted": True}])
]
to_create = [
@@ -94,6 +90,4 @@ class Mutation(QueryABC):
foreign_changes = [*to_delete, *to_create, *to_restore]
if len(foreign_changes) > 0:
await source_dao.touch(resolved_obj)
await foreign_dao.touch_many_by_id(
[getattr(x, reference_key_foreign_attr) for x in foreign_changes]
)
await foreign_dao.touch_many_by_id([getattr(x, reference_key_foreign_attr) for x in foreign_changes])

View File

@@ -0,0 +1,6 @@
from cpl.graphql.schema.subscription import Subscription
class RootSubscription(Subscription):
def __init__(self):
Subscription.__init__(self)

View File

@@ -0,0 +1,109 @@
import inspect
import types
from abc import ABC
from typing import Any, Callable, Dict, Type, AsyncGenerator, Optional
import strawberry
from strawberry.exceptions import StrawberryException
from cpl.dependency import ServiceProvider, get_provider, inject
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.subscription_field import SubscriptionField
from cpl.graphql.typing import Selector
from cpl.graphql.utils.type_collector import TypeCollector
class Subscription(ABC, StrawberryProtocol):
@inject
def __init__(self, provider: ServiceProvider):
ABC.__init__(self)
self._provider = provider
self._fields: Dict[str, SubscriptionField] = {}
@property
def fields(self) -> dict[str, SubscriptionField]:
return self._fields
@property
def fields_count(self) -> int:
return len(self._fields)
def get_fields(self) -> dict[str, SubscriptionField]:
return self._fields
def subscription_field(
self,
name: str,
type_: Type,
resolver: Callable[..., AsyncGenerator],
selector: Selector = None,
) -> SubscriptionField:
f = SubscriptionField(name, type_, resolver, selector)
self._fields[name] = f
return f
def with_subscription(self, name: str, sub_cls: Type["Subscription"]) -> SubscriptionField:
sub = self._provider.get_service(sub_cls)
if not sub:
raise ValueError(f"Subscription '{sub_cls.__name__}' not registered in service provider")
async def _resolver(root, info):
return sub
self._fields[name] = SubscriptionField(name, sub.to_strawberry(), resolver=_resolver)
return self._fields[name]
@staticmethod
def _type_to_strawberry(t: Type) -> Type:
_t = get_provider().get_service(t)
if isinstance(_t, StrawberryProtocol):
return _t.to_strawberry()
return t
@staticmethod
def _build_resolver(f: SubscriptionField) -> Callable:
async def _resolver(root, info):
async for event in f.resolver(root, info):
if not f.selector or f.selector(event, info):
yield event
return _resolver
def to_strawberry(self) -> Type:
cls = self.__class__
if TypeCollector.has(cls):
return TypeCollector.get(cls)
gql_cls = type(cls.__name__, (), {})
TypeCollector.set(cls, gql_cls)
annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {}
for name, f in self._fields.items():
t = f.type
if isinstance(t, types.GenericAlias):
t = t.__args__[0]
elif isinstance(t, type) and issubclass(t, StrawberryProtocol):
t = self._type_to_strawberry(t)
annotations[name] = Optional[t]
try:
namespace[name] = strawberry.subscription(resolver=self._build_resolver(f))
except StrawberryException as e:
raise Exception(f"Error converting subscription field '{f.name}': {e}") from e
gql_cls.__annotations__ = annotations
for k, v in namespace.items():
setattr(gql_cls, k, v)
try:
gql_type = strawberry.type(gql_cls)
except Exception as e:
raise Exception(f"Error creating strawberry type for '{cls.__name__}': {e}") from e
TypeCollector.set(cls, gql_type)
return gql_type

View File

@@ -0,0 +1,11 @@
from typing import Type, Callable
from cpl.graphql.typing import Selector
class SubscriptionField:
def __init__(self, name: str, type_: Type, resolver: Callable, selector: Selector = None):
self.name = name
self.type = type_
self.resolver = resolver
self.selector = selector

View File

@@ -8,6 +8,7 @@ from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.root_mutation import RootMutation
from cpl.graphql.schema.root_query import RootQuery
from cpl.graphql.schema.root_subscription import RootSubscription
class Schema:
@@ -38,6 +39,13 @@ class Schema:
raise ValueError("RootMutation not registered in service provider")
return mutation
@property
def subscription(self) -> RootSubscription:
subscription = self._provider.get_service(RootSubscription)
if not subscription:
raise ValueError("RootSubscription not registered in service provider")
return subscription
def with_type(self, t: Type[StrawberryProtocol]) -> Self:
self._types[t.__name__] = t
return self
@@ -61,7 +69,7 @@ class Schema:
self._schema = strawberry.Schema(
query=query.to_strawberry() if query.fields_count > 0 else None,
mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None,
subscription=None,
subscription=self.subscription.to_strawberry() if self.subscription.fields_count > 0 else None,
types=self._get_types(),
)
return self._schema

View File

@@ -1,11 +1,14 @@
from enum import Enum
from typing import Type, Callable, List, Tuple, Awaitable
from typing import Type, Callable, List, Tuple, Awaitable, Any
import strawberry
from cpl.auth.permission import Permissions
from cpl.graphql.query_context import QueryContext
TQuery = Type["Query"]
Resolver = Callable
Selector = Callable[[Any, strawberry.types.Info], bool]
ScalarType = str | int | float | bool | object
AttributeName = str | property
TRequireAnyPermissions = List[Enum | Permissions] | None