diff --git a/example/api/src/main.py b/example/api/src/main.py index 7dc53652..4c71bbc9 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -100,7 +100,7 @@ def main(): schema.mutation.with_mutation("post", PostMutation).with_public() - schema.subscription.with_subscription("post", PostSubscription) + schema.subscription.with_subscription(PostSubscription) app.with_auth_root_queries(True) app.with_auth_root_mutations(True) diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py index eab7525f..5fe134f6 100644 --- a/example/api/src/model/post_query.py +++ b/example/api/src/model/post_query.py @@ -86,15 +86,10 @@ class PostSubscription(Subscription): Subscription.__init__(self) self._bus = bus - async def post_changed(): - async for event in await self._bus.subscribe("postChange"): - print("Event:", event, type(event)) - yield event - def selector(event: Post, info) -> bool: - return True + return event.id == 101 - self.subscription_field("postChange", PostGraphType, post_changed, selector) + self.subscription_field("postChange", PostGraphType, selector).with_public() class PostMutation(Mutation): diff --git a/src/cpl-api/cpl/api/application/web_app.py b/src/cpl-api/cpl/api/application/web_app.py index f994444e..f94694f9 100644 --- a/src/cpl-api/cpl/api/application/web_app.py +++ b/src/cpl-api/cpl/api/application/web_app.py @@ -169,6 +169,30 @@ class WebApp(WebAppABC): return self + def with_websocket( + self, + path: str, + fn: TEndpoint, + authentication: bool = False, + roles: list[str | Enum] = None, + permissions: list[str | Enum] = None, + policies: list[str] = None, + match: ValidationMatch = None, + ) -> Self: + self._check_for_app() + assert path is not None, "path must not be None" + assert fn is not None, "fn must not be None" + + Router.websocket(path, registry=self._routes)(fn) + + if authentication: + Router.authenticate()(fn) + + if roles or permissions or policies: + Router.authorize(roles, permissions, policies, match)(fn) + + return self + def with_middleware(self, middleware: PartialMiddleware) -> Self: self._check_for_app() diff --git a/src/cpl-api/cpl/api/middleware/request.py b/src/cpl-api/cpl/api/middleware/request.py index 4f3ae5a4..d5e73721 100644 --- a/src/cpl-api/cpl/api/middleware/request.py +++ b/src/cpl-api/cpl/api/middleware/request.py @@ -5,6 +5,7 @@ from uuid import uuid4 from starlette.requests import Request from starlette.types import Scope, Receive, Send +from starlette.websockets import WebSocket from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.logger import APILogger @@ -33,7 +34,7 @@ class RequestMiddleware(ASGIMiddleware): self._ctx_token = None async def __call__(self, scope: Scope, receive: Receive, send: Send): - request = Request(scope, receive, send) + request = Request(scope, receive, send) if scope["type"] != "websocket" else WebSocket(scope, receive, send) await self.set_request_data(request) try: diff --git a/src/cpl-api/cpl/api/model/websocket_route.py b/src/cpl-api/cpl/api/model/websocket_route.py new file mode 100644 index 00000000..3c09ca3f --- /dev/null +++ b/src/cpl-api/cpl/api/model/websocket_route.py @@ -0,0 +1,31 @@ +from typing import Callable + +import starlette.routing + + +class WebSocketRoute: + + def __init__(self, path: str, fn: Callable, **kwargs): + self._path = path + self._fn = fn + + self._kwargs = kwargs + + @property + def name(self) -> str: + return self._fn.__name__ + + @property + def fn(self) -> Callable: + return self._fn + + @property + def path(self) -> str: + return self._path + + @property + def kwargs(self) -> dict: + return self._kwargs + + def to_starlette(self, *args) -> starlette.routing.WebSocketRoute: + return starlette.routing.WebSocketRoute(self._path, self._fn) diff --git a/src/cpl-api/cpl/api/registry/route.py b/src/cpl-api/cpl/api/registry/route.py index e030007b..83ce7862 100644 --- a/src/cpl-api/cpl/api/registry/route.py +++ b/src/cpl-api/cpl/api/registry/route.py @@ -1,32 +1,35 @@ -from typing import Optional +from typing import Optional, Union from cpl.api.model.api_route import ApiRoute +from cpl.api.model.websocket_route import WebSocketRoute from cpl.core.abc.registry_abc import RegistryABC +TRoute = Union[ApiRoute, WebSocketRoute] + class RouteRegistry(RegistryABC): def __init__(self): RegistryABC.__init__(self) - def extend(self, items: list[ApiRoute]): + def extend(self, items: list[TRoute]): for policy in items: self.add(policy) - def add(self, item: ApiRoute): - assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute" + def add(self, item: TRoute): + assert isinstance(item, (ApiRoute, WebSocketRoute)), "route must be an instance of ApiRoute" if item.path in self._items: raise ValueError(f"ApiRoute {item.path} is already registered") self._items[item.path] = item - def set(self, item: ApiRoute): + def set(self, item: TRoute): assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute" self._items[item.path] = item - def get(self, key: str) -> Optional[ApiRoute]: + def get(self, key: str) -> Optional[TRoute]: return self._items.get(key) - def all(self) -> list[ApiRoute]: + def all(self) -> list[TRoute]: return list(self._items.values()) diff --git a/src/cpl-api/cpl/api/router.py b/src/cpl-api/cpl/api/router.py index 27dfd5ab..55369c38 100644 --- a/src/cpl-api/cpl/api/router.py +++ b/src/cpl-api/cpl/api/router.py @@ -91,6 +91,22 @@ class Router: return inner + @classmethod + def websocket(cls, path: str, registry: RouteRegistry = None, **kwargs): + from cpl.api.model.websocket_route import WebSocketRoute + + if not registry: + routes = get_provider().get_service(RouteRegistry) + else: + routes = registry + + def inner(fn): + routes.add(WebSocketRoute(path, fn, **kwargs)) + setattr(fn, "_route_path", path) + return fn + + return inner + @classmethod def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs): from cpl.api.model.api_route import ApiRoute diff --git a/src/cpl-core/cpl/core/log/structured_logger.py b/src/cpl-core/cpl/core/log/structured_logger.py index 2d1b9eca..e8e45849 100644 --- a/src/cpl-core/cpl/core/log/structured_logger.py +++ b/src/cpl-core/cpl/core/log/structured_logger.py @@ -68,7 +68,7 @@ class StructuredLogger(Logger): message["request"] = { "url": str(request.url), - "method": request.method, + "method": request.method if request.scope == "http" else "websocket", "scope": self._scope_to_json(request), } if isinstance(request, Request) and request.scope == "http": diff --git a/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py b/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py index 70a81ad3..a369fd64 100644 --- a/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py +++ b/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py @@ -9,7 +9,10 @@ async def graphiql_endpoint(request): GraphiQL - +
@@ -21,13 +24,39 @@ async def graphiql_endpoint(request): + + +