[WIP] with authentication #181
Some checks failed
Test before pr merge / test-lint (pull_request) Failing after 6s
Some checks failed
Test before pr merge / test-lint (pull_request) Failing after 6s
This commit is contained in:
@@ -85,7 +85,7 @@ def main():
|
|||||||
schema.query.string_field("ping", resolver=lambda: "pong")
|
schema.query.string_field("ping", resolver=lambda: "pong")
|
||||||
schema.query.with_query("hello", HelloQuery)
|
schema.query.with_query("hello", HelloQuery)
|
||||||
schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort)
|
schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort)
|
||||||
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort)
|
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort).with_public(True)
|
||||||
|
|
||||||
app.with_playground()
|
app.with_playground()
|
||||||
app.with_graphiql()
|
app.with_graphiql()
|
||||||
|
|||||||
@@ -26,12 +26,12 @@ class AuthorGraphType(GraphType[Author]):
|
|||||||
self.int_field(
|
self.int_field(
|
||||||
"id",
|
"id",
|
||||||
resolver=lambda root: root.id,
|
resolver=lambda root: root.id,
|
||||||
)
|
).with_public(True)
|
||||||
self.string_field(
|
self.string_field(
|
||||||
"firstName",
|
"firstName",
|
||||||
resolver=lambda root: root.first_name,
|
resolver=lambda root: root.first_name,
|
||||||
)
|
).with_public(True)
|
||||||
self.string_field(
|
self.string_field(
|
||||||
"lastName",
|
"lastName",
|
||||||
resolver=lambda root: root.last_name,
|
resolver=lambda root: root.last_name,
|
||||||
)
|
).with_public(True)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class PostGraphType(GraphType[Post]):
|
|||||||
self.int_field(
|
self.int_field(
|
||||||
"id",
|
"id",
|
||||||
resolver=lambda root: root.id,
|
resolver=lambda root: root.id,
|
||||||
)
|
).with_public(True)
|
||||||
|
|
||||||
async def _a(root: Post):
|
async def _a(root: Post):
|
||||||
return await authors.get_by_id(root.author_id)
|
return await authors.get_by_id(root.author_id)
|
||||||
@@ -38,12 +38,12 @@ class PostGraphType(GraphType[Post]):
|
|||||||
"author",
|
"author",
|
||||||
AuthorGraphType,
|
AuthorGraphType,
|
||||||
resolver=_a#lambda root: root.author_id,
|
resolver=_a#lambda root: root.author_id,
|
||||||
)
|
).with_public(True)
|
||||||
self.string_field(
|
self.string_field(
|
||||||
"title",
|
"title",
|
||||||
resolver=lambda root: root.title,
|
resolver=lambda root: root.title,
|
||||||
)
|
).with_public(True)
|
||||||
self.string_field(
|
self.string_field(
|
||||||
"content",
|
"content",
|
||||||
resolver=lambda root: root.content,
|
resolver=lambda root: root.content,
|
||||||
)
|
).with_public(True)
|
||||||
|
|||||||
@@ -25,6 +25,22 @@ class AuthenticationMiddleware(ASGIMiddleware):
|
|||||||
request = get_request()
|
request = get_request()
|
||||||
url = request.url.path
|
url = request.url.path
|
||||||
|
|
||||||
|
if url not in Router.get_auth_required_routes():
|
||||||
|
self._logger.trace(f"No authentication required for {url}")
|
||||||
|
return await self._app(scope, receive, send)
|
||||||
|
|
||||||
|
# ab hier Auth erzwingen
|
||||||
|
user = getattr(request.state, "user", None)
|
||||||
|
if not user or user.deleted:
|
||||||
|
self._logger.debug(f"Unauthorized access to {url}, user missing or deleted")
|
||||||
|
return await Unauthorized("Unauthorized").asgi_response(scope, receive, send)
|
||||||
|
|
||||||
|
return await self._call_next(scope, receive, send)
|
||||||
|
|
||||||
|
async def _old_call__(self, scope: Scope, receive: Receive, send: Send):
|
||||||
|
request = get_request()
|
||||||
|
url = request.url.path
|
||||||
|
|
||||||
if url not in Router.get_auth_required_routes():
|
if url not in Router.get_auth_required_routes():
|
||||||
self._logger.trace(f"No authentication required for {url}")
|
self._logger.trace(f"No authentication required for {url}")
|
||||||
return await self._app(scope, receive, send)
|
return await self._app(scope, receive, send)
|
||||||
|
|||||||
@@ -9,6 +9,10 @@ from starlette.types import Scope, Receive, Send
|
|||||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||||
from cpl.api.logger import APILogger
|
from cpl.api.logger import APILogger
|
||||||
from cpl.api.typing import TRequest
|
from cpl.api.typing import TRequest
|
||||||
|
from cpl.auth.keycloak.keycloak_client import KeycloakClient
|
||||||
|
from cpl.auth.schema import AuthUser
|
||||||
|
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||||
|
from cpl.core.ctx import set_user
|
||||||
from cpl.dependency.inject import inject
|
from cpl.dependency.inject import inject
|
||||||
from cpl.dependency.service_provider import ServiceProvider
|
from cpl.dependency.service_provider import ServiceProvider
|
||||||
|
|
||||||
@@ -17,12 +21,15 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
|
|||||||
|
|
||||||
class RequestMiddleware(ASGIMiddleware):
|
class RequestMiddleware(ASGIMiddleware):
|
||||||
|
|
||||||
def __init__(self, app, provider: ServiceProvider, logger: APILogger):
|
def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
|
||||||
ASGIMiddleware.__init__(self, app)
|
ASGIMiddleware.__init__(self, app)
|
||||||
|
|
||||||
self._provider = provider
|
self._provider = provider
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
|
|
||||||
|
self._keycloak = keycloak
|
||||||
|
self._user_dao = user_dao
|
||||||
|
|
||||||
self._ctx_token = None
|
self._ctx_token = None
|
||||||
|
|
||||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||||
@@ -30,6 +37,7 @@ class RequestMiddleware(ASGIMiddleware):
|
|||||||
await self.set_request_data(request)
|
await self.set_request_data(request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
await self._try_set_user(request)
|
||||||
with self._provider.create_scope():
|
with self._provider.create_scope():
|
||||||
inject(await self._app(scope, receive, send))
|
inject(await self._app(scope, receive, send))
|
||||||
finally:
|
finally:
|
||||||
@@ -53,6 +61,36 @@ class RequestMiddleware(ASGIMiddleware):
|
|||||||
self._logger.trace(f"Clearing current request: {request.state.request_id}")
|
self._logger.trace(f"Clearing current request: {request.state.request_id}")
|
||||||
_request_context.reset(self._ctx_token)
|
_request_context.reset(self._ctx_token)
|
||||||
|
|
||||||
|
async def _try_set_user(self, request: Request):
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
if not auth_header or not auth_header.startswith("Bearer "):
|
||||||
|
return
|
||||||
|
|
||||||
|
token = auth_header.split("Bearer ")[1]
|
||||||
|
try:
|
||||||
|
token_info = self._keycloak.introspect(token)
|
||||||
|
if not token_info.get("active", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
keycloak_id = self._keycloak.get_user_id(token)
|
||||||
|
if not keycloak_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
user = await self._user_dao.find_by_keycloak_id(keycloak_id)
|
||||||
|
if not user:
|
||||||
|
user = AuthUser(0, keycloak_id)
|
||||||
|
uid = await self._user_dao.create(user)
|
||||||
|
user = await self._user_dao.get_by_id(uid)
|
||||||
|
|
||||||
|
if user.deleted:
|
||||||
|
return
|
||||||
|
|
||||||
|
request.state.user = user
|
||||||
|
set_user(user)
|
||||||
|
self._logger.trace(f"User {user.id} bound to request {request.state.request_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.debug(f"Silent user binding failed: {e}")
|
||||||
|
|
||||||
def get_request() -> Optional[TRequest]:
|
def get_request() -> Optional[TRequest]:
|
||||||
return _request_context.get()
|
return _request_context.get()
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response, JSONResponse
|
from starlette.responses import Response, JSONResponse
|
||||||
|
|
||||||
from cpl.graphql.service.service import GraphQLService
|
from cpl.graphql.service.graphql import GraphQLService
|
||||||
|
|
||||||
|
|
||||||
async def graphql_endpoint(request: Request, service: GraphQLService) -> Response:
|
async def graphql_endpoint(request: Request, service: GraphQLService) -> Response:
|
||||||
|
|||||||
14
src/cpl-graphql/cpl/graphql/error.py
Normal file
14
src/cpl-graphql/cpl/graphql/error.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from graphql import GraphQLError
|
||||||
|
|
||||||
|
from cpl.api import APIError
|
||||||
|
|
||||||
|
|
||||||
|
def graphql_error(api_error: APIError) -> GraphQLError:
|
||||||
|
"""Convert an APIError (from cpl-api) into a GraphQL-friendly error."""
|
||||||
|
return GraphQLError(
|
||||||
|
message=api_error.error_message,
|
||||||
|
extensions={
|
||||||
|
"code": api_error.status_code,
|
||||||
|
},
|
||||||
|
original_error=api_error,
|
||||||
|
)
|
||||||
@@ -8,7 +8,7 @@ from cpl.graphql.schema.filter.int_filter import IntFilter
|
|||||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||||
from cpl.graphql.schema.root_query import RootQuery
|
from cpl.graphql.schema.root_query import RootQuery
|
||||||
from cpl.graphql.service.schema import Schema
|
from cpl.graphql.service.schema import Schema
|
||||||
from cpl.graphql.service.service import GraphQLService
|
from cpl.graphql.service.graphql import GraphQLService
|
||||||
|
|
||||||
|
|
||||||
class GraphQLModule(Module):
|
class GraphQLModule(Module):
|
||||||
|
|||||||
90
src/cpl-graphql/cpl/graphql/query_context.py
Normal file
90
src/cpl-graphql/cpl/graphql/query_context.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Any
|
||||||
|
|
||||||
|
from graphql import GraphQLResolveInfo
|
||||||
|
|
||||||
|
from cpl.auth.schema import AuthUser, Permission
|
||||||
|
from cpl.core.utils import get_value
|
||||||
|
|
||||||
|
|
||||||
|
class QueryContext:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data: Any,
|
||||||
|
user: Optional[AuthUser],
|
||||||
|
user_permissions: Optional[list[Enum | Permission]],
|
||||||
|
is_mutation: bool = False,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
self._data = data
|
||||||
|
self._user = user
|
||||||
|
if user_permissions is None:
|
||||||
|
user_permissions = []
|
||||||
|
self._user_permissions: list[str] = [x.name for x in user_permissions]
|
||||||
|
|
||||||
|
self._resolve_info = None
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, GraphQLResolveInfo):
|
||||||
|
self._resolve_info = arg
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._filter = kwargs.get("filters", {})
|
||||||
|
self._sort = kwargs.get("sort", {})
|
||||||
|
self._skip = get_value(kwargs, "skip", int)
|
||||||
|
self._take = get_value(kwargs, "take", int)
|
||||||
|
|
||||||
|
self._input = kwargs.get("input", None)
|
||||||
|
self._args = args
|
||||||
|
self._kwargs = kwargs
|
||||||
|
|
||||||
|
self._is_mutation = is_mutation
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self):
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user(self) -> AuthUser:
|
||||||
|
return self._user
|
||||||
|
|
||||||
|
@property
|
||||||
|
def resolve_info(self) -> Optional[GraphQLResolveInfo]:
|
||||||
|
return self._resolve_info
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filter(self) -> dict:
|
||||||
|
return self._filter
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sort(self) -> dict:
|
||||||
|
return self._sort
|
||||||
|
|
||||||
|
@property
|
||||||
|
def skip(self) -> Optional[int]:
|
||||||
|
return self._skip
|
||||||
|
|
||||||
|
@property
|
||||||
|
def take(self) -> Optional[int]:
|
||||||
|
return self._take
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input(self) -> Optional[Any]:
|
||||||
|
return self._input
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args(self) -> tuple:
|
||||||
|
return self._args
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kwargs(self) -> dict:
|
||||||
|
return self._kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_mutation(self) -> bool:
|
||||||
|
return self._is_mutation
|
||||||
|
|
||||||
|
def has_permission(self, permission: Enum | str) -> bool:
|
||||||
|
return permission.value if isinstance(permission, Enum) else permission in self._user_permissions
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
|
from enum import Enum
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
from cpl.graphql.schema.argument import Argument
|
from cpl.graphql.schema.argument import Argument
|
||||||
from cpl.graphql.typing import TQuery, Resolver
|
from cpl.graphql.typing import TQuery, Resolver, TRequireAnyPermissions, TRequireAnyResolvers
|
||||||
|
|
||||||
|
|
||||||
class Field:
|
class Field:
|
||||||
@@ -9,7 +10,7 @@ class Field:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
gql_type: type = None,
|
t: type = None,
|
||||||
resolver: Resolver = None,
|
resolver: Resolver = None,
|
||||||
optional=None,
|
optional=None,
|
||||||
default=None,
|
default=None,
|
||||||
@@ -17,7 +18,7 @@ class Field:
|
|||||||
parent_type=None,
|
parent_type=None,
|
||||||
):
|
):
|
||||||
self._name = name
|
self._name = name
|
||||||
self._gql_type = gql_type
|
self._type = t
|
||||||
self._resolver = resolver
|
self._resolver = resolver
|
||||||
self._optional = optional or True
|
self._optional = optional or True
|
||||||
self._default = default
|
self._default = default
|
||||||
@@ -26,6 +27,9 @@ class Field:
|
|||||||
self._parent_type = parent_type
|
self._parent_type = parent_type
|
||||||
|
|
||||||
self._args: dict[str, Argument] = {}
|
self._args: dict[str, Argument] = {}
|
||||||
|
self._require_any_permission = None
|
||||||
|
self._require_any = None
|
||||||
|
self._public = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@@ -33,7 +37,7 @@ class Field:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self) -> type:
|
def type(self) -> type:
|
||||||
return self._gql_type
|
return self._type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def resolver(self) -> callable:
|
def resolver(self) -> callable:
|
||||||
@@ -63,6 +67,34 @@ class Field:
|
|||||||
def arguments(self) -> dict[str, Argument]:
|
def arguments(self) -> dict[str, Argument]:
|
||||||
return self._args
|
return self._args
|
||||||
|
|
||||||
|
@property
|
||||||
|
def require_any_permission(self) -> TRequireAnyPermissions | None:
|
||||||
|
return self._require_any_permission
|
||||||
|
|
||||||
|
@property
|
||||||
|
def require_any(self) -> TRequireAnyResolvers | None:
|
||||||
|
return self._require_any
|
||||||
|
|
||||||
|
@property
|
||||||
|
def public(self) -> bool:
|
||||||
|
return self._public
|
||||||
|
|
||||||
|
def with_type(self, t: type) -> Self:
|
||||||
|
self._type = t
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_resolver(self, resolver: Resolver) -> Self:
|
||||||
|
self._resolver = resolver
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_optional(self, optional: bool) -> Self:
|
||||||
|
self._optional = optional
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_default(self, default) -> Self:
|
||||||
|
self._default = default
|
||||||
|
return self
|
||||||
|
|
||||||
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self:
|
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self:
|
||||||
if name in self._args:
|
if name in self._args:
|
||||||
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
|
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
|
||||||
@@ -76,3 +108,21 @@ class Field:
|
|||||||
|
|
||||||
self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional)
|
self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def with_require_any_permission(self, permissions: TRequireAnyPermissions) -> Self:
|
||||||
|
assert permissions is not None, "require_any_permission cannot be None"
|
||||||
|
assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type"
|
||||||
|
self._require_any_permission = permissions
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_require_any(self, permissions: TRequireAnyPermissions, resolvers: TRequireAnyResolvers) -> Self:
|
||||||
|
assert permissions is not None, "permissions cannot be None"
|
||||||
|
assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type"
|
||||||
|
assert resolvers is not None, "resolvers cannot be None"
|
||||||
|
assert all(callable(r) for r in resolvers), "All resolvers must be callable"
|
||||||
|
self._require_any = (permissions, resolvers)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_public(self, public: bool = False) -> Self:
|
||||||
|
self._public = public
|
||||||
|
return self
|
||||||
|
|||||||
@@ -1,13 +1,19 @@
|
|||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Callable, Type, Any, Optional
|
from typing import Callable, Type, Any, Optional
|
||||||
|
|
||||||
import strawberry
|
import strawberry
|
||||||
from strawberry.exceptions import StrawberryException
|
from strawberry.exceptions import StrawberryException
|
||||||
|
|
||||||
|
from cpl.api import Unauthorized, Forbidden
|
||||||
|
from cpl.api.middleware.request import get_request
|
||||||
|
from cpl.core.ctx import get_user
|
||||||
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
||||||
from cpl.dependency.inject import inject
|
from cpl.dependency.inject import inject
|
||||||
from cpl.dependency.service_provider import ServiceProvider
|
from cpl.dependency.service_provider import ServiceProvider
|
||||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||||
|
from cpl.graphql.error import graphql_error
|
||||||
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
|
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
|
||||||
from cpl.graphql.schema.field import Field
|
from cpl.graphql.schema.field import Field
|
||||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||||
@@ -202,23 +208,70 @@ class Query(StrawberryProtocol):
|
|||||||
_resolver.__signature__ = sig
|
_resolver.__signature__ = sig
|
||||||
return _resolver
|
return _resolver
|
||||||
|
|
||||||
|
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
|
||||||
|
# Signatur vom Original übernehmen
|
||||||
|
sig = getattr(resolver, "__signature__", None)
|
||||||
|
|
||||||
|
@functools.wraps(resolver)
|
||||||
|
async def _auth_resolver(*args, **kwargs):
|
||||||
|
request = get_request()
|
||||||
|
user = get_user()
|
||||||
|
|
||||||
|
# Public
|
||||||
|
if f.public:
|
||||||
|
return await self._maybe_await(resolver(*args, **kwargs))
|
||||||
|
|
||||||
|
# Auth required
|
||||||
|
if user is None:
|
||||||
|
raise graphql_error(Unauthorized("Authentication required"))
|
||||||
|
|
||||||
|
# Permissions
|
||||||
|
if f.require_any_permission:
|
||||||
|
if not any(user.has_permission(p) for p in f.require_any_permission):
|
||||||
|
raise Forbidden("Permission denied")
|
||||||
|
|
||||||
|
# Custom resolvers
|
||||||
|
if f.require_any:
|
||||||
|
perms, resolvers = f.require_any
|
||||||
|
if not any(user.has_permission(p) for p in perms):
|
||||||
|
for r in resolvers:
|
||||||
|
ok = await self._maybe_await(r(user, *args, **kwargs))
|
||||||
|
if ok:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise Forbidden("Permission denied")
|
||||||
|
|
||||||
|
return await self._maybe_await(resolver(*args, **kwargs))
|
||||||
|
|
||||||
|
# Signatur beibehalten
|
||||||
|
if sig:
|
||||||
|
_auth_resolver.__signature__ = sig
|
||||||
|
|
||||||
|
return _auth_resolver
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _maybe_await(value):
|
||||||
|
if asyncio.iscoroutine(value):
|
||||||
|
return value
|
||||||
|
return asyncio.sleep(0, result=value) # sofort resolved Future
|
||||||
|
|
||||||
|
|
||||||
def _field_to_strawberry(self, f: Field) -> Any:
|
def _field_to_strawberry(self, f: Field) -> Any:
|
||||||
|
resolver = None
|
||||||
try:
|
try:
|
||||||
if f.resolver:
|
if f.arguments:
|
||||||
|
resolver = self._build_resolver(f)
|
||||||
|
elif not f.resolver:
|
||||||
|
resolver = lambda *_, **__: None
|
||||||
|
else:
|
||||||
ann = getattr(f.resolver, "__annotations__", {})
|
ann = getattr(f.resolver, "__annotations__", {})
|
||||||
if "return" not in ann or ann["return"] is None:
|
if "return" not in ann or ann["return"] is None:
|
||||||
ann = dict(ann)
|
ann = dict(ann)
|
||||||
ann["return"] = f.type
|
ann["return"] = f.type
|
||||||
f.resolver.__annotations__ = ann
|
f.resolver.__annotations__ = ann
|
||||||
|
resolver = f.resolver
|
||||||
|
|
||||||
if f.arguments:
|
return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
|
||||||
resolver = self._build_resolver(f)
|
|
||||||
return strawberry.field(resolver=resolver)
|
|
||||||
|
|
||||||
if not f.resolver:
|
|
||||||
return strawberry.field(resolver=lambda *_, **__: None)
|
|
||||||
|
|
||||||
return strawberry.field(resolver=f.resolver)
|
|
||||||
except StrawberryException as e:
|
except StrawberryException as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Error converting field '{f.name}' to strawberry field: {e}"
|
f"Error converting field '{f.name}' to strawberry field: {e}"
|
||||||
|
|||||||
51
src/cpl-graphql/cpl/graphql/service/graphql.py
Normal file
51
src/cpl-graphql/cpl/graphql/service/graphql.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from graphql import GraphQLError
|
||||||
|
|
||||||
|
from cpl.api import APILogger, APIError
|
||||||
|
from cpl.api.typing import TRequest
|
||||||
|
from cpl.graphql.service.schema import Schema
|
||||||
|
|
||||||
|
|
||||||
|
class GraphQLService:
|
||||||
|
def __init__(self, logger: APILogger, schema: Schema):
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
|
if schema.schema is None:
|
||||||
|
raise ValueError("Schema has not been built. Call schema.build() before using the service.")
|
||||||
|
self._schema = schema.schema
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
variables: Optional[Dict[str, Any]],
|
||||||
|
request: TRequest,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
result = await self._schema.execute(
|
||||||
|
query,
|
||||||
|
variable_values=variables,
|
||||||
|
context_value={"request": request},
|
||||||
|
)
|
||||||
|
|
||||||
|
response_data: Dict[str, Any] = {}
|
||||||
|
if result.errors:
|
||||||
|
errors = []
|
||||||
|
for error in result.errors:
|
||||||
|
if isinstance(error, GraphQLError):
|
||||||
|
self._logger.error(f"GraphQL APIError: {error}")
|
||||||
|
errors.append({"message": error.message, "extensions": error.extensions})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(error, APIError):
|
||||||
|
self._logger.error(f"GraphQL APIError: {error}")
|
||||||
|
errors.append({"message": error.error_message, "extensions": {"code": error.status_code}})
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._logger.error(f"GraphQL unexpected error: {error}")
|
||||||
|
errors.append({"message": str(error), "extensions": {"code": 500}})
|
||||||
|
|
||||||
|
response_data["errors"] = errors
|
||||||
|
if result.data:
|
||||||
|
response_data["data"] = result.data
|
||||||
|
|
||||||
|
return response_data
|
||||||
@@ -1,7 +1,11 @@
|
|||||||
|
import logging
|
||||||
from typing import Type, Self
|
from typing import Type, Self
|
||||||
|
|
||||||
import strawberry
|
import strawberry
|
||||||
|
from starlette.requests import Request
|
||||||
|
from strawberry.types import ExecutionContext
|
||||||
|
|
||||||
|
from cpl.api import APIError
|
||||||
from cpl.api.logger import APILogger
|
from cpl.api.logger import APILogger
|
||||||
from cpl.dependency.service_provider import ServiceProvider
|
from cpl.dependency.service_provider import ServiceProvider
|
||||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||||
@@ -40,7 +44,18 @@ class Schema:
|
|||||||
|
|
||||||
return types
|
return types
|
||||||
|
|
||||||
|
def _graphql_exception_handler(self, error: Exception, execution_context: ExecutionContext):
|
||||||
|
request: Request = execution_context.context.get("request")
|
||||||
|
|
||||||
|
if isinstance(error, APIError):
|
||||||
|
self._logger.error(f"GraphQL APIError: {error}")
|
||||||
|
return {"message": error.error_message, "extensions": {"code": error.status_code}}
|
||||||
|
|
||||||
|
self._logger.error(f"GraphQL unexpected error: {error}")
|
||||||
|
return {"message": str(error), "extensions": {"code": 500}}
|
||||||
|
|
||||||
def build(self) -> strawberry.Schema:
|
def build(self) -> strawberry.Schema:
|
||||||
|
logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL)
|
||||||
query = self._provider.get_service(RootQuery)
|
query = self._provider.get_service(RootQuery)
|
||||||
if not query:
|
if not query:
|
||||||
raise ValueError("RootQuery not registered in service provider")
|
raise ValueError("RootQuery not registered in service provider")
|
||||||
|
|||||||
@@ -1,31 +0,0 @@
|
|||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from cpl.api.typing import TRequest
|
|
||||||
from cpl.graphql.service.schema import Schema
|
|
||||||
|
|
||||||
|
|
||||||
class GraphQLService:
|
|
||||||
def __init__(self, schema: Schema):
|
|
||||||
if schema.schema is None:
|
|
||||||
raise ValueError("Schema has not been built. Call schema.build() before using the service.")
|
|
||||||
self._schema = schema.schema
|
|
||||||
|
|
||||||
async def execute(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
variables: Optional[Dict[str, Any]],
|
|
||||||
request: TRequest,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
result = await self._schema.execute(
|
|
||||||
query,
|
|
||||||
variable_values=variables,
|
|
||||||
context_value={"request": request},
|
|
||||||
)
|
|
||||||
|
|
||||||
response_data: Dict[str, Any] = {}
|
|
||||||
if result.errors:
|
|
||||||
response_data["errors"] = [str(e) for e in result.errors]
|
|
||||||
if result.data:
|
|
||||||
response_data["data"] = result.data
|
|
||||||
|
|
||||||
return response_data
|
|
||||||
@@ -1,5 +1,15 @@
|
|||||||
from typing import Type, Callable
|
from enum import Enum
|
||||||
|
from typing import Type, Callable, List, Tuple, Awaitable
|
||||||
|
|
||||||
|
from cpl.auth.permission import Permissions
|
||||||
|
from cpl.graphql.query_context import QueryContext
|
||||||
|
|
||||||
TQuery = Type["Query"]
|
TQuery = Type["Query"]
|
||||||
Resolver = Callable
|
Resolver = Callable
|
||||||
ScalarType = str | int | float | bool | object
|
ScalarType = str | int | float | bool | object
|
||||||
|
|
||||||
|
TRequireAnyPermissions = List[Enum | Permissions] | None
|
||||||
|
TRequireAnyResolvers = List[
|
||||||
|
Callable[[QueryContext], bool | Awaitable[bool]],
|
||||||
|
]
|
||||||
|
TRequireAny = Tuple[TRequireAnyPermissions, TRequireAnyResolvers]
|
||||||
|
|||||||
Reference in New Issue
Block a user