Added subscriptions final #181
All checks were successful
Test before pr merge / test-lint (pull_request) Successful in 7s
All checks were successful
Test before pr merge / test-lint (pull_request) Successful in 7s
This commit is contained in:
@@ -100,7 +100,7 @@ def main():
|
||||
|
||||
schema.mutation.with_mutation("post", PostMutation).with_public()
|
||||
|
||||
schema.subscription.with_subscription("post", PostSubscription)
|
||||
schema.subscription.with_subscription(PostSubscription)
|
||||
|
||||
app.with_auth_root_queries(True)
|
||||
app.with_auth_root_mutations(True)
|
||||
|
||||
@@ -86,15 +86,10 @@ class PostSubscription(Subscription):
|
||||
Subscription.__init__(self)
|
||||
self._bus = bus
|
||||
|
||||
async def post_changed():
|
||||
async for event in await self._bus.subscribe("postChange"):
|
||||
print("Event:", event, type(event))
|
||||
yield event
|
||||
|
||||
def selector(event: Post, info) -> bool:
|
||||
return True
|
||||
return event.id == 101
|
||||
|
||||
self.subscription_field("postChange", PostGraphType, post_changed, selector)
|
||||
self.subscription_field("postChange", PostGraphType, selector).with_public()
|
||||
|
||||
|
||||
class PostMutation(Mutation):
|
||||
|
||||
@@ -169,6 +169,30 @@ class WebApp(WebAppABC):
|
||||
|
||||
return self
|
||||
|
||||
def with_websocket(
|
||||
self,
|
||||
path: str,
|
||||
fn: TEndpoint,
|
||||
authentication: bool = False,
|
||||
roles: list[str | Enum] = None,
|
||||
permissions: list[str | Enum] = None,
|
||||
policies: list[str] = None,
|
||||
match: ValidationMatch = None,
|
||||
) -> Self:
|
||||
self._check_for_app()
|
||||
assert path is not None, "path must not be None"
|
||||
assert fn is not None, "fn must not be None"
|
||||
|
||||
Router.websocket(path, registry=self._routes)(fn)
|
||||
|
||||
if authentication:
|
||||
Router.authenticate()(fn)
|
||||
|
||||
if roles or permissions or policies:
|
||||
Router.authorize(roles, permissions, policies, match)(fn)
|
||||
|
||||
return self
|
||||
|
||||
def with_middleware(self, middleware: PartialMiddleware) -> Self:
|
||||
self._check_for_app()
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.types import Scope, Receive, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.logger import APILogger
|
||||
@@ -33,7 +34,7 @@ class RequestMiddleware(ASGIMiddleware):
|
||||
self._ctx_token = None
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
request = Request(scope, receive, send)
|
||||
request = Request(scope, receive, send) if scope["type"] != "websocket" else WebSocket(scope, receive, send)
|
||||
await self.set_request_data(request)
|
||||
|
||||
try:
|
||||
|
||||
31
src/cpl-api/cpl/api/model/websocket_route.py
Normal file
31
src/cpl-api/cpl/api/model/websocket_route.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import Callable
|
||||
|
||||
import starlette.routing
|
||||
|
||||
|
||||
class WebSocketRoute:
|
||||
|
||||
def __init__(self, path: str, fn: Callable, **kwargs):
|
||||
self._path = path
|
||||
self._fn = fn
|
||||
|
||||
self._kwargs = kwargs
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._fn.__name__
|
||||
|
||||
@property
|
||||
def fn(self) -> Callable:
|
||||
return self._fn
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self._path
|
||||
|
||||
@property
|
||||
def kwargs(self) -> dict:
|
||||
return self._kwargs
|
||||
|
||||
def to_starlette(self, *args) -> starlette.routing.WebSocketRoute:
|
||||
return starlette.routing.WebSocketRoute(self._path, self._fn)
|
||||
@@ -1,32 +1,35 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
from cpl.api.model.websocket_route import WebSocketRoute
|
||||
from cpl.core.abc.registry_abc import RegistryABC
|
||||
|
||||
TRoute = Union[ApiRoute, WebSocketRoute]
|
||||
|
||||
|
||||
class RouteRegistry(RegistryABC):
|
||||
|
||||
def __init__(self):
|
||||
RegistryABC.__init__(self)
|
||||
|
||||
def extend(self, items: list[ApiRoute]):
|
||||
def extend(self, items: list[TRoute]):
|
||||
for policy in items:
|
||||
self.add(policy)
|
||||
|
||||
def add(self, item: ApiRoute):
|
||||
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
|
||||
def add(self, item: TRoute):
|
||||
assert isinstance(item, (ApiRoute, WebSocketRoute)), "route must be an instance of ApiRoute"
|
||||
|
||||
if item.path in self._items:
|
||||
raise ValueError(f"ApiRoute {item.path} is already registered")
|
||||
|
||||
self._items[item.path] = item
|
||||
|
||||
def set(self, item: ApiRoute):
|
||||
def set(self, item: TRoute):
|
||||
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
|
||||
self._items[item.path] = item
|
||||
|
||||
def get(self, key: str) -> Optional[ApiRoute]:
|
||||
def get(self, key: str) -> Optional[TRoute]:
|
||||
return self._items.get(key)
|
||||
|
||||
def all(self) -> list[ApiRoute]:
|
||||
def all(self) -> list[TRoute]:
|
||||
return list(self._items.values())
|
||||
|
||||
@@ -91,6 +91,22 @@ class Router:
|
||||
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def websocket(cls, path: str, registry: RouteRegistry = None, **kwargs):
|
||||
from cpl.api.model.websocket_route import WebSocketRoute
|
||||
|
||||
if not registry:
|
||||
routes = get_provider().get_service(RouteRegistry)
|
||||
else:
|
||||
routes = registry
|
||||
|
||||
def inner(fn):
|
||||
routes.add(WebSocketRoute(path, fn, **kwargs))
|
||||
setattr(fn, "_route_path", path)
|
||||
return fn
|
||||
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs):
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
|
||||
@@ -68,7 +68,7 @@ class StructuredLogger(Logger):
|
||||
|
||||
message["request"] = {
|
||||
"url": str(request.url),
|
||||
"method": request.method,
|
||||
"method": request.method if request.scope == "http" else "websocket",
|
||||
"scope": self._scope_to_json(request),
|
||||
}
|
||||
if isinstance(request, Request) and request.scope == "http":
|
||||
|
||||
@@ -9,7 +9,10 @@ async def graphiql_endpoint(request):
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>GraphiQL</title>
|
||||
<link href="https://unpkg.com/graphiql@2.4.0/graphiql.min.css" rel="stylesheet" />
|
||||
<link
|
||||
href="https://unpkg.com/graphiql@2.4.0/graphiql.min.css"
|
||||
rel="stylesheet"
|
||||
/>
|
||||
</head>
|
||||
<body style="margin:0;overflow:hidden;">
|
||||
<div id="graphiql" style="height:100vh;"></div>
|
||||
@@ -21,13 +24,39 @@ async def graphiql_endpoint(request):
|
||||
<!-- GraphiQL -->
|
||||
<script src="https://unpkg.com/graphiql@2.4.0/graphiql.min.js"></script>
|
||||
|
||||
<!-- GraphQL over WebSocket client -->
|
||||
<script src="https://unpkg.com/graphql-ws@5.11.3/umd/graphql-ws.min.js"></script>
|
||||
|
||||
<script>
|
||||
const graphQLFetcher = graphQLParams =>
|
||||
fetch('/api/graphql', {
|
||||
method: 'post',
|
||||
const httpUrl = window.location.origin + '/api/graphql';
|
||||
const wsUrl = (window.location.protocol === 'https:' ? 'wss://' : 'ws://') +
|
||||
window.location.host + '/api/graphql/ws';
|
||||
|
||||
// HTTP fetcher for queries & mutations
|
||||
const httpFetcher = async (params) => {
|
||||
const res = await fetch(httpUrl, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(graphQLParams),
|
||||
}).then(response => response.json()).catch(() => response.text());
|
||||
body: JSON.stringify(params),
|
||||
});
|
||||
return res.json();
|
||||
};
|
||||
|
||||
// WebSocket fetcher for subscriptions
|
||||
const wsClient = graphqlWs.createClient({ url: wsUrl });
|
||||
const wsFetcher = (params) => ({
|
||||
subscribe: (sink) => ({
|
||||
unsubscribe: wsClient.subscribe(params, sink),
|
||||
}),
|
||||
});
|
||||
|
||||
// smart fetcher wrapper (decides HTTP or WS)
|
||||
const graphQLFetcher = (params) => {
|
||||
if (params.query.trim().startsWith('subscription')) {
|
||||
return wsFetcher(params);
|
||||
}
|
||||
return httpFetcher(params);
|
||||
};
|
||||
|
||||
ReactDOM.render(
|
||||
React.createElement(GraphiQL, { fetcher: graphQLFetcher }),
|
||||
|
||||
27
src/cpl-graphql/cpl/graphql/_endpoints/lazy_graphql_app.py
Normal file
27
src/cpl-graphql/cpl/graphql/_endpoints/lazy_graphql_app.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from strawberry.asgi import GraphQL
|
||||
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
|
||||
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
|
||||
class LazyGraphQLApp:
|
||||
|
||||
def __init__(self, services: ServiceProvider):
|
||||
self._services = services
|
||||
self._graphql_app = None
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if self._graphql_app is None:
|
||||
schema = self._services.get_service(Schema)
|
||||
if not schema or not schema.schema:
|
||||
raise RuntimeError("GraphQL Schema not available yet")
|
||||
|
||||
self._graphql_app = GraphQL(
|
||||
schema.schema,
|
||||
subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
|
||||
)
|
||||
|
||||
await self._graphql_app(scope, receive, send)
|
||||
@@ -10,6 +10,7 @@ from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.dependency.typing import Modules
|
||||
from cpl.graphql._endpoints.graphiql import graphiql_endpoint
|
||||
from cpl.graphql._endpoints.graphql import graphql_endpoint
|
||||
from cpl.graphql._endpoints.lazy_graphql_app import LazyGraphQLApp
|
||||
from cpl.graphql._endpoints.playground import playground_endpoint
|
||||
from cpl.graphql.graphql_module import GraphQLModule
|
||||
from cpl.graphql.service.schema import Schema
|
||||
@@ -43,6 +44,12 @@ class GraphQLApp(WebApp):
|
||||
schema = self._services.get_service(Schema)
|
||||
if schema is None:
|
||||
self._logger.fatal("Could not resolve RootQuery. Make sure GraphQLModule is registered.")
|
||||
#
|
||||
# graphql_ws_app = GraphQL(
|
||||
# schema,
|
||||
# subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
|
||||
# )
|
||||
self.with_websocket("/api/graphql/ws", LazyGraphQLApp(self._services))
|
||||
return schema
|
||||
|
||||
def with_graphiql(
|
||||
|
||||
@@ -1,109 +1,88 @@
|
||||
import inspect
|
||||
import types
|
||||
from abc import ABC
|
||||
from typing import Any, Callable, Dict, Type, AsyncGenerator, Optional
|
||||
from typing import Any, Type, Optional, Self
|
||||
|
||||
import strawberry
|
||||
from strawberry.exceptions import StrawberryException
|
||||
|
||||
from cpl.dependency import ServiceProvider, get_provider, inject
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.api import Unauthorized, Forbidden
|
||||
from cpl.core.ctx.user_context import get_user
|
||||
from cpl.dependency import get_provider, inject
|
||||
from cpl.dependency.event_bus import EventBusABC
|
||||
from cpl.graphql.abc.query_abc import QueryABC
|
||||
from cpl.graphql.error import graphql_error
|
||||
from cpl.graphql.query_context import QueryContext
|
||||
from cpl.graphql.schema.subscription_field import SubscriptionField
|
||||
from cpl.graphql.typing import Selector
|
||||
from cpl.graphql.utils.type_collector import TypeCollector
|
||||
|
||||
|
||||
class Subscription(ABC, StrawberryProtocol):
|
||||
class Subscription(QueryABC):
|
||||
|
||||
@inject
|
||||
def __init__(self, provider: ServiceProvider):
|
||||
ABC.__init__(self)
|
||||
self._provider = provider
|
||||
self._fields: Dict[str, SubscriptionField] = {}
|
||||
|
||||
@property
|
||||
def fields(self) -> dict[str, SubscriptionField]:
|
||||
return self._fields
|
||||
|
||||
@property
|
||||
def fields_count(self) -> int:
|
||||
return len(self._fields)
|
||||
|
||||
def get_fields(self) -> dict[str, SubscriptionField]:
|
||||
return self._fields
|
||||
def __init__(self, bus: EventBusABC):
|
||||
QueryABC.__init__(self)
|
||||
self._bus = bus
|
||||
|
||||
def subscription_field(
|
||||
self,
|
||||
name: str,
|
||||
type_: Type,
|
||||
resolver: Callable[..., AsyncGenerator],
|
||||
selector: Selector = None,
|
||||
t: Type,
|
||||
selector: Optional[Selector] = None,
|
||||
channel: Optional[str] = None,
|
||||
) -> SubscriptionField:
|
||||
f = SubscriptionField(name, type_, resolver, selector)
|
||||
self._fields[name] = f
|
||||
return f
|
||||
field = SubscriptionField(name, t, selector, channel)
|
||||
self._fields[name] = field
|
||||
return field
|
||||
|
||||
def with_subscription(self, name: str, sub_cls: Type["Subscription"]) -> SubscriptionField:
|
||||
sub = self._provider.get_service(sub_cls)
|
||||
def with_subscription(self, sub_cls: Type[Self]) -> Self:
|
||||
sub = get_provider().get_service(sub_cls)
|
||||
if not sub:
|
||||
raise ValueError(f"Subscription '{sub_cls.__name__}' not registered in service provider")
|
||||
raise ValueError(f"Subscription '{sub_cls.__name__}' not registered in provider")
|
||||
|
||||
async def _resolver(root, info):
|
||||
return sub
|
||||
for sub_name, sub_field in sub.get_fields().items():
|
||||
self._fields[sub_name] = sub_field
|
||||
|
||||
self._fields[name] = SubscriptionField(name, sub.to_strawberry(), resolver=_resolver)
|
||||
return self._fields[name]
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _type_to_strawberry(t: Type) -> Type:
|
||||
_t = get_provider().get_service(t)
|
||||
if isinstance(_t, StrawberryProtocol):
|
||||
return _t.to_strawberry()
|
||||
return t
|
||||
def _field_to_strawberry(self, f: SubscriptionField) -> Any:
|
||||
try:
|
||||
if isinstance(f, SubscriptionField):
|
||||
|
||||
@staticmethod
|
||||
def _build_resolver(f: SubscriptionField) -> Callable:
|
||||
async def _resolver(root, info):
|
||||
async for event in f.resolver(root, info):
|
||||
if not f.selector or f.selector(event, info):
|
||||
def make_resolver(field: SubscriptionField):
|
||||
async def resolver(root=None, info=None):
|
||||
if not field.public:
|
||||
user = get_user()
|
||||
if not user:
|
||||
raise graphql_error(Unauthorized(f"{field.name}: Authentication required"))
|
||||
|
||||
if field.require_any_permission:
|
||||
ok = any([await user.has_permission(p) for p in field.require_any_permission])
|
||||
if not ok:
|
||||
raise graphql_error(Forbidden(f"{field.name}: Permission denied"))
|
||||
|
||||
if field.require_any:
|
||||
perms, resolvers = field.require_any
|
||||
ok = any([await user.has_permission(p) for p in perms])
|
||||
if not ok:
|
||||
ctx = QueryContext([x.name for x in await user.permissions])
|
||||
results = [
|
||||
r(ctx) if not inspect.iscoroutinefunction(r) else await r(ctx)
|
||||
for r in resolvers
|
||||
]
|
||||
if not any(results):
|
||||
raise graphql_error(Forbidden(f"{field.name}: Permission denied"))
|
||||
|
||||
async for event in self._bus.subscribe(field.channel):
|
||||
if field.selector is None or field.selector(event, info):
|
||||
yield event
|
||||
|
||||
return _resolver
|
||||
return resolver
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
cls = self.__class__
|
||||
if TypeCollector.has(cls):
|
||||
return TypeCollector.get(cls)
|
||||
return strawberry.subscription(resolver=make_resolver(f))
|
||||
|
||||
gql_cls = type(cls.__name__, (), {})
|
||||
TypeCollector.set(cls, gql_cls)
|
||||
async def wrapper_resolver(root=None, info=None):
|
||||
yield None
|
||||
|
||||
annotations: dict[str, Any] = {}
|
||||
namespace: dict[str, Any] = {}
|
||||
|
||||
for name, f in self._fields.items():
|
||||
t = f.type
|
||||
if isinstance(t, types.GenericAlias):
|
||||
t = t.__args__[0]
|
||||
elif isinstance(t, type) and issubclass(t, StrawberryProtocol):
|
||||
t = self._type_to_strawberry(t)
|
||||
|
||||
annotations[name] = Optional[t]
|
||||
|
||||
try:
|
||||
namespace[name] = strawberry.subscription(resolver=self._build_resolver(f))
|
||||
return strawberry.subscription(resolver=wrapper_resolver)
|
||||
|
||||
except StrawberryException as e:
|
||||
raise Exception(f"Error converting subscription field '{f.name}': {e}") from e
|
||||
|
||||
gql_cls.__annotations__ = annotations
|
||||
for k, v in namespace.items():
|
||||
setattr(gql_cls, k, v)
|
||||
|
||||
try:
|
||||
gql_type = strawberry.type(gql_cls)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error creating strawberry type for '{cls.__name__}': {e}") from e
|
||||
|
||||
TypeCollector.set(cls, gql_type)
|
||||
return gql_type
|
||||
|
||||
@@ -1,11 +1,25 @@
|
||||
from typing import Type, Callable
|
||||
from typing import Type, Callable, Optional
|
||||
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.typing import Selector
|
||||
|
||||
|
||||
class SubscriptionField:
|
||||
def __init__(self, name: str, type_: Type, resolver: Callable, selector: Selector = None):
|
||||
self.name = name
|
||||
self.type = type_
|
||||
self.resolver = resolver
|
||||
class SubscriptionField(Field):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
t: Type,
|
||||
selector: Optional[Selector] = None,
|
||||
channel: Optional[str] = None,
|
||||
):
|
||||
super().__init__(name, t)
|
||||
self.selector = selector
|
||||
self.channel = channel or name
|
||||
|
||||
def with_selector(self, selector: Selector) -> "SubscriptionField":
|
||||
self.selector = selector
|
||||
return self
|
||||
|
||||
def with_channel(self, channel: str) -> "SubscriptionField":
|
||||
self.channel = channel
|
||||
return self
|
||||
|
||||
@@ -65,11 +65,12 @@ class Schema:
|
||||
|
||||
query = self.query
|
||||
mutation = self.mutation
|
||||
subscription = self.subscription
|
||||
|
||||
self._schema = strawberry.Schema(
|
||||
query=query.to_strawberry() if query.fields_count > 0 else None,
|
||||
mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None,
|
||||
subscription=self.subscription.to_strawberry() if self.subscription.fields_count > 0 else None,
|
||||
subscription=subscription.to_strawberry() if subscription.fields_count > 0 else None,
|
||||
types=self._get_types(),
|
||||
)
|
||||
return self._schema
|
||||
|
||||
Reference in New Issue
Block a user