[WIP] Subscriptions
All checks were successful
Test before pr merge / test-lint (pull_request) Successful in 6s

This commit is contained in:
2025-10-04 06:57:14 +02:00
parent e362b7fb61
commit cdb5e4ff89
21 changed files with 254 additions and 52 deletions

View File

@@ -21,9 +21,7 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
class RequestMiddleware(ASGIMiddleware):
def __init__(
self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao
):
def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao):
ASGIMiddleware.__init__(self, app)
self._provider = provider

View File

@@ -119,6 +119,13 @@ class MySQLPool:
try:
async with await con.cursor(dictionary=True) as cursor:
res = await self._exec_sql(cursor, query, args, multi)
return list(res)
decoded_res = []
for row in res:
decoded_row = {
k: (v.decode("utf-8") if isinstance(v, (bytes, bytearray)) else v) for k, v in row.items()
}
decoded_res.append(decoded_row)
return decoded_res
finally:
await con.close()

View File

@@ -0,0 +1,10 @@
from abc import abstractmethod, ABC
from typing import Any, AsyncGenerator
class EventBusABC(ABC):
@abstractmethod
async def publish(self, channel: str, event: Any) -> None: ...
@abstractmethod
async def subscribe(self, channel: str) -> AsyncGenerator[Any, None]: ...

View File

@@ -1,9 +1,11 @@
from typing import Protocol, Type, runtime_checkable
from cpl.graphql.schema.field import Field
from cpl.graphql.schema.subscription_field import SubscriptionField
@runtime_checkable
class StrawberryProtocol(Protocol):
def to_strawberry(self) -> Type: ...
def get_fields(self) -> dict[str, Field]: ...
def get_fields(self) -> dict[str, Field | SubscriptionField]: ...

View File

@@ -114,4 +114,6 @@ class GraphQLApp(WebApp):
if self._with_graphiql:
self._logger.warning(f"GraphiQL available at http://{host}:{self._api_settings.port}/api/graphiql")
if self._with_playground:
self._logger.warning(f"GraphQL Playground available at http://{host}:{self._api_settings.port}/api/playground")
self._logger.warning(
f"GraphQL Playground available at http://{host}:{self._api_settings.port}/api/playground"
)

View File

@@ -8,13 +8,12 @@ from cpl.graphql.schema.mutation import Mutation
class ApiKeyMutation(Mutation):
def __init__(
self,
logger: APILogger,
api_key_dao: ApiKeyDao,
api_key_permission_dao: ApiKeyPermissionDao,
permission_dao: ApiKeyPermissionDao,
keycloak_admin: KeycloakAdmin,
self,
logger: APILogger,
api_key_dao: ApiKeyDao,
api_key_permission_dao: ApiKeyPermissionDao,
permission_dao: ApiKeyPermissionDao,
keycloak_admin: KeycloakAdmin,
):
Mutation.__init__(self)
self._logger = logger
@@ -61,9 +60,7 @@ class ApiKeyMutation(Mutation):
api_key = ApiKey.new(obj.identifier)
await self._api_key_dao.create(api_key)
api_key = await self._api_key_dao.get_single_by([{ApiKey.identifier: obj.identifier}])
await self._api_key_permission_dao.create_many(
[ApiKeyPermission(0, api_key.id, x) for x in obj.permissions]
)
await self._api_key_permission_dao.create_many([ApiKeyPermission(0, api_key.id, x) for x in obj.permissions])
return api_key
async def resolve_update(self, input: ApiKeyUpdateInput):

View File

@@ -0,0 +1,27 @@
import asyncio
from typing import Any, AsyncGenerator
from cpl.dependency.event_bus import EventBusABC
class InMemoryEventBus(EventBusABC):
def __init__(self):
self._subscribers: dict[str, list[asyncio.Queue]] = {}
async def publish(self, channel: str, event: Any) -> None:
queues = self._subscribers.get(channel, [])
for q in queues.copy():
await q.put(event)
async def subscribe(self, channel: str) -> AsyncGenerator[Any, None]:
q = asyncio.Queue()
if channel not in self._subscribers:
self._subscribers[channel] = []
self._subscribers[channel].append(q)
try:
while True:
item = await q.get()
yield item
finally:
self._subscribers[channel].remove(q)

View File

@@ -8,13 +8,14 @@ from cpl.graphql.schema.filter.int_filter import IntFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
from cpl.graphql.schema.root_mutation import RootMutation
from cpl.graphql.schema.root_query import RootQuery
from cpl.graphql.schema.root_subscription import RootSubscription
from cpl.graphql.service.graphql import GraphQLService
from cpl.graphql.service.schema import Schema
class GraphQLModule(Module):
dependencies = [ApiModule]
singleton = [Schema, RootQuery, RootMutation]
singleton = [Schema, RootQuery, RootMutation, RootSubscription]
scoped = [GraphQLService]
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]

View File

@@ -9,6 +9,7 @@ from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.utils.type_collector import TypeCollector
class CollectionGraphTypeFactory:
@classmethod
def get(cls, node_type: Type[StrawberryProtocol]) -> Type:
@@ -23,11 +24,7 @@ class CollectionGraphTypeFactory:
gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type
gql_cls = type(
type_name,
(),
{}
)
gql_cls = type(type_name, (), {})
TypeCollector.set(type_name, gql_cls)
@@ -45,8 +42,6 @@ class CollectionGraphTypeFactory:
return gql_type
class Collection:
def __init__(self, nodes: list[T], total_count: int, count: int):
self._nodes = nodes

View File

@@ -26,6 +26,7 @@ class DbModelGraphType(GraphType[T], Generic[T]):
if Configuration.get("GraphQLAuthModuleEnabled", False):
from cpl.graphql.auth.user.user_graph_type import UserGraphType
self.object_field("editor", lambda: UserGraphType, lambda root: root.editor).with_public(public)
self.string_field("created", lambda root: root.created).with_public(public)

View File

@@ -16,6 +16,7 @@ class DbModelFilter(Filter[T], Generic[T]):
self.field("deleted", BoolFilter).with_public(public)
if Configuration.get("GraphQLAuthModuleEnabled", False):
from cpl.graphql.auth.user.user_filter import UserFilter
self.field("editor", lambda: UserFilter).with_public(public)
self.field("created", DateFilter).with_public(public)

View File

@@ -28,14 +28,14 @@ class Mutation(QueryABC):
@staticmethod
async def _resolve_assignments(
foreign_objs: list[int],
resolved_obj: T,
reference_key_own: Union[str, property],
reference_key_foreign: Union[str, property],
source_dao: DataAccessObjectABC[T],
join_dao: DataAccessObjectABC[T],
join_type: Type[DbJoinModelABC],
foreign_dao: DataAccessObjectABC[T],
foreign_objs: list[int],
resolved_obj: T,
reference_key_own: Union[str, property],
reference_key_foreign: Union[str, property],
source_dao: DataAccessObjectABC[T],
join_dao: DataAccessObjectABC[T],
join_type: Type[DbJoinModelABC],
foreign_dao: DataAccessObjectABC[T],
):
if foreign_objs is None:
return
@@ -44,9 +44,7 @@ class Mutation(QueryABC):
if isinstance(reference_key_foreign, property):
reference_key_foreign_attr = reference_key_foreign.fget.__name__
foreign_list = await join_dao.find_by(
[{reference_key_own: resolved_obj.id}, {"deleted": False}]
)
foreign_list = await join_dao.find_by([{reference_key_own: resolved_obj.id}, {"deleted": False}])
to_delete = (
foreign_list
@@ -61,9 +59,7 @@ class Mutation(QueryABC):
foreign_ids = [getattr(x, reference_key_foreign_attr) for x in foreign_list]
deleted_foreign_ids = [
getattr(x, reference_key_foreign_attr)
for x in await join_dao.find_by(
[{reference_key_own: resolved_obj.id}, {"deleted": True}]
)
for x in await join_dao.find_by([{reference_key_own: resolved_obj.id}, {"deleted": True}])
]
to_create = [
@@ -94,6 +90,4 @@ class Mutation(QueryABC):
foreign_changes = [*to_delete, *to_create, *to_restore]
if len(foreign_changes) > 0:
await source_dao.touch(resolved_obj)
await foreign_dao.touch_many_by_id(
[getattr(x, reference_key_foreign_attr) for x in foreign_changes]
)
await foreign_dao.touch_many_by_id([getattr(x, reference_key_foreign_attr) for x in foreign_changes])

View File

@@ -0,0 +1,6 @@
from cpl.graphql.schema.subscription import Subscription
class RootSubscription(Subscription):
def __init__(self):
Subscription.__init__(self)

View File

@@ -0,0 +1,109 @@
import inspect
import types
from abc import ABC
from typing import Any, Callable, Dict, Type, AsyncGenerator, Optional
import strawberry
from strawberry.exceptions import StrawberryException
from cpl.dependency import ServiceProvider, get_provider, inject
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.subscription_field import SubscriptionField
from cpl.graphql.typing import Selector
from cpl.graphql.utils.type_collector import TypeCollector
class Subscription(ABC, StrawberryProtocol):
@inject
def __init__(self, provider: ServiceProvider):
ABC.__init__(self)
self._provider = provider
self._fields: Dict[str, SubscriptionField] = {}
@property
def fields(self) -> dict[str, SubscriptionField]:
return self._fields
@property
def fields_count(self) -> int:
return len(self._fields)
def get_fields(self) -> dict[str, SubscriptionField]:
return self._fields
def subscription_field(
self,
name: str,
type_: Type,
resolver: Callable[..., AsyncGenerator],
selector: Selector = None,
) -> SubscriptionField:
f = SubscriptionField(name, type_, resolver, selector)
self._fields[name] = f
return f
def with_subscription(self, name: str, sub_cls: Type["Subscription"]) -> SubscriptionField:
sub = self._provider.get_service(sub_cls)
if not sub:
raise ValueError(f"Subscription '{sub_cls.__name__}' not registered in service provider")
async def _resolver(root, info):
return sub
self._fields[name] = SubscriptionField(name, sub.to_strawberry(), resolver=_resolver)
return self._fields[name]
@staticmethod
def _type_to_strawberry(t: Type) -> Type:
_t = get_provider().get_service(t)
if isinstance(_t, StrawberryProtocol):
return _t.to_strawberry()
return t
@staticmethod
def _build_resolver(f: SubscriptionField) -> Callable:
async def _resolver(root, info):
async for event in f.resolver(root, info):
if not f.selector or f.selector(event, info):
yield event
return _resolver
def to_strawberry(self) -> Type:
cls = self.__class__
if TypeCollector.has(cls):
return TypeCollector.get(cls)
gql_cls = type(cls.__name__, (), {})
TypeCollector.set(cls, gql_cls)
annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {}
for name, f in self._fields.items():
t = f.type
if isinstance(t, types.GenericAlias):
t = t.__args__[0]
elif isinstance(t, type) and issubclass(t, StrawberryProtocol):
t = self._type_to_strawberry(t)
annotations[name] = Optional[t]
try:
namespace[name] = strawberry.subscription(resolver=self._build_resolver(f))
except StrawberryException as e:
raise Exception(f"Error converting subscription field '{f.name}': {e}") from e
gql_cls.__annotations__ = annotations
for k, v in namespace.items():
setattr(gql_cls, k, v)
try:
gql_type = strawberry.type(gql_cls)
except Exception as e:
raise Exception(f"Error creating strawberry type for '{cls.__name__}': {e}") from e
TypeCollector.set(cls, gql_type)
return gql_type

View File

@@ -0,0 +1,11 @@
from typing import Type, Callable
from cpl.graphql.typing import Selector
class SubscriptionField:
def __init__(self, name: str, type_: Type, resolver: Callable, selector: Selector = None):
self.name = name
self.type = type_
self.resolver = resolver
self.selector = selector

View File

@@ -8,6 +8,7 @@ from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.root_mutation import RootMutation
from cpl.graphql.schema.root_query import RootQuery
from cpl.graphql.schema.root_subscription import RootSubscription
class Schema:
@@ -38,6 +39,13 @@ class Schema:
raise ValueError("RootMutation not registered in service provider")
return mutation
@property
def subscription(self) -> RootSubscription:
subscription = self._provider.get_service(RootSubscription)
if not subscription:
raise ValueError("RootSubscription not registered in service provider")
return subscription
def with_type(self, t: Type[StrawberryProtocol]) -> Self:
self._types[t.__name__] = t
return self
@@ -61,7 +69,7 @@ class Schema:
self._schema = strawberry.Schema(
query=query.to_strawberry() if query.fields_count > 0 else None,
mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None,
subscription=None,
subscription=self.subscription.to_strawberry() if self.subscription.fields_count > 0 else None,
types=self._get_types(),
)
return self._schema

View File

@@ -1,11 +1,14 @@
from enum import Enum
from typing import Type, Callable, List, Tuple, Awaitable
from typing import Type, Callable, List, Tuple, Awaitable, Any
import strawberry
from cpl.auth.permission import Permissions
from cpl.graphql.query_context import QueryContext
TQuery = Type["Query"]
Resolver = Callable
Selector = Callable[[Any, strawberry.types.Info], bool]
ScalarType = str | int | float | bool | object
AttributeName = str | property
TRequireAnyPermissions = List[Enum | Permissions] | None