WIP: dev into master #184
@@ -1,8 +1,10 @@
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from api.src.queries.cities import CityGraphType, CityFilter, CitySort
|
||||
from api.src.queries.hello import UserGraphType#, UserFilter, UserSort, UserGraphType
|
||||
from api.src.queries.user import UserFilter, UserSort
|
||||
from cpl.dependency.event_bus import EventBusABC
|
||||
from cpl.graphql.event_bus.memory import InMemoryEventBus
|
||||
from queries.cities import CityGraphType, CityFilter, CitySort
|
||||
from queries.hello import UserGraphType # , UserFilter, UserSort, UserGraphType
|
||||
from queries.user import UserFilter, UserSort
|
||||
from cpl.api.api_module import ApiModule
|
||||
from cpl.application.application_builder import ApplicationBuilder
|
||||
from cpl.auth.schema import User, Role
|
||||
@@ -17,7 +19,7 @@ from cpl.graphql.graphql_module import GraphQLModule
|
||||
from model.author_dao import AuthorDao
|
||||
from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort
|
||||
from model.post_dao import PostDao
|
||||
from model.post_query import PostFilter, PostSort, PostGraphType, PostMutation
|
||||
from model.post_query import PostFilter, PostSort, PostGraphType, PostMutation, PostSubscription
|
||||
from permissions import PostPermissions
|
||||
from queries.hello import HelloQuery
|
||||
from scoped_service import ScopedService
|
||||
@@ -41,6 +43,7 @@ def main():
|
||||
.add_module(GraphQLModule)
|
||||
.add_module(GraphQLAuthModule)
|
||||
.add_scoped(ScopedService)
|
||||
.add_singleton(EventBusABC, InMemoryEventBus)
|
||||
.add_cache(User)
|
||||
.add_cache(Role)
|
||||
.add_transient(CityGraphType)
|
||||
@@ -66,6 +69,7 @@ def main():
|
||||
.add_transient(PostFilter)
|
||||
.add_transient(PostSort)
|
||||
.add_transient(PostMutation)
|
||||
.add_transient(PostSubscription)
|
||||
)
|
||||
|
||||
app = builder.build()
|
||||
@@ -96,6 +100,8 @@ def main():
|
||||
|
||||
schema.mutation.with_mutation("post", PostMutation).with_public()
|
||||
|
||||
schema.subscription.with_subscription("post", PostSubscription)
|
||||
|
||||
app.with_auth_root_queries(True)
|
||||
app.with_auth_root_mutations(True)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from cpl.dependency.event_bus import EventBusABC
|
||||
from cpl.graphql.query_context import QueryContext
|
||||
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
|
||||
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
|
||||
@@ -5,6 +6,7 @@ from cpl.graphql.schema.input import Input
|
||||
from cpl.graphql.schema.mutation import Mutation
|
||||
from cpl.graphql.schema.sort.sort import Sort
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
from cpl.graphql.schema.subscription import Subscription
|
||||
from model.author_dao import AuthorDao
|
||||
from model.author_query import AuthorGraphType, AuthorFilter
|
||||
from model.post import Post
|
||||
@@ -19,6 +21,7 @@ class PostFilter(DbModelFilter[Post]):
|
||||
self.string_field("title")
|
||||
self.string_field("content")
|
||||
|
||||
|
||||
class PostSort(Sort[Post]):
|
||||
def __init__(self):
|
||||
Sort.__init__(self)
|
||||
@@ -26,6 +29,7 @@ class PostSort(Sort[Post]):
|
||||
self.field("title", SortOrder)
|
||||
self.field("content", SortOrder)
|
||||
|
||||
|
||||
class PostGraphType(DbModelGraphType[Post]):
|
||||
|
||||
def __init__(self, authors: AuthorDao):
|
||||
@@ -42,7 +46,7 @@ class PostGraphType(DbModelGraphType[Post]):
|
||||
def r_name(ctx: QueryContext):
|
||||
return ctx.user.username == "admin"
|
||||
|
||||
self.object_field("author", AuthorGraphType, resolver=_a).with_public(True)# .with_require_any([], [r_name]))
|
||||
self.object_field("author", AuthorGraphType, resolver=_a).with_public(True) # .with_require_any([], [r_name]))
|
||||
self.string_field(
|
||||
"title",
|
||||
resolver=lambda root: root.title,
|
||||
@@ -64,6 +68,7 @@ class PostCreateInput(Input[Post]):
|
||||
self.string_field("content").with_required()
|
||||
self.int_field("author_id").with_required()
|
||||
|
||||
|
||||
class PostUpdateInput(Input[Post]):
|
||||
title: str
|
||||
content: str
|
||||
@@ -75,13 +80,31 @@ class PostUpdateInput(Input[Post]):
|
||||
self.string_field("title").with_required(False)
|
||||
self.string_field("content").with_required(False)
|
||||
|
||||
|
||||
class PostSubscription(Subscription):
|
||||
def __init__(self, bus: EventBusABC):
|
||||
Subscription.__init__(self)
|
||||
self._bus = bus
|
||||
|
||||
async def post_changed():
|
||||
async for event in await self._bus.subscribe("postChange"):
|
||||
print("Event:", event, type(event))
|
||||
yield event
|
||||
|
||||
def selector(event: Post, info) -> bool:
|
||||
return True
|
||||
|
||||
self.subscription_field("postChange", PostGraphType, post_changed, selector)
|
||||
|
||||
|
||||
class PostMutation(Mutation):
|
||||
|
||||
def __init__(self, posts: PostDao, authors: AuthorDao):
|
||||
def __init__(self, posts: PostDao, authors: AuthorDao, bus: EventBusABC):
|
||||
Mutation.__init__(self)
|
||||
|
||||
self._posts = posts
|
||||
self._authors = authors
|
||||
self._bus = bus
|
||||
|
||||
self.field("create", int, resolver=self.create_post).with_public().with_required().with_argument(
|
||||
"input",
|
||||
@@ -112,6 +135,7 @@ class PostMutation(Mutation):
|
||||
post.content = input.content if input.content is not None else post.content
|
||||
|
||||
await self._posts.update(post)
|
||||
await self._bus.publish("postChange", post)
|
||||
return True
|
||||
|
||||
async def delete_post(self, id: int) -> bool:
|
||||
@@ -127,4 +151,3 @@ class PostMutation(Mutation):
|
||||
return False
|
||||
await self._posts.restore(post)
|
||||
return True
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from api.src.queries.cities import CityFilter, CitySort, CityGraphType, City
|
||||
from api.src.queries.user import User, UserFilter, UserSort, UserGraphType
|
||||
from queries.cities import CityFilter, CitySort, CityGraphType, City
|
||||
from queries.user import User, UserFilter, UserSort, UserGraphType
|
||||
from cpl.api.middleware.request import get_request
|
||||
from cpl.auth.schema import UserDao, User
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
@@ -38,6 +38,7 @@ cities = [City(i, f"City {i}") for i in range(1, 101)]
|
||||
# resolver=lambda root: root.username,
|
||||
# )
|
||||
|
||||
|
||||
class HelloQuery(Query):
|
||||
def __init__(self):
|
||||
Query.__init__(self)
|
||||
|
||||
@@ -21,9 +21,7 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
|
||||
|
||||
class RequestMiddleware(ASGIMiddleware):
|
||||
|
||||
def __init__(
|
||||
self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao
|
||||
):
|
||||
def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._provider = provider
|
||||
|
||||
@@ -119,6 +119,13 @@ class MySQLPool:
|
||||
try:
|
||||
async with await con.cursor(dictionary=True) as cursor:
|
||||
res = await self._exec_sql(cursor, query, args, multi)
|
||||
return list(res)
|
||||
decoded_res = []
|
||||
for row in res:
|
||||
decoded_row = {
|
||||
k: (v.decode("utf-8") if isinstance(v, (bytes, bytearray)) else v) for k, v in row.items()
|
||||
}
|
||||
decoded_res.append(decoded_row)
|
||||
|
||||
return decoded_res
|
||||
finally:
|
||||
await con.close()
|
||||
|
||||
10
src/cpl-dependency/cpl/dependency/event_bus.py
Normal file
10
src/cpl-dependency/cpl/dependency/event_bus.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
|
||||
class EventBusABC(ABC):
|
||||
@abstractmethod
|
||||
async def publish(self, channel: str, event: Any) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def subscribe(self, channel: str) -> AsyncGenerator[Any, None]: ...
|
||||
@@ -1,9 +1,11 @@
|
||||
from typing import Protocol, Type, runtime_checkable
|
||||
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.schema.subscription_field import SubscriptionField
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class StrawberryProtocol(Protocol):
|
||||
def to_strawberry(self) -> Type: ...
|
||||
def get_fields(self) -> dict[str, Field]: ...
|
||||
|
||||
def get_fields(self) -> dict[str, Field | SubscriptionField]: ...
|
||||
|
||||
@@ -114,4 +114,6 @@ class GraphQLApp(WebApp):
|
||||
if self._with_graphiql:
|
||||
self._logger.warning(f"GraphiQL available at http://{host}:{self._api_settings.port}/api/graphiql")
|
||||
if self._with_playground:
|
||||
self._logger.warning(f"GraphQL Playground available at http://{host}:{self._api_settings.port}/api/playground")
|
||||
self._logger.warning(
|
||||
f"GraphQL Playground available at http://{host}:{self._api_settings.port}/api/playground"
|
||||
)
|
||||
|
||||
@@ -8,13 +8,12 @@ from cpl.graphql.schema.mutation import Mutation
|
||||
|
||||
class ApiKeyMutation(Mutation):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
logger: APILogger,
|
||||
api_key_dao: ApiKeyDao,
|
||||
api_key_permission_dao: ApiKeyPermissionDao,
|
||||
permission_dao: ApiKeyPermissionDao,
|
||||
keycloak_admin: KeycloakAdmin,
|
||||
self,
|
||||
logger: APILogger,
|
||||
api_key_dao: ApiKeyDao,
|
||||
api_key_permission_dao: ApiKeyPermissionDao,
|
||||
permission_dao: ApiKeyPermissionDao,
|
||||
keycloak_admin: KeycloakAdmin,
|
||||
):
|
||||
Mutation.__init__(self)
|
||||
self._logger = logger
|
||||
@@ -61,9 +60,7 @@ class ApiKeyMutation(Mutation):
|
||||
api_key = ApiKey.new(obj.identifier)
|
||||
await self._api_key_dao.create(api_key)
|
||||
api_key = await self._api_key_dao.get_single_by([{ApiKey.identifier: obj.identifier}])
|
||||
await self._api_key_permission_dao.create_many(
|
||||
[ApiKeyPermission(0, api_key.id, x) for x in obj.permissions]
|
||||
)
|
||||
await self._api_key_permission_dao.create_many([ApiKeyPermission(0, api_key.id, x) for x in obj.permissions])
|
||||
return api_key
|
||||
|
||||
async def resolve_update(self, input: ApiKeyUpdateInput):
|
||||
|
||||
0
src/cpl-graphql/cpl/graphql/event_bus/__init__.py
Normal file
0
src/cpl-graphql/cpl/graphql/event_bus/__init__.py
Normal file
27
src/cpl-graphql/cpl/graphql/event_bus/memory.py
Normal file
27
src/cpl-graphql/cpl/graphql/event_bus/memory.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from cpl.dependency.event_bus import EventBusABC
|
||||
|
||||
|
||||
class InMemoryEventBus(EventBusABC):
|
||||
def __init__(self):
|
||||
self._subscribers: dict[str, list[asyncio.Queue]] = {}
|
||||
|
||||
async def publish(self, channel: str, event: Any) -> None:
|
||||
queues = self._subscribers.get(channel, [])
|
||||
for q in queues.copy():
|
||||
await q.put(event)
|
||||
|
||||
async def subscribe(self, channel: str) -> AsyncGenerator[Any, None]:
|
||||
q = asyncio.Queue()
|
||||
if channel not in self._subscribers:
|
||||
self._subscribers[channel] = []
|
||||
self._subscribers[channel].append(q)
|
||||
|
||||
try:
|
||||
while True:
|
||||
item = await q.get()
|
||||
yield item
|
||||
finally:
|
||||
self._subscribers[channel].remove(q)
|
||||
@@ -8,13 +8,14 @@ from cpl.graphql.schema.filter.int_filter import IntFilter
|
||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||
from cpl.graphql.schema.root_mutation import RootMutation
|
||||
from cpl.graphql.schema.root_query import RootQuery
|
||||
from cpl.graphql.schema.root_subscription import RootSubscription
|
||||
from cpl.graphql.service.graphql import GraphQLService
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
|
||||
class GraphQLModule(Module):
|
||||
dependencies = [ApiModule]
|
||||
singleton = [Schema, RootQuery, RootMutation]
|
||||
singleton = [Schema, RootQuery, RootMutation, RootSubscription]
|
||||
scoped = [GraphQLService]
|
||||
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
|
||||
from cpl.graphql.utils.type_collector import TypeCollector
|
||||
|
||||
|
||||
class CollectionGraphTypeFactory:
|
||||
@classmethod
|
||||
def get(cls, node_type: Type[StrawberryProtocol]) -> Type:
|
||||
@@ -23,11 +24,7 @@ class CollectionGraphTypeFactory:
|
||||
|
||||
gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type
|
||||
|
||||
gql_cls = type(
|
||||
type_name,
|
||||
(),
|
||||
{}
|
||||
)
|
||||
gql_cls = type(type_name, (), {})
|
||||
|
||||
TypeCollector.set(type_name, gql_cls)
|
||||
|
||||
@@ -45,8 +42,6 @@ class CollectionGraphTypeFactory:
|
||||
return gql_type
|
||||
|
||||
|
||||
|
||||
|
||||
class Collection:
|
||||
def __init__(self, nodes: list[T], total_count: int, count: int):
|
||||
self._nodes = nodes
|
||||
|
||||
@@ -26,6 +26,7 @@ class DbModelGraphType(GraphType[T], Generic[T]):
|
||||
|
||||
if Configuration.get("GraphQLAuthModuleEnabled", False):
|
||||
from cpl.graphql.auth.user.user_graph_type import UserGraphType
|
||||
|
||||
self.object_field("editor", lambda: UserGraphType, lambda root: root.editor).with_public(public)
|
||||
|
||||
self.string_field("created", lambda root: root.created).with_public(public)
|
||||
|
||||
@@ -16,6 +16,7 @@ class DbModelFilter(Filter[T], Generic[T]):
|
||||
self.field("deleted", BoolFilter).with_public(public)
|
||||
if Configuration.get("GraphQLAuthModuleEnabled", False):
|
||||
from cpl.graphql.auth.user.user_filter import UserFilter
|
||||
|
||||
self.field("editor", lambda: UserFilter).with_public(public)
|
||||
|
||||
self.field("created", DateFilter).with_public(public)
|
||||
|
||||
@@ -28,14 +28,14 @@ class Mutation(QueryABC):
|
||||
|
||||
@staticmethod
|
||||
async def _resolve_assignments(
|
||||
foreign_objs: list[int],
|
||||
resolved_obj: T,
|
||||
reference_key_own: Union[str, property],
|
||||
reference_key_foreign: Union[str, property],
|
||||
source_dao: DataAccessObjectABC[T],
|
||||
join_dao: DataAccessObjectABC[T],
|
||||
join_type: Type[DbJoinModelABC],
|
||||
foreign_dao: DataAccessObjectABC[T],
|
||||
foreign_objs: list[int],
|
||||
resolved_obj: T,
|
||||
reference_key_own: Union[str, property],
|
||||
reference_key_foreign: Union[str, property],
|
||||
source_dao: DataAccessObjectABC[T],
|
||||
join_dao: DataAccessObjectABC[T],
|
||||
join_type: Type[DbJoinModelABC],
|
||||
foreign_dao: DataAccessObjectABC[T],
|
||||
):
|
||||
if foreign_objs is None:
|
||||
return
|
||||
@@ -44,9 +44,7 @@ class Mutation(QueryABC):
|
||||
if isinstance(reference_key_foreign, property):
|
||||
reference_key_foreign_attr = reference_key_foreign.fget.__name__
|
||||
|
||||
foreign_list = await join_dao.find_by(
|
||||
[{reference_key_own: resolved_obj.id}, {"deleted": False}]
|
||||
)
|
||||
foreign_list = await join_dao.find_by([{reference_key_own: resolved_obj.id}, {"deleted": False}])
|
||||
|
||||
to_delete = (
|
||||
foreign_list
|
||||
@@ -61,9 +59,7 @@ class Mutation(QueryABC):
|
||||
foreign_ids = [getattr(x, reference_key_foreign_attr) for x in foreign_list]
|
||||
deleted_foreign_ids = [
|
||||
getattr(x, reference_key_foreign_attr)
|
||||
for x in await join_dao.find_by(
|
||||
[{reference_key_own: resolved_obj.id}, {"deleted": True}]
|
||||
)
|
||||
for x in await join_dao.find_by([{reference_key_own: resolved_obj.id}, {"deleted": True}])
|
||||
]
|
||||
|
||||
to_create = [
|
||||
@@ -94,6 +90,4 @@ class Mutation(QueryABC):
|
||||
foreign_changes = [*to_delete, *to_create, *to_restore]
|
||||
if len(foreign_changes) > 0:
|
||||
await source_dao.touch(resolved_obj)
|
||||
await foreign_dao.touch_many_by_id(
|
||||
[getattr(x, reference_key_foreign_attr) for x in foreign_changes]
|
||||
)
|
||||
await foreign_dao.touch_many_by_id([getattr(x, reference_key_foreign_attr) for x in foreign_changes])
|
||||
|
||||
6
src/cpl-graphql/cpl/graphql/schema/root_subscription.py
Normal file
6
src/cpl-graphql/cpl/graphql/schema/root_subscription.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from cpl.graphql.schema.subscription import Subscription
|
||||
|
||||
|
||||
class RootSubscription(Subscription):
|
||||
def __init__(self):
|
||||
Subscription.__init__(self)
|
||||
109
src/cpl-graphql/cpl/graphql/schema/subscription.py
Normal file
109
src/cpl-graphql/cpl/graphql/schema/subscription.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import inspect
|
||||
import types
|
||||
from abc import ABC
|
||||
from typing import Any, Callable, Dict, Type, AsyncGenerator, Optional
|
||||
|
||||
import strawberry
|
||||
from strawberry.exceptions import StrawberryException
|
||||
|
||||
from cpl.dependency import ServiceProvider, get_provider, inject
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.schema.subscription_field import SubscriptionField
|
||||
from cpl.graphql.typing import Selector
|
||||
from cpl.graphql.utils.type_collector import TypeCollector
|
||||
|
||||
|
||||
class Subscription(ABC, StrawberryProtocol):
|
||||
|
||||
@inject
|
||||
def __init__(self, provider: ServiceProvider):
|
||||
ABC.__init__(self)
|
||||
self._provider = provider
|
||||
self._fields: Dict[str, SubscriptionField] = {}
|
||||
|
||||
@property
|
||||
def fields(self) -> dict[str, SubscriptionField]:
|
||||
return self._fields
|
||||
|
||||
@property
|
||||
def fields_count(self) -> int:
|
||||
return len(self._fields)
|
||||
|
||||
def get_fields(self) -> dict[str, SubscriptionField]:
|
||||
return self._fields
|
||||
|
||||
def subscription_field(
|
||||
self,
|
||||
name: str,
|
||||
type_: Type,
|
||||
resolver: Callable[..., AsyncGenerator],
|
||||
selector: Selector = None,
|
||||
) -> SubscriptionField:
|
||||
f = SubscriptionField(name, type_, resolver, selector)
|
||||
self._fields[name] = f
|
||||
return f
|
||||
|
||||
def with_subscription(self, name: str, sub_cls: Type["Subscription"]) -> SubscriptionField:
|
||||
sub = self._provider.get_service(sub_cls)
|
||||
if not sub:
|
||||
raise ValueError(f"Subscription '{sub_cls.__name__}' not registered in service provider")
|
||||
|
||||
async def _resolver(root, info):
|
||||
return sub
|
||||
|
||||
self._fields[name] = SubscriptionField(name, sub.to_strawberry(), resolver=_resolver)
|
||||
return self._fields[name]
|
||||
|
||||
@staticmethod
|
||||
def _type_to_strawberry(t: Type) -> Type:
|
||||
_t = get_provider().get_service(t)
|
||||
if isinstance(_t, StrawberryProtocol):
|
||||
return _t.to_strawberry()
|
||||
return t
|
||||
|
||||
@staticmethod
|
||||
def _build_resolver(f: SubscriptionField) -> Callable:
|
||||
async def _resolver(root, info):
|
||||
async for event in f.resolver(root, info):
|
||||
if not f.selector or f.selector(event, info):
|
||||
yield event
|
||||
|
||||
return _resolver
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
cls = self.__class__
|
||||
if TypeCollector.has(cls):
|
||||
return TypeCollector.get(cls)
|
||||
|
||||
gql_cls = type(cls.__name__, (), {})
|
||||
TypeCollector.set(cls, gql_cls)
|
||||
|
||||
annotations: dict[str, Any] = {}
|
||||
namespace: dict[str, Any] = {}
|
||||
|
||||
for name, f in self._fields.items():
|
||||
t = f.type
|
||||
if isinstance(t, types.GenericAlias):
|
||||
t = t.__args__[0]
|
||||
elif isinstance(t, type) and issubclass(t, StrawberryProtocol):
|
||||
t = self._type_to_strawberry(t)
|
||||
|
||||
annotations[name] = Optional[t]
|
||||
|
||||
try:
|
||||
namespace[name] = strawberry.subscription(resolver=self._build_resolver(f))
|
||||
|
||||
except StrawberryException as e:
|
||||
raise Exception(f"Error converting subscription field '{f.name}': {e}") from e
|
||||
|
||||
gql_cls.__annotations__ = annotations
|
||||
for k, v in namespace.items():
|
||||
setattr(gql_cls, k, v)
|
||||
|
||||
try:
|
||||
gql_type = strawberry.type(gql_cls)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error creating strawberry type for '{cls.__name__}': {e}") from e
|
||||
|
||||
TypeCollector.set(cls, gql_type)
|
||||
return gql_type
|
||||
11
src/cpl-graphql/cpl/graphql/schema/subscription_field.py
Normal file
11
src/cpl-graphql/cpl/graphql/schema/subscription_field.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import Type, Callable
|
||||
|
||||
from cpl.graphql.typing import Selector
|
||||
|
||||
|
||||
class SubscriptionField:
|
||||
def __init__(self, name: str, type_: Type, resolver: Callable, selector: Selector = None):
|
||||
self.name = name
|
||||
self.type = type_
|
||||
self.resolver = resolver
|
||||
self.selector = selector
|
||||
@@ -8,6 +8,7 @@ from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.schema.root_mutation import RootMutation
|
||||
from cpl.graphql.schema.root_query import RootQuery
|
||||
from cpl.graphql.schema.root_subscription import RootSubscription
|
||||
|
||||
|
||||
class Schema:
|
||||
@@ -38,6 +39,13 @@ class Schema:
|
||||
raise ValueError("RootMutation not registered in service provider")
|
||||
return mutation
|
||||
|
||||
@property
|
||||
def subscription(self) -> RootSubscription:
|
||||
subscription = self._provider.get_service(RootSubscription)
|
||||
if not subscription:
|
||||
raise ValueError("RootSubscription not registered in service provider")
|
||||
return subscription
|
||||
|
||||
def with_type(self, t: Type[StrawberryProtocol]) -> Self:
|
||||
self._types[t.__name__] = t
|
||||
return self
|
||||
@@ -61,7 +69,7 @@ class Schema:
|
||||
self._schema = strawberry.Schema(
|
||||
query=query.to_strawberry() if query.fields_count > 0 else None,
|
||||
mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None,
|
||||
subscription=None,
|
||||
subscription=self.subscription.to_strawberry() if self.subscription.fields_count > 0 else None,
|
||||
types=self._get_types(),
|
||||
)
|
||||
return self._schema
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from enum import Enum
|
||||
from typing import Type, Callable, List, Tuple, Awaitable
|
||||
from typing import Type, Callable, List, Tuple, Awaitable, Any
|
||||
|
||||
import strawberry
|
||||
|
||||
from cpl.auth.permission import Permissions
|
||||
from cpl.graphql.query_context import QueryContext
|
||||
|
||||
TQuery = Type["Query"]
|
||||
Resolver = Callable
|
||||
Selector = Callable[[Any, strawberry.types.Info], bool]
|
||||
ScalarType = str | int | float | bool | object
|
||||
AttributeName = str | property
|
||||
TRequireAnyPermissions = List[Enum | Permissions] | None
|
||||
|
||||
Reference in New Issue
Block a user