Added subscriptions final #181
All checks were successful
Test before pr merge / test-lint (pull_request) Successful in 7s

This commit is contained in:
2025-10-08 21:22:51 +02:00
parent 3774cef56a
commit 545540d05d
14 changed files with 243 additions and 116 deletions

View File

@@ -100,7 +100,7 @@ def main():
schema.mutation.with_mutation("post", PostMutation).with_public() schema.mutation.with_mutation("post", PostMutation).with_public()
schema.subscription.with_subscription("post", PostSubscription) schema.subscription.with_subscription(PostSubscription)
app.with_auth_root_queries(True) app.with_auth_root_queries(True)
app.with_auth_root_mutations(True) app.with_auth_root_mutations(True)

View File

@@ -86,15 +86,10 @@ class PostSubscription(Subscription):
Subscription.__init__(self) Subscription.__init__(self)
self._bus = bus self._bus = bus
async def post_changed():
async for event in await self._bus.subscribe("postChange"):
print("Event:", event, type(event))
yield event
def selector(event: Post, info) -> bool: def selector(event: Post, info) -> bool:
return True return event.id == 101
self.subscription_field("postChange", PostGraphType, post_changed, selector) self.subscription_field("postChange", PostGraphType, selector).with_public()
class PostMutation(Mutation): class PostMutation(Mutation):

View File

@@ -169,6 +169,30 @@ class WebApp(WebAppABC):
return self return self
def with_websocket(
self,
path: str,
fn: TEndpoint,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Self:
self._check_for_app()
assert path is not None, "path must not be None"
assert fn is not None, "fn must not be None"
Router.websocket(path, registry=self._routes)(fn)
if authentication:
Router.authenticate()(fn)
if roles or permissions or policies:
Router.authorize(roles, permissions, policies, match)(fn)
return self
def with_middleware(self, middleware: PartialMiddleware) -> Self: def with_middleware(self, middleware: PartialMiddleware) -> Self:
self._check_for_app() self._check_for_app()

View File

@@ -5,6 +5,7 @@ from uuid import uuid4
from starlette.requests import Request from starlette.requests import Request
from starlette.types import Scope, Receive, Send from starlette.types import Scope, Receive, Send
from starlette.websockets import WebSocket
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
@@ -33,7 +34,7 @@ class RequestMiddleware(ASGIMiddleware):
self._ctx_token = None self._ctx_token = None
async def __call__(self, scope: Scope, receive: Receive, send: Send): async def __call__(self, scope: Scope, receive: Receive, send: Send):
request = Request(scope, receive, send) request = Request(scope, receive, send) if scope["type"] != "websocket" else WebSocket(scope, receive, send)
await self.set_request_data(request) await self.set_request_data(request)
try: try:

View File

@@ -0,0 +1,31 @@
from typing import Callable
import starlette.routing
class WebSocketRoute:
def __init__(self, path: str, fn: Callable, **kwargs):
self._path = path
self._fn = fn
self._kwargs = kwargs
@property
def name(self) -> str:
return self._fn.__name__
@property
def fn(self) -> Callable:
return self._fn
@property
def path(self) -> str:
return self._path
@property
def kwargs(self) -> dict:
return self._kwargs
def to_starlette(self, *args) -> starlette.routing.WebSocketRoute:
return starlette.routing.WebSocketRoute(self._path, self._fn)

View File

@@ -1,32 +1,35 @@
from typing import Optional from typing import Optional, Union
from cpl.api.model.api_route import ApiRoute from cpl.api.model.api_route import ApiRoute
from cpl.api.model.websocket_route import WebSocketRoute
from cpl.core.abc.registry_abc import RegistryABC from cpl.core.abc.registry_abc import RegistryABC
TRoute = Union[ApiRoute, WebSocketRoute]
class RouteRegistry(RegistryABC): class RouteRegistry(RegistryABC):
def __init__(self): def __init__(self):
RegistryABC.__init__(self) RegistryABC.__init__(self)
def extend(self, items: list[ApiRoute]): def extend(self, items: list[TRoute]):
for policy in items: for policy in items:
self.add(policy) self.add(policy)
def add(self, item: ApiRoute): def add(self, item: TRoute):
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute" assert isinstance(item, (ApiRoute, WebSocketRoute)), "route must be an instance of ApiRoute"
if item.path in self._items: if item.path in self._items:
raise ValueError(f"ApiRoute {item.path} is already registered") raise ValueError(f"ApiRoute {item.path} is already registered")
self._items[item.path] = item self._items[item.path] = item
def set(self, item: ApiRoute): def set(self, item: TRoute):
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute" assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
self._items[item.path] = item self._items[item.path] = item
def get(self, key: str) -> Optional[ApiRoute]: def get(self, key: str) -> Optional[TRoute]:
return self._items.get(key) return self._items.get(key)
def all(self) -> list[ApiRoute]: def all(self) -> list[TRoute]:
return list(self._items.values()) return list(self._items.values())

View File

@@ -91,6 +91,22 @@ class Router:
return inner return inner
@classmethod
def websocket(cls, path: str, registry: RouteRegistry = None, **kwargs):
from cpl.api.model.websocket_route import WebSocketRoute
if not registry:
routes = get_provider().get_service(RouteRegistry)
else:
routes = registry
def inner(fn):
routes.add(WebSocketRoute(path, fn, **kwargs))
setattr(fn, "_route_path", path)
return fn
return inner
@classmethod @classmethod
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs): def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs):
from cpl.api.model.api_route import ApiRoute from cpl.api.model.api_route import ApiRoute

View File

@@ -68,7 +68,7 @@ class StructuredLogger(Logger):
message["request"] = { message["request"] = {
"url": str(request.url), "url": str(request.url),
"method": request.method, "method": request.method if request.scope == "http" else "websocket",
"scope": self._scope_to_json(request), "scope": self._scope_to_json(request),
} }
if isinstance(request, Request) and request.scope == "http": if isinstance(request, Request) and request.scope == "http":

View File

@@ -9,7 +9,10 @@ async def graphiql_endpoint(request):
<head> <head>
<meta charset="utf-8" /> <meta charset="utf-8" />
<title>GraphiQL</title> <title>GraphiQL</title>
<link href="https://unpkg.com/graphiql@2.4.0/graphiql.min.css" rel="stylesheet" /> <link
href="https://unpkg.com/graphiql@2.4.0/graphiql.min.css"
rel="stylesheet"
/>
</head> </head>
<body style="margin:0;overflow:hidden;"> <body style="margin:0;overflow:hidden;">
<div id="graphiql" style="height:100vh;"></div> <div id="graphiql" style="height:100vh;"></div>
@@ -21,13 +24,39 @@ async def graphiql_endpoint(request):
<!-- GraphiQL --> <!-- GraphiQL -->
<script src="https://unpkg.com/graphiql@2.4.0/graphiql.min.js"></script> <script src="https://unpkg.com/graphiql@2.4.0/graphiql.min.js"></script>
<!-- GraphQL over WebSocket client -->
<script src="https://unpkg.com/graphql-ws@5.11.3/umd/graphql-ws.min.js"></script>
<script> <script>
const graphQLFetcher = graphQLParams => const httpUrl = window.location.origin + '/api/graphql';
fetch('/api/graphql', { const wsUrl = (window.location.protocol === 'https:' ? 'wss://' : 'ws://') +
method: 'post', window.location.host + '/api/graphql/ws';
// HTTP fetcher for queries & mutations
const httpFetcher = async (params) => {
const res = await fetch(httpUrl, {
method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(graphQLParams), body: JSON.stringify(params),
}).then(response => response.json()).catch(() => response.text()); });
return res.json();
};
// WebSocket fetcher for subscriptions
const wsClient = graphqlWs.createClient({ url: wsUrl });
const wsFetcher = (params) => ({
subscribe: (sink) => ({
unsubscribe: wsClient.subscribe(params, sink),
}),
});
// smart fetcher wrapper (decides HTTP or WS)
const graphQLFetcher = (params) => {
if (params.query.trim().startsWith('subscription')) {
return wsFetcher(params);
}
return httpFetcher(params);
};
ReactDOM.render( ReactDOM.render(
React.createElement(GraphiQL, { fetcher: graphQLFetcher }), React.createElement(GraphiQL, { fetcher: graphQLFetcher }),

View File

@@ -0,0 +1,27 @@
from starlette.requests import Request
from starlette.responses import Response
from strawberry.asgi import GraphQL
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
from cpl.dependency import ServiceProvider
from cpl.graphql.service.schema import Schema
class LazyGraphQLApp:
def __init__(self, services: ServiceProvider):
self._services = services
self._graphql_app = None
async def __call__(self, scope, receive, send):
if self._graphql_app is None:
schema = self._services.get_service(Schema)
if not schema or not schema.schema:
raise RuntimeError("GraphQL Schema not available yet")
self._graphql_app = GraphQL(
schema.schema,
subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
)
await self._graphql_app(scope, receive, send)

View File

@@ -10,6 +10,7 @@ from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.typing import Modules from cpl.dependency.typing import Modules
from cpl.graphql._endpoints.graphiql import graphiql_endpoint from cpl.graphql._endpoints.graphiql import graphiql_endpoint
from cpl.graphql._endpoints.graphql import graphql_endpoint from cpl.graphql._endpoints.graphql import graphql_endpoint
from cpl.graphql._endpoints.lazy_graphql_app import LazyGraphQLApp
from cpl.graphql._endpoints.playground import playground_endpoint from cpl.graphql._endpoints.playground import playground_endpoint
from cpl.graphql.graphql_module import GraphQLModule from cpl.graphql.graphql_module import GraphQLModule
from cpl.graphql.service.schema import Schema from cpl.graphql.service.schema import Schema
@@ -43,6 +44,12 @@ class GraphQLApp(WebApp):
schema = self._services.get_service(Schema) schema = self._services.get_service(Schema)
if schema is None: if schema is None:
self._logger.fatal("Could not resolve RootQuery. Make sure GraphQLModule is registered.") self._logger.fatal("Could not resolve RootQuery. Make sure GraphQLModule is registered.")
#
# graphql_ws_app = GraphQL(
# schema,
# subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
# )
self.with_websocket("/api/graphql/ws", LazyGraphQLApp(self._services))
return schema return schema
def with_graphiql( def with_graphiql(

View File

@@ -1,109 +1,88 @@
import inspect import inspect
import types from typing import Any, Type, Optional, Self
from abc import ABC
from typing import Any, Callable, Dict, Type, AsyncGenerator, Optional
import strawberry import strawberry
from strawberry.exceptions import StrawberryException from strawberry.exceptions import StrawberryException
from cpl.dependency import ServiceProvider, get_provider, inject from cpl.api import Unauthorized, Forbidden
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.core.ctx.user_context import get_user
from cpl.dependency import get_provider, inject
from cpl.dependency.event_bus import EventBusABC
from cpl.graphql.abc.query_abc import QueryABC
from cpl.graphql.error import graphql_error
from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.subscription_field import SubscriptionField from cpl.graphql.schema.subscription_field import SubscriptionField
from cpl.graphql.typing import Selector from cpl.graphql.typing import Selector
from cpl.graphql.utils.type_collector import TypeCollector
class Subscription(ABC, StrawberryProtocol): class Subscription(QueryABC):
@inject @inject
def __init__(self, provider: ServiceProvider): def __init__(self, bus: EventBusABC):
ABC.__init__(self) QueryABC.__init__(self)
self._provider = provider self._bus = bus
self._fields: Dict[str, SubscriptionField] = {}
@property
def fields(self) -> dict[str, SubscriptionField]:
return self._fields
@property
def fields_count(self) -> int:
return len(self._fields)
def get_fields(self) -> dict[str, SubscriptionField]:
return self._fields
def subscription_field( def subscription_field(
self, self,
name: str, name: str,
type_: Type, t: Type,
resolver: Callable[..., AsyncGenerator], selector: Optional[Selector] = None,
selector: Selector = None, channel: Optional[str] = None,
) -> SubscriptionField: ) -> SubscriptionField:
f = SubscriptionField(name, type_, resolver, selector) field = SubscriptionField(name, t, selector, channel)
self._fields[name] = f self._fields[name] = field
return f return field
def with_subscription(self, name: str, sub_cls: Type["Subscription"]) -> SubscriptionField: def with_subscription(self, sub_cls: Type[Self]) -> Self:
sub = self._provider.get_service(sub_cls) sub = get_provider().get_service(sub_cls)
if not sub: if not sub:
raise ValueError(f"Subscription '{sub_cls.__name__}' not registered in service provider") raise ValueError(f"Subscription '{sub_cls.__name__}' not registered in provider")
async def _resolver(root, info): for sub_name, sub_field in sub.get_fields().items():
return sub self._fields[sub_name] = sub_field
self._fields[name] = SubscriptionField(name, sub.to_strawberry(), resolver=_resolver) return self
return self._fields[name]
@staticmethod def _field_to_strawberry(self, f: SubscriptionField) -> Any:
def _type_to_strawberry(t: Type) -> Type: try:
_t = get_provider().get_service(t) if isinstance(f, SubscriptionField):
if isinstance(_t, StrawberryProtocol):
return _t.to_strawberry()
return t
@staticmethod def make_resolver(field: SubscriptionField):
def _build_resolver(f: SubscriptionField) -> Callable: async def resolver(root=None, info=None):
async def _resolver(root, info): if not field.public:
async for event in f.resolver(root, info): user = get_user()
if not f.selector or f.selector(event, info): if not user:
raise graphql_error(Unauthorized(f"{field.name}: Authentication required"))
if field.require_any_permission:
ok = any([await user.has_permission(p) for p in field.require_any_permission])
if not ok:
raise graphql_error(Forbidden(f"{field.name}: Permission denied"))
if field.require_any:
perms, resolvers = field.require_any
ok = any([await user.has_permission(p) for p in perms])
if not ok:
ctx = QueryContext([x.name for x in await user.permissions])
results = [
r(ctx) if not inspect.iscoroutinefunction(r) else await r(ctx)
for r in resolvers
]
if not any(results):
raise graphql_error(Forbidden(f"{field.name}: Permission denied"))
async for event in self._bus.subscribe(field.channel):
if field.selector is None or field.selector(event, info):
yield event yield event
return _resolver return resolver
def to_strawberry(self) -> Type: return strawberry.subscription(resolver=make_resolver(f))
cls = self.__class__
if TypeCollector.has(cls):
return TypeCollector.get(cls)
gql_cls = type(cls.__name__, (), {}) async def wrapper_resolver(root=None, info=None):
TypeCollector.set(cls, gql_cls) yield None
annotations: dict[str, Any] = {} return strawberry.subscription(resolver=wrapper_resolver)
namespace: dict[str, Any] = {}
for name, f in self._fields.items():
t = f.type
if isinstance(t, types.GenericAlias):
t = t.__args__[0]
elif isinstance(t, type) and issubclass(t, StrawberryProtocol):
t = self._type_to_strawberry(t)
annotations[name] = Optional[t]
try:
namespace[name] = strawberry.subscription(resolver=self._build_resolver(f))
except StrawberryException as e: except StrawberryException as e:
raise Exception(f"Error converting subscription field '{f.name}': {e}") from e raise Exception(f"Error converting subscription field '{f.name}': {e}") from e
gql_cls.__annotations__ = annotations
for k, v in namespace.items():
setattr(gql_cls, k, v)
try:
gql_type = strawberry.type(gql_cls)
except Exception as e:
raise Exception(f"Error creating strawberry type for '{cls.__name__}': {e}") from e
TypeCollector.set(cls, gql_type)
return gql_type

View File

@@ -1,11 +1,25 @@
from typing import Type, Callable from typing import Type, Callable, Optional
from cpl.graphql.schema.field import Field
from cpl.graphql.typing import Selector from cpl.graphql.typing import Selector
class SubscriptionField: class SubscriptionField(Field):
def __init__(self, name: str, type_: Type, resolver: Callable, selector: Selector = None): def __init__(
self.name = name self,
self.type = type_ name: str,
self.resolver = resolver t: Type,
selector: Optional[Selector] = None,
channel: Optional[str] = None,
):
super().__init__(name, t)
self.selector = selector self.selector = selector
self.channel = channel or name
def with_selector(self, selector: Selector) -> "SubscriptionField":
self.selector = selector
return self
def with_channel(self, channel: str) -> "SubscriptionField":
self.channel = channel
return self

View File

@@ -65,11 +65,12 @@ class Schema:
query = self.query query = self.query
mutation = self.mutation mutation = self.mutation
subscription = self.subscription
self._schema = strawberry.Schema( self._schema = strawberry.Schema(
query=query.to_strawberry() if query.fields_count > 0 else None, query=query.to_strawberry() if query.fields_count > 0 else None,
mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None, mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None,
subscription=self.subscription.to_strawberry() if self.subscription.fields_count > 0 else None, subscription=subscription.to_strawberry() if subscription.fields_count > 0 else None,
types=self._get_types(), types=self._get_types(),
) )
return self._schema return self._schema