WIP: dev into master #184
@@ -100,7 +100,7 @@ def main():
|
|||||||
|
|
||||||
schema.mutation.with_mutation("post", PostMutation).with_public()
|
schema.mutation.with_mutation("post", PostMutation).with_public()
|
||||||
|
|
||||||
schema.subscription.with_subscription("post", PostSubscription)
|
schema.subscription.with_subscription(PostSubscription)
|
||||||
|
|
||||||
app.with_auth_root_queries(True)
|
app.with_auth_root_queries(True)
|
||||||
app.with_auth_root_mutations(True)
|
app.with_auth_root_mutations(True)
|
||||||
|
|||||||
@@ -86,15 +86,10 @@ class PostSubscription(Subscription):
|
|||||||
Subscription.__init__(self)
|
Subscription.__init__(self)
|
||||||
self._bus = bus
|
self._bus = bus
|
||||||
|
|
||||||
async def post_changed():
|
|
||||||
async for event in await self._bus.subscribe("postChange"):
|
|
||||||
print("Event:", event, type(event))
|
|
||||||
yield event
|
|
||||||
|
|
||||||
def selector(event: Post, info) -> bool:
|
def selector(event: Post, info) -> bool:
|
||||||
return True
|
return event.id == 101
|
||||||
|
|
||||||
self.subscription_field("postChange", PostGraphType, post_changed, selector)
|
self.subscription_field("postChange", PostGraphType, selector).with_public()
|
||||||
|
|
||||||
|
|
||||||
class PostMutation(Mutation):
|
class PostMutation(Mutation):
|
||||||
|
|||||||
@@ -169,6 +169,30 @@ class WebApp(WebAppABC):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def with_websocket(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
fn: TEndpoint,
|
||||||
|
authentication: bool = False,
|
||||||
|
roles: list[str | Enum] = None,
|
||||||
|
permissions: list[str | Enum] = None,
|
||||||
|
policies: list[str] = None,
|
||||||
|
match: ValidationMatch = None,
|
||||||
|
) -> Self:
|
||||||
|
self._check_for_app()
|
||||||
|
assert path is not None, "path must not be None"
|
||||||
|
assert fn is not None, "fn must not be None"
|
||||||
|
|
||||||
|
Router.websocket(path, registry=self._routes)(fn)
|
||||||
|
|
||||||
|
if authentication:
|
||||||
|
Router.authenticate()(fn)
|
||||||
|
|
||||||
|
if roles or permissions or policies:
|
||||||
|
Router.authorize(roles, permissions, policies, match)(fn)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
def with_middleware(self, middleware: PartialMiddleware) -> Self:
|
def with_middleware(self, middleware: PartialMiddleware) -> Self:
|
||||||
self._check_for_app()
|
self._check_for_app()
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from uuid import uuid4
|
|||||||
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.types import Scope, Receive, Send
|
from starlette.types import Scope, Receive, Send
|
||||||
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
from cpl.api.logger import APILogger
|
from cpl.api.logger import APILogger
|
||||||
@@ -33,7 +34,7 @@ class RequestMiddleware(ASGIMiddleware):
|
|||||||
self._ctx_token = None
|
self._ctx_token = None
|
||||||
|
|
||||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||||
request = Request(scope, receive, send)
|
request = Request(scope, receive, send) if scope["type"] != "websocket" else WebSocket(scope, receive, send)
|
||||||
await self.set_request_data(request)
|
await self.set_request_data(request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
31
src/cpl-api/cpl/api/model/websocket_route.py
Normal file
31
src/cpl-api/cpl/api/model/websocket_route.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import starlette.routing
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketRoute:
|
||||||
|
|
||||||
|
def __init__(self, path: str, fn: Callable, **kwargs):
|
||||||
|
self._path = path
|
||||||
|
self._fn = fn
|
||||||
|
|
||||||
|
self._kwargs = kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._fn.__name__
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fn(self) -> Callable:
|
||||||
|
return self._fn
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self) -> str:
|
||||||
|
return self._path
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kwargs(self) -> dict:
|
||||||
|
return self._kwargs
|
||||||
|
|
||||||
|
def to_starlette(self, *args) -> starlette.routing.WebSocketRoute:
|
||||||
|
return starlette.routing.WebSocketRoute(self._path, self._fn)
|
||||||
@@ -1,32 +1,35 @@
|
|||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from cpl.api.model.api_route import ApiRoute
|
from cpl.api.model.api_route import ApiRoute
|
||||||
|
from cpl.api.model.websocket_route import WebSocketRoute
|
||||||
from cpl.core.abc.registry_abc import RegistryABC
|
from cpl.core.abc.registry_abc import RegistryABC
|
||||||
|
|
||||||
|
TRoute = Union[ApiRoute, WebSocketRoute]
|
||||||
|
|
||||||
|
|
||||||
class RouteRegistry(RegistryABC):
|
class RouteRegistry(RegistryABC):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
RegistryABC.__init__(self)
|
RegistryABC.__init__(self)
|
||||||
|
|
||||||
def extend(self, items: list[ApiRoute]):
|
def extend(self, items: list[TRoute]):
|
||||||
for policy in items:
|
for policy in items:
|
||||||
self.add(policy)
|
self.add(policy)
|
||||||
|
|
||||||
def add(self, item: ApiRoute):
|
def add(self, item: TRoute):
|
||||||
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
|
assert isinstance(item, (ApiRoute, WebSocketRoute)), "route must be an instance of ApiRoute"
|
||||||
|
|
||||||
if item.path in self._items:
|
if item.path in self._items:
|
||||||
raise ValueError(f"ApiRoute {item.path} is already registered")
|
raise ValueError(f"ApiRoute {item.path} is already registered")
|
||||||
|
|
||||||
self._items[item.path] = item
|
self._items[item.path] = item
|
||||||
|
|
||||||
def set(self, item: ApiRoute):
|
def set(self, item: TRoute):
|
||||||
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
|
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
|
||||||
self._items[item.path] = item
|
self._items[item.path] = item
|
||||||
|
|
||||||
def get(self, key: str) -> Optional[ApiRoute]:
|
def get(self, key: str) -> Optional[TRoute]:
|
||||||
return self._items.get(key)
|
return self._items.get(key)
|
||||||
|
|
||||||
def all(self) -> list[ApiRoute]:
|
def all(self) -> list[TRoute]:
|
||||||
return list(self._items.values())
|
return list(self._items.values())
|
||||||
|
|||||||
@@ -91,6 +91,22 @@ class Router:
|
|||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def websocket(cls, path: str, registry: RouteRegistry = None, **kwargs):
|
||||||
|
from cpl.api.model.websocket_route import WebSocketRoute
|
||||||
|
|
||||||
|
if not registry:
|
||||||
|
routes = get_provider().get_service(RouteRegistry)
|
||||||
|
else:
|
||||||
|
routes = registry
|
||||||
|
|
||||||
|
def inner(fn):
|
||||||
|
routes.add(WebSocketRoute(path, fn, **kwargs))
|
||||||
|
setattr(fn, "_route_path", path)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs):
|
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs):
|
||||||
from cpl.api.model.api_route import ApiRoute
|
from cpl.api.model.api_route import ApiRoute
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ class StructuredLogger(Logger):
|
|||||||
|
|
||||||
message["request"] = {
|
message["request"] = {
|
||||||
"url": str(request.url),
|
"url": str(request.url),
|
||||||
"method": request.method,
|
"method": request.method if request.scope == "http" else "websocket",
|
||||||
"scope": self._scope_to_json(request),
|
"scope": self._scope_to_json(request),
|
||||||
}
|
}
|
||||||
if isinstance(request, Request) and request.scope == "http":
|
if isinstance(request, Request) and request.scope == "http":
|
||||||
|
|||||||
@@ -9,7 +9,10 @@ async def graphiql_endpoint(request):
|
|||||||
<head>
|
<head>
|
||||||
<meta charset="utf-8" />
|
<meta charset="utf-8" />
|
||||||
<title>GraphiQL</title>
|
<title>GraphiQL</title>
|
||||||
<link href="https://unpkg.com/graphiql@2.4.0/graphiql.min.css" rel="stylesheet" />
|
<link
|
||||||
|
href="https://unpkg.com/graphiql@2.4.0/graphiql.min.css"
|
||||||
|
rel="stylesheet"
|
||||||
|
/>
|
||||||
</head>
|
</head>
|
||||||
<body style="margin:0;overflow:hidden;">
|
<body style="margin:0;overflow:hidden;">
|
||||||
<div id="graphiql" style="height:100vh;"></div>
|
<div id="graphiql" style="height:100vh;"></div>
|
||||||
@@ -21,13 +24,39 @@ async def graphiql_endpoint(request):
|
|||||||
<!-- GraphiQL -->
|
<!-- GraphiQL -->
|
||||||
<script src="https://unpkg.com/graphiql@2.4.0/graphiql.min.js"></script>
|
<script src="https://unpkg.com/graphiql@2.4.0/graphiql.min.js"></script>
|
||||||
|
|
||||||
|
<!-- GraphQL over WebSocket client -->
|
||||||
|
<script src="https://unpkg.com/graphql-ws@5.11.3/umd/graphql-ws.min.js"></script>
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
const graphQLFetcher = graphQLParams =>
|
const httpUrl = window.location.origin + '/api/graphql';
|
||||||
fetch('/api/graphql', {
|
const wsUrl = (window.location.protocol === 'https:' ? 'wss://' : 'ws://') +
|
||||||
method: 'post',
|
window.location.host + '/api/graphql/ws';
|
||||||
|
|
||||||
|
// HTTP fetcher for queries & mutations
|
||||||
|
const httpFetcher = async (params) => {
|
||||||
|
const res = await fetch(httpUrl, {
|
||||||
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify(graphQLParams),
|
body: JSON.stringify(params),
|
||||||
}).then(response => response.json()).catch(() => response.text());
|
});
|
||||||
|
return res.json();
|
||||||
|
};
|
||||||
|
|
||||||
|
// WebSocket fetcher for subscriptions
|
||||||
|
const wsClient = graphqlWs.createClient({ url: wsUrl });
|
||||||
|
const wsFetcher = (params) => ({
|
||||||
|
subscribe: (sink) => ({
|
||||||
|
unsubscribe: wsClient.subscribe(params, sink),
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
// smart fetcher wrapper (decides HTTP or WS)
|
||||||
|
const graphQLFetcher = (params) => {
|
||||||
|
if (params.query.trim().startsWith('subscription')) {
|
||||||
|
return wsFetcher(params);
|
||||||
|
}
|
||||||
|
return httpFetcher(params);
|
||||||
|
};
|
||||||
|
|
||||||
ReactDOM.render(
|
ReactDOM.render(
|
||||||
React.createElement(GraphiQL, { fetcher: graphQLFetcher }),
|
React.createElement(GraphiQL, { fetcher: graphQLFetcher }),
|
||||||
|
|||||||
27
src/cpl-graphql/cpl/graphql/_endpoints/lazy_graphql_app.py
Normal file
27
src/cpl-graphql/cpl/graphql/_endpoints/lazy_graphql_app.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import Response
|
||||||
|
from strawberry.asgi import GraphQL
|
||||||
|
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
|
||||||
|
|
||||||
|
from cpl.dependency import ServiceProvider
|
||||||
|
from cpl.graphql.service.schema import Schema
|
||||||
|
|
||||||
|
|
||||||
|
class LazyGraphQLApp:
|
||||||
|
|
||||||
|
def __init__(self, services: ServiceProvider):
|
||||||
|
self._services = services
|
||||||
|
self._graphql_app = None
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
if self._graphql_app is None:
|
||||||
|
schema = self._services.get_service(Schema)
|
||||||
|
if not schema or not schema.schema:
|
||||||
|
raise RuntimeError("GraphQL Schema not available yet")
|
||||||
|
|
||||||
|
self._graphql_app = GraphQL(
|
||||||
|
schema.schema,
|
||||||
|
subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._graphql_app(scope, receive, send)
|
||||||
@@ -10,6 +10,7 @@ from cpl.dependency.service_provider import ServiceProvider
|
|||||||
from cpl.dependency.typing import Modules
|
from cpl.dependency.typing import Modules
|
||||||
from cpl.graphql._endpoints.graphiql import graphiql_endpoint
|
from cpl.graphql._endpoints.graphiql import graphiql_endpoint
|
||||||
from cpl.graphql._endpoints.graphql import graphql_endpoint
|
from cpl.graphql._endpoints.graphql import graphql_endpoint
|
||||||
|
from cpl.graphql._endpoints.lazy_graphql_app import LazyGraphQLApp
|
||||||
from cpl.graphql._endpoints.playground import playground_endpoint
|
from cpl.graphql._endpoints.playground import playground_endpoint
|
||||||
from cpl.graphql.graphql_module import GraphQLModule
|
from cpl.graphql.graphql_module import GraphQLModule
|
||||||
from cpl.graphql.service.schema import Schema
|
from cpl.graphql.service.schema import Schema
|
||||||
@@ -43,6 +44,12 @@ class GraphQLApp(WebApp):
|
|||||||
schema = self._services.get_service(Schema)
|
schema = self._services.get_service(Schema)
|
||||||
if schema is None:
|
if schema is None:
|
||||||
self._logger.fatal("Could not resolve RootQuery. Make sure GraphQLModule is registered.")
|
self._logger.fatal("Could not resolve RootQuery. Make sure GraphQLModule is registered.")
|
||||||
|
#
|
||||||
|
# graphql_ws_app = GraphQL(
|
||||||
|
# schema,
|
||||||
|
# subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
|
||||||
|
# )
|
||||||
|
self.with_websocket("/api/graphql/ws", LazyGraphQLApp(self._services))
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
def with_graphiql(
|
def with_graphiql(
|
||||||
|
|||||||
@@ -1,109 +1,88 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import types
|
from typing import Any, Type, Optional, Self
|
||||||
from abc import ABC
|
|
||||||
from typing import Any, Callable, Dict, Type, AsyncGenerator, Optional
|
|
||||||
|
|
||||||
import strawberry
|
import strawberry
|
||||||
from strawberry.exceptions import StrawberryException
|
from strawberry.exceptions import StrawberryException
|
||||||
|
|
||||||
from cpl.dependency import ServiceProvider, get_provider, inject
|
from cpl.api import Unauthorized, Forbidden
|
||||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
from cpl.core.ctx.user_context import get_user
|
||||||
|
from cpl.dependency import get_provider, inject
|
||||||
|
from cpl.dependency.event_bus import EventBusABC
|
||||||
|
from cpl.graphql.abc.query_abc import QueryABC
|
||||||
|
from cpl.graphql.error import graphql_error
|
||||||
|
from cpl.graphql.query_context import QueryContext
|
||||||
from cpl.graphql.schema.subscription_field import SubscriptionField
|
from cpl.graphql.schema.subscription_field import SubscriptionField
|
||||||
from cpl.graphql.typing import Selector
|
from cpl.graphql.typing import Selector
|
||||||
from cpl.graphql.utils.type_collector import TypeCollector
|
|
||||||
|
|
||||||
|
|
||||||
class Subscription(ABC, StrawberryProtocol):
|
class Subscription(QueryABC):
|
||||||
|
|
||||||
@inject
|
@inject
|
||||||
def __init__(self, provider: ServiceProvider):
|
def __init__(self, bus: EventBusABC):
|
||||||
ABC.__init__(self)
|
QueryABC.__init__(self)
|
||||||
self._provider = provider
|
self._bus = bus
|
||||||
self._fields: Dict[str, SubscriptionField] = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def fields(self) -> dict[str, SubscriptionField]:
|
|
||||||
return self._fields
|
|
||||||
|
|
||||||
@property
|
|
||||||
def fields_count(self) -> int:
|
|
||||||
return len(self._fields)
|
|
||||||
|
|
||||||
def get_fields(self) -> dict[str, SubscriptionField]:
|
|
||||||
return self._fields
|
|
||||||
|
|
||||||
def subscription_field(
|
def subscription_field(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
type_: Type,
|
t: Type,
|
||||||
resolver: Callable[..., AsyncGenerator],
|
selector: Optional[Selector] = None,
|
||||||
selector: Selector = None,
|
channel: Optional[str] = None,
|
||||||
) -> SubscriptionField:
|
) -> SubscriptionField:
|
||||||
f = SubscriptionField(name, type_, resolver, selector)
|
field = SubscriptionField(name, t, selector, channel)
|
||||||
self._fields[name] = f
|
self._fields[name] = field
|
||||||
return f
|
return field
|
||||||
|
|
||||||
def with_subscription(self, name: str, sub_cls: Type["Subscription"]) -> SubscriptionField:
|
def with_subscription(self, sub_cls: Type[Self]) -> Self:
|
||||||
sub = self._provider.get_service(sub_cls)
|
sub = get_provider().get_service(sub_cls)
|
||||||
if not sub:
|
if not sub:
|
||||||
raise ValueError(f"Subscription '{sub_cls.__name__}' not registered in service provider")
|
raise ValueError(f"Subscription '{sub_cls.__name__}' not registered in provider")
|
||||||
|
|
||||||
async def _resolver(root, info):
|
for sub_name, sub_field in sub.get_fields().items():
|
||||||
return sub
|
self._fields[sub_name] = sub_field
|
||||||
|
|
||||||
self._fields[name] = SubscriptionField(name, sub.to_strawberry(), resolver=_resolver)
|
return self
|
||||||
return self._fields[name]
|
|
||||||
|
|
||||||
@staticmethod
|
def _field_to_strawberry(self, f: SubscriptionField) -> Any:
|
||||||
def _type_to_strawberry(t: Type) -> Type:
|
try:
|
||||||
_t = get_provider().get_service(t)
|
if isinstance(f, SubscriptionField):
|
||||||
if isinstance(_t, StrawberryProtocol):
|
|
||||||
return _t.to_strawberry()
|
|
||||||
return t
|
|
||||||
|
|
||||||
@staticmethod
|
def make_resolver(field: SubscriptionField):
|
||||||
def _build_resolver(f: SubscriptionField) -> Callable:
|
async def resolver(root=None, info=None):
|
||||||
async def _resolver(root, info):
|
if not field.public:
|
||||||
async for event in f.resolver(root, info):
|
user = get_user()
|
||||||
if not f.selector or f.selector(event, info):
|
if not user:
|
||||||
|
raise graphql_error(Unauthorized(f"{field.name}: Authentication required"))
|
||||||
|
|
||||||
|
if field.require_any_permission:
|
||||||
|
ok = any([await user.has_permission(p) for p in field.require_any_permission])
|
||||||
|
if not ok:
|
||||||
|
raise graphql_error(Forbidden(f"{field.name}: Permission denied"))
|
||||||
|
|
||||||
|
if field.require_any:
|
||||||
|
perms, resolvers = field.require_any
|
||||||
|
ok = any([await user.has_permission(p) for p in perms])
|
||||||
|
if not ok:
|
||||||
|
ctx = QueryContext([x.name for x in await user.permissions])
|
||||||
|
results = [
|
||||||
|
r(ctx) if not inspect.iscoroutinefunction(r) else await r(ctx)
|
||||||
|
for r in resolvers
|
||||||
|
]
|
||||||
|
if not any(results):
|
||||||
|
raise graphql_error(Forbidden(f"{field.name}: Permission denied"))
|
||||||
|
|
||||||
|
async for event in self._bus.subscribe(field.channel):
|
||||||
|
if field.selector is None or field.selector(event, info):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
return _resolver
|
return resolver
|
||||||
|
|
||||||
def to_strawberry(self) -> Type:
|
return strawberry.subscription(resolver=make_resolver(f))
|
||||||
cls = self.__class__
|
|
||||||
if TypeCollector.has(cls):
|
|
||||||
return TypeCollector.get(cls)
|
|
||||||
|
|
||||||
gql_cls = type(cls.__name__, (), {})
|
async def wrapper_resolver(root=None, info=None):
|
||||||
TypeCollector.set(cls, gql_cls)
|
yield None
|
||||||
|
|
||||||
annotations: dict[str, Any] = {}
|
return strawberry.subscription(resolver=wrapper_resolver)
|
||||||
namespace: dict[str, Any] = {}
|
|
||||||
|
|
||||||
for name, f in self._fields.items():
|
|
||||||
t = f.type
|
|
||||||
if isinstance(t, types.GenericAlias):
|
|
||||||
t = t.__args__[0]
|
|
||||||
elif isinstance(t, type) and issubclass(t, StrawberryProtocol):
|
|
||||||
t = self._type_to_strawberry(t)
|
|
||||||
|
|
||||||
annotations[name] = Optional[t]
|
|
||||||
|
|
||||||
try:
|
|
||||||
namespace[name] = strawberry.subscription(resolver=self._build_resolver(f))
|
|
||||||
|
|
||||||
except StrawberryException as e:
|
except StrawberryException as e:
|
||||||
raise Exception(f"Error converting subscription field '{f.name}': {e}") from e
|
raise Exception(f"Error converting subscription field '{f.name}': {e}") from e
|
||||||
|
|
||||||
gql_cls.__annotations__ = annotations
|
|
||||||
for k, v in namespace.items():
|
|
||||||
setattr(gql_cls, k, v)
|
|
||||||
|
|
||||||
try:
|
|
||||||
gql_type = strawberry.type(gql_cls)
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(f"Error creating strawberry type for '{cls.__name__}': {e}") from e
|
|
||||||
|
|
||||||
TypeCollector.set(cls, gql_type)
|
|
||||||
return gql_type
|
|
||||||
|
|||||||
@@ -1,11 +1,25 @@
|
|||||||
from typing import Type, Callable
|
from typing import Type, Callable, Optional
|
||||||
|
|
||||||
|
from cpl.graphql.schema.field import Field
|
||||||
from cpl.graphql.typing import Selector
|
from cpl.graphql.typing import Selector
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionField:
|
class SubscriptionField(Field):
|
||||||
def __init__(self, name: str, type_: Type, resolver: Callable, selector: Selector = None):
|
def __init__(
|
||||||
self.name = name
|
self,
|
||||||
self.type = type_
|
name: str,
|
||||||
self.resolver = resolver
|
t: Type,
|
||||||
|
selector: Optional[Selector] = None,
|
||||||
|
channel: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(name, t)
|
||||||
self.selector = selector
|
self.selector = selector
|
||||||
|
self.channel = channel or name
|
||||||
|
|
||||||
|
def with_selector(self, selector: Selector) -> "SubscriptionField":
|
||||||
|
self.selector = selector
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_channel(self, channel: str) -> "SubscriptionField":
|
||||||
|
self.channel = channel
|
||||||
|
return self
|
||||||
|
|||||||
@@ -65,11 +65,12 @@ class Schema:
|
|||||||
|
|
||||||
query = self.query
|
query = self.query
|
||||||
mutation = self.mutation
|
mutation = self.mutation
|
||||||
|
subscription = self.subscription
|
||||||
|
|
||||||
self._schema = strawberry.Schema(
|
self._schema = strawberry.Schema(
|
||||||
query=query.to_strawberry() if query.fields_count > 0 else None,
|
query=query.to_strawberry() if query.fields_count > 0 else None,
|
||||||
mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None,
|
mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None,
|
||||||
subscription=self.subscription.to_strawberry() if self.subscription.fields_count > 0 else None,
|
subscription=subscription.to_strawberry() if subscription.fields_count > 0 else None,
|
||||||
types=self._get_types(),
|
types=self._get_types(),
|
||||||
)
|
)
|
||||||
return self._schema
|
return self._schema
|
||||||
|
|||||||
Reference in New Issue
Block a user