Added subscriptions final #181
All checks were successful
Test before pr merge / test-lint (pull_request) Successful in 7s

This commit is contained in:
2025-10-08 21:22:51 +02:00
parent 3774cef56a
commit 545540d05d
14 changed files with 243 additions and 116 deletions

View File

@@ -100,7 +100,7 @@ def main():
schema.mutation.with_mutation("post", PostMutation).with_public()
schema.subscription.with_subscription("post", PostSubscription)
schema.subscription.with_subscription(PostSubscription)
app.with_auth_root_queries(True)
app.with_auth_root_mutations(True)

View File

@@ -86,15 +86,10 @@ class PostSubscription(Subscription):
Subscription.__init__(self)
self._bus = bus
async def post_changed():
async for event in await self._bus.subscribe("postChange"):
print("Event:", event, type(event))
yield event
def selector(event: Post, info) -> bool:
return True
return event.id == 101
self.subscription_field("postChange", PostGraphType, post_changed, selector)
self.subscription_field("postChange", PostGraphType, selector).with_public()
class PostMutation(Mutation):

View File

@@ -169,6 +169,30 @@ class WebApp(WebAppABC):
return self
def with_websocket(
self,
path: str,
fn: TEndpoint,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Self:
self._check_for_app()
assert path is not None, "path must not be None"
assert fn is not None, "fn must not be None"
Router.websocket(path, registry=self._routes)(fn)
if authentication:
Router.authenticate()(fn)
if roles or permissions or policies:
Router.authorize(roles, permissions, policies, match)(fn)
return self
def with_middleware(self, middleware: PartialMiddleware) -> Self:
self._check_for_app()

View File

@@ -5,6 +5,7 @@ from uuid import uuid4
from starlette.requests import Request
from starlette.types import Scope, Receive, Send
from starlette.websockets import WebSocket
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger
@@ -33,7 +34,7 @@ class RequestMiddleware(ASGIMiddleware):
self._ctx_token = None
async def __call__(self, scope: Scope, receive: Receive, send: Send):
request = Request(scope, receive, send)
request = Request(scope, receive, send) if scope["type"] != "websocket" else WebSocket(scope, receive, send)
await self.set_request_data(request)
try:

View File

@@ -0,0 +1,31 @@
from typing import Callable
import starlette.routing
class WebSocketRoute:
def __init__(self, path: str, fn: Callable, **kwargs):
self._path = path
self._fn = fn
self._kwargs = kwargs
@property
def name(self) -> str:
return self._fn.__name__
@property
def fn(self) -> Callable:
return self._fn
@property
def path(self) -> str:
return self._path
@property
def kwargs(self) -> dict:
return self._kwargs
def to_starlette(self, *args) -> starlette.routing.WebSocketRoute:
return starlette.routing.WebSocketRoute(self._path, self._fn)

View File

@@ -1,32 +1,35 @@
from typing import Optional
from typing import Optional, Union
from cpl.api.model.api_route import ApiRoute
from cpl.api.model.websocket_route import WebSocketRoute
from cpl.core.abc.registry_abc import RegistryABC
TRoute = Union[ApiRoute, WebSocketRoute]
class RouteRegistry(RegistryABC):
def __init__(self):
RegistryABC.__init__(self)
def extend(self, items: list[ApiRoute]):
def extend(self, items: list[TRoute]):
for policy in items:
self.add(policy)
def add(self, item: ApiRoute):
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
def add(self, item: TRoute):
assert isinstance(item, (ApiRoute, WebSocketRoute)), "route must be an instance of ApiRoute"
if item.path in self._items:
raise ValueError(f"ApiRoute {item.path} is already registered")
self._items[item.path] = item
def set(self, item: ApiRoute):
def set(self, item: TRoute):
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
self._items[item.path] = item
def get(self, key: str) -> Optional[ApiRoute]:
def get(self, key: str) -> Optional[TRoute]:
return self._items.get(key)
def all(self) -> list[ApiRoute]:
def all(self) -> list[TRoute]:
return list(self._items.values())

View File

@@ -91,6 +91,22 @@ class Router:
return inner
@classmethod
def websocket(cls, path: str, registry: RouteRegistry = None, **kwargs):
from cpl.api.model.websocket_route import WebSocketRoute
if not registry:
routes = get_provider().get_service(RouteRegistry)
else:
routes = registry
def inner(fn):
routes.add(WebSocketRoute(path, fn, **kwargs))
setattr(fn, "_route_path", path)
return fn
return inner
@classmethod
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs):
from cpl.api.model.api_route import ApiRoute

View File

@@ -68,7 +68,7 @@ class StructuredLogger(Logger):
message["request"] = {
"url": str(request.url),
"method": request.method,
"method": request.method if request.scope == "http" else "websocket",
"scope": self._scope_to_json(request),
}
if isinstance(request, Request) and request.scope == "http":

View File

@@ -9,7 +9,10 @@ async def graphiql_endpoint(request):
<head>
<meta charset="utf-8" />
<title>GraphiQL</title>
<link href="https://unpkg.com/graphiql@2.4.0/graphiql.min.css" rel="stylesheet" />
<link
href="https://unpkg.com/graphiql@2.4.0/graphiql.min.css"
rel="stylesheet"
/>
</head>
<body style="margin:0;overflow:hidden;">
<div id="graphiql" style="height:100vh;"></div>
@@ -21,13 +24,39 @@ async def graphiql_endpoint(request):
<!-- GraphiQL -->
<script src="https://unpkg.com/graphiql@2.4.0/graphiql.min.js"></script>
<!-- GraphQL over WebSocket client -->
<script src="https://unpkg.com/graphql-ws@5.11.3/umd/graphql-ws.min.js"></script>
<script>
const graphQLFetcher = graphQLParams =>
fetch('/api/graphql', {
method: 'post',
const httpUrl = window.location.origin + '/api/graphql';
const wsUrl = (window.location.protocol === 'https:' ? 'wss://' : 'ws://') +
window.location.host + '/api/graphql/ws';
// HTTP fetcher for queries & mutations
const httpFetcher = async (params) => {
const res = await fetch(httpUrl, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(graphQLParams),
}).then(response => response.json()).catch(() => response.text());
body: JSON.stringify(params),
});
return res.json();
};
// WebSocket fetcher for subscriptions
const wsClient = graphqlWs.createClient({ url: wsUrl });
const wsFetcher = (params) => ({
subscribe: (sink) => ({
unsubscribe: wsClient.subscribe(params, sink),
}),
});
// smart fetcher wrapper (decides HTTP or WS)
const graphQLFetcher = (params) => {
if (params.query.trim().startsWith('subscription')) {
return wsFetcher(params);
}
return httpFetcher(params);
};
ReactDOM.render(
React.createElement(GraphiQL, { fetcher: graphQLFetcher }),

View File

@@ -0,0 +1,27 @@
from starlette.requests import Request
from starlette.responses import Response
from strawberry.asgi import GraphQL
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
from cpl.dependency import ServiceProvider
from cpl.graphql.service.schema import Schema
class LazyGraphQLApp:
def __init__(self, services: ServiceProvider):
self._services = services
self._graphql_app = None
async def __call__(self, scope, receive, send):
if self._graphql_app is None:
schema = self._services.get_service(Schema)
if not schema or not schema.schema:
raise RuntimeError("GraphQL Schema not available yet")
self._graphql_app = GraphQL(
schema.schema,
subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
)
await self._graphql_app(scope, receive, send)

View File

@@ -10,6 +10,7 @@ from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.typing import Modules
from cpl.graphql._endpoints.graphiql import graphiql_endpoint
from cpl.graphql._endpoints.graphql import graphql_endpoint
from cpl.graphql._endpoints.lazy_graphql_app import LazyGraphQLApp
from cpl.graphql._endpoints.playground import playground_endpoint
from cpl.graphql.graphql_module import GraphQLModule
from cpl.graphql.service.schema import Schema
@@ -43,6 +44,12 @@ class GraphQLApp(WebApp):
schema = self._services.get_service(Schema)
if schema is None:
self._logger.fatal("Could not resolve RootQuery. Make sure GraphQLModule is registered.")
#
# graphql_ws_app = GraphQL(
# schema,
# subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
# )
self.with_websocket("/api/graphql/ws", LazyGraphQLApp(self._services))
return schema
def with_graphiql(

View File

@@ -1,109 +1,88 @@
import inspect
import types
from abc import ABC
from typing import Any, Callable, Dict, Type, AsyncGenerator, Optional
from typing import Any, Type, Optional, Self
import strawberry
from strawberry.exceptions import StrawberryException
from cpl.dependency import ServiceProvider, get_provider, inject
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.api import Unauthorized, Forbidden
from cpl.core.ctx.user_context import get_user
from cpl.dependency import get_provider, inject
from cpl.dependency.event_bus import EventBusABC
from cpl.graphql.abc.query_abc import QueryABC
from cpl.graphql.error import graphql_error
from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.subscription_field import SubscriptionField
from cpl.graphql.typing import Selector
from cpl.graphql.utils.type_collector import TypeCollector
class Subscription(ABC, StrawberryProtocol):
class Subscription(QueryABC):
@inject
def __init__(self, provider: ServiceProvider):
ABC.__init__(self)
self._provider = provider
self._fields: Dict[str, SubscriptionField] = {}
@property
def fields(self) -> dict[str, SubscriptionField]:
return self._fields
@property
def fields_count(self) -> int:
return len(self._fields)
def get_fields(self) -> dict[str, SubscriptionField]:
return self._fields
def __init__(self, bus: EventBusABC):
QueryABC.__init__(self)
self._bus = bus
def subscription_field(
self,
name: str,
type_: Type,
resolver: Callable[..., AsyncGenerator],
selector: Selector = None,
t: Type,
selector: Optional[Selector] = None,
channel: Optional[str] = None,
) -> SubscriptionField:
f = SubscriptionField(name, type_, resolver, selector)
self._fields[name] = f
return f
field = SubscriptionField(name, t, selector, channel)
self._fields[name] = field
return field
def with_subscription(self, name: str, sub_cls: Type["Subscription"]) -> SubscriptionField:
sub = self._provider.get_service(sub_cls)
def with_subscription(self, sub_cls: Type[Self]) -> Self:
sub = get_provider().get_service(sub_cls)
if not sub:
raise ValueError(f"Subscription '{sub_cls.__name__}' not registered in service provider")
raise ValueError(f"Subscription '{sub_cls.__name__}' not registered in provider")
async def _resolver(root, info):
return sub
for sub_name, sub_field in sub.get_fields().items():
self._fields[sub_name] = sub_field
self._fields[name] = SubscriptionField(name, sub.to_strawberry(), resolver=_resolver)
return self._fields[name]
@staticmethod
def _type_to_strawberry(t: Type) -> Type:
_t = get_provider().get_service(t)
if isinstance(_t, StrawberryProtocol):
return _t.to_strawberry()
return t
@staticmethod
def _build_resolver(f: SubscriptionField) -> Callable:
async def _resolver(root, info):
async for event in f.resolver(root, info):
if not f.selector or f.selector(event, info):
yield event
return _resolver
def to_strawberry(self) -> Type:
cls = self.__class__
if TypeCollector.has(cls):
return TypeCollector.get(cls)
gql_cls = type(cls.__name__, (), {})
TypeCollector.set(cls, gql_cls)
annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {}
for name, f in self._fields.items():
t = f.type
if isinstance(t, types.GenericAlias):
t = t.__args__[0]
elif isinstance(t, type) and issubclass(t, StrawberryProtocol):
t = self._type_to_strawberry(t)
annotations[name] = Optional[t]
try:
namespace[name] = strawberry.subscription(resolver=self._build_resolver(f))
except StrawberryException as e:
raise Exception(f"Error converting subscription field '{f.name}': {e}") from e
gql_cls.__annotations__ = annotations
for k, v in namespace.items():
setattr(gql_cls, k, v)
return self
def _field_to_strawberry(self, f: SubscriptionField) -> Any:
try:
gql_type = strawberry.type(gql_cls)
except Exception as e:
raise Exception(f"Error creating strawberry type for '{cls.__name__}': {e}") from e
if isinstance(f, SubscriptionField):
TypeCollector.set(cls, gql_type)
return gql_type
def make_resolver(field: SubscriptionField):
async def resolver(root=None, info=None):
if not field.public:
user = get_user()
if not user:
raise graphql_error(Unauthorized(f"{field.name}: Authentication required"))
if field.require_any_permission:
ok = any([await user.has_permission(p) for p in field.require_any_permission])
if not ok:
raise graphql_error(Forbidden(f"{field.name}: Permission denied"))
if field.require_any:
perms, resolvers = field.require_any
ok = any([await user.has_permission(p) for p in perms])
if not ok:
ctx = QueryContext([x.name for x in await user.permissions])
results = [
r(ctx) if not inspect.iscoroutinefunction(r) else await r(ctx)
for r in resolvers
]
if not any(results):
raise graphql_error(Forbidden(f"{field.name}: Permission denied"))
async for event in self._bus.subscribe(field.channel):
if field.selector is None or field.selector(event, info):
yield event
return resolver
return strawberry.subscription(resolver=make_resolver(f))
async def wrapper_resolver(root=None, info=None):
yield None
return strawberry.subscription(resolver=wrapper_resolver)
except StrawberryException as e:
raise Exception(f"Error converting subscription field '{f.name}': {e}") from e

View File

@@ -1,11 +1,25 @@
from typing import Type, Callable
from typing import Type, Callable, Optional
from cpl.graphql.schema.field import Field
from cpl.graphql.typing import Selector
class SubscriptionField:
def __init__(self, name: str, type_: Type, resolver: Callable, selector: Selector = None):
self.name = name
self.type = type_
self.resolver = resolver
class SubscriptionField(Field):
def __init__(
self,
name: str,
t: Type,
selector: Optional[Selector] = None,
channel: Optional[str] = None,
):
super().__init__(name, t)
self.selector = selector
self.channel = channel or name
def with_selector(self, selector: Selector) -> "SubscriptionField":
self.selector = selector
return self
def with_channel(self, channel: str) -> "SubscriptionField":
self.channel = channel
return self

View File

@@ -65,11 +65,12 @@ class Schema:
query = self.query
mutation = self.mutation
subscription = self.subscription
self._schema = strawberry.Schema(
query=query.to_strawberry() if query.fields_count > 0 else None,
mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None,
subscription=self.subscription.to_strawberry() if self.subscription.fields_count > 0 else None,
subscription=subscription.to_strawberry() if subscription.fields_count > 0 else None,
types=self._get_types(),
)
return self._schema