From 12436c296b8380e7e7414308861a1c5184a69390 Mon Sep 17 00:00:00 2001 From: edraft Date: Sun, 28 Sep 2025 11:45:51 +0200 Subject: [PATCH] [WIP] with authentication #181 --- example/api/src/main.py | 2 +- example/api/src/model/author_query.py | 6 +- example/api/src/model/post_query.py | 8 +- .../cpl/api/middleware/authentication.py | 16 ++++ src/cpl-api/cpl/api/middleware/request.py | 40 ++++++++- .../cpl/graphql/_endpoints/graphql.py | 2 +- src/cpl-graphql/cpl/graphql/error.py | 14 +++ src/cpl-graphql/cpl/graphql/graphql_module.py | 2 +- src/cpl-graphql/cpl/graphql/query_context.py | 90 +++++++++++++++++++ src/cpl-graphql/cpl/graphql/schema/field.py | 58 +++++++++++- src/cpl-graphql/cpl/graphql/schema/query.py | 71 +++++++++++++-- .../cpl/graphql/service/graphql.py | 51 +++++++++++ src/cpl-graphql/cpl/graphql/service/schema.py | 15 ++++ .../cpl/graphql/service/service.py | 31 ------- src/cpl-graphql/cpl/graphql/typing.py | 14 ++- 15 files changed, 363 insertions(+), 57 deletions(-) create mode 100644 src/cpl-graphql/cpl/graphql/error.py create mode 100644 src/cpl-graphql/cpl/graphql/query_context.py create mode 100644 src/cpl-graphql/cpl/graphql/service/graphql.py delete mode 100644 src/cpl-graphql/cpl/graphql/service/service.py diff --git a/example/api/src/main.py b/example/api/src/main.py index 7273381a..c149fe47 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -85,7 +85,7 @@ def main(): schema.query.string_field("ping", resolver=lambda: "pong") schema.query.with_query("hello", HelloQuery) schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort) - schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort) + schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort).with_public(True) app.with_playground() app.with_graphiql() diff --git a/example/api/src/model/author_query.py b/example/api/src/model/author_query.py index f7f1d1df..ae365a7c 100644 --- a/example/api/src/model/author_query.py +++ b/example/api/src/model/author_query.py @@ -26,12 +26,12 @@ class AuthorGraphType(GraphType[Author]): self.int_field( "id", resolver=lambda root: root.id, - ) + ).with_public(True) self.string_field( "firstName", resolver=lambda root: root.first_name, - ) + ).with_public(True) self.string_field( "lastName", resolver=lambda root: root.last_name, - ) + ).with_public(True) diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py index e3bc41af..48845617 100644 --- a/example/api/src/model/post_query.py +++ b/example/api/src/model/post_query.py @@ -29,7 +29,7 @@ class PostGraphType(GraphType[Post]): self.int_field( "id", resolver=lambda root: root.id, - ) + ).with_public(True) async def _a(root: Post): return await authors.get_by_id(root.author_id) @@ -38,12 +38,12 @@ class PostGraphType(GraphType[Post]): "author", AuthorGraphType, resolver=_a#lambda root: root.author_id, - ) + ).with_public(True) self.string_field( "title", resolver=lambda root: root.title, - ) + ).with_public(True) self.string_field( "content", resolver=lambda root: root.content, - ) + ).with_public(True) diff --git a/src/cpl-api/cpl/api/middleware/authentication.py b/src/cpl-api/cpl/api/middleware/authentication.py index c0dc95f1..2412a1c6 100644 --- a/src/cpl-api/cpl/api/middleware/authentication.py +++ b/src/cpl-api/cpl/api/middleware/authentication.py @@ -25,6 +25,22 @@ class AuthenticationMiddleware(ASGIMiddleware): request = get_request() url = request.url.path + if url not in Router.get_auth_required_routes(): + self._logger.trace(f"No authentication required for {url}") + return await self._app(scope, receive, send) + + # ab hier Auth erzwingen + user = getattr(request.state, "user", None) + if not user or user.deleted: + self._logger.debug(f"Unauthorized access to {url}, user missing or deleted") + return await Unauthorized("Unauthorized").asgi_response(scope, receive, send) + + return await self._call_next(scope, receive, send) + + async def _old_call__(self, scope: Scope, receive: Receive, send: Send): + request = get_request() + url = request.url.path + if url not in Router.get_auth_required_routes(): self._logger.trace(f"No authentication required for {url}") return await self._app(scope, receive, send) diff --git a/src/cpl-api/cpl/api/middleware/request.py b/src/cpl-api/cpl/api/middleware/request.py index 0cedc88b..2dc24bc5 100644 --- a/src/cpl-api/cpl/api/middleware/request.py +++ b/src/cpl-api/cpl/api/middleware/request.py @@ -9,6 +9,10 @@ from starlette.types import Scope, Receive, Send from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.logger import APILogger from cpl.api.typing import TRequest +from cpl.auth.keycloak.keycloak_client import KeycloakClient +from cpl.auth.schema import AuthUser +from cpl.auth.schema._administration.auth_user_dao import AuthUserDao +from cpl.core.ctx import set_user from cpl.dependency.inject import inject from cpl.dependency.service_provider import ServiceProvider @@ -17,12 +21,15 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa class RequestMiddleware(ASGIMiddleware): - def __init__(self, app, provider: ServiceProvider, logger: APILogger): + def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao): ASGIMiddleware.__init__(self, app) self._provider = provider self._logger = logger + self._keycloak = keycloak + self._user_dao = user_dao + self._ctx_token = None async def __call__(self, scope: Scope, receive: Receive, send: Send): @@ -30,6 +37,7 @@ class RequestMiddleware(ASGIMiddleware): await self.set_request_data(request) try: + await self._try_set_user(request) with self._provider.create_scope(): inject(await self._app(scope, receive, send)) finally: @@ -53,6 +61,36 @@ class RequestMiddleware(ASGIMiddleware): self._logger.trace(f"Clearing current request: {request.state.request_id}") _request_context.reset(self._ctx_token) + async def _try_set_user(self, request: Request): + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + return + + token = auth_header.split("Bearer ")[1] + try: + token_info = self._keycloak.introspect(token) + if not token_info.get("active", False): + return + + keycloak_id = self._keycloak.get_user_id(token) + if not keycloak_id: + return + + user = await self._user_dao.find_by_keycloak_id(keycloak_id) + if not user: + user = AuthUser(0, keycloak_id) + uid = await self._user_dao.create(user) + user = await self._user_dao.get_by_id(uid) + + if user.deleted: + return + + request.state.user = user + set_user(user) + self._logger.trace(f"User {user.id} bound to request {request.state.request_id}") + + except Exception as e: + self._logger.debug(f"Silent user binding failed: {e}") def get_request() -> Optional[TRequest]: return _request_context.get() diff --git a/src/cpl-graphql/cpl/graphql/_endpoints/graphql.py b/src/cpl-graphql/cpl/graphql/_endpoints/graphql.py index 0808d704..01cb133b 100644 --- a/src/cpl-graphql/cpl/graphql/_endpoints/graphql.py +++ b/src/cpl-graphql/cpl/graphql/_endpoints/graphql.py @@ -1,7 +1,7 @@ from starlette.requests import Request from starlette.responses import Response, JSONResponse -from cpl.graphql.service.service import GraphQLService +from cpl.graphql.service.graphql import GraphQLService async def graphql_endpoint(request: Request, service: GraphQLService) -> Response: diff --git a/src/cpl-graphql/cpl/graphql/error.py b/src/cpl-graphql/cpl/graphql/error.py new file mode 100644 index 00000000..e96e41c1 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/error.py @@ -0,0 +1,14 @@ +from graphql import GraphQLError + +from cpl.api import APIError + + +def graphql_error(api_error: APIError) -> GraphQLError: + """Convert an APIError (from cpl-api) into a GraphQL-friendly error.""" + return GraphQLError( + message=api_error.error_message, + extensions={ + "code": api_error.status_code, + }, + original_error=api_error, + ) \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/graphql_module.py b/src/cpl-graphql/cpl/graphql/graphql_module.py index 2d5d6b93..d9d66aee 100644 --- a/src/cpl-graphql/cpl/graphql/graphql_module.py +++ b/src/cpl-graphql/cpl/graphql/graphql_module.py @@ -8,7 +8,7 @@ from cpl.graphql.schema.filter.int_filter import IntFilter from cpl.graphql.schema.filter.string_filter import StringFilter from cpl.graphql.schema.root_query import RootQuery from cpl.graphql.service.schema import Schema -from cpl.graphql.service.service import GraphQLService +from cpl.graphql.service.graphql import GraphQLService class GraphQLModule(Module): diff --git a/src/cpl-graphql/cpl/graphql/query_context.py b/src/cpl-graphql/cpl/graphql/query_context.py new file mode 100644 index 00000000..79d0b965 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/query_context.py @@ -0,0 +1,90 @@ +from enum import Enum +from typing import Optional, Any + +from graphql import GraphQLResolveInfo + +from cpl.auth.schema import AuthUser, Permission +from cpl.core.utils import get_value + + +class QueryContext: + + def __init__( + self, + data: Any, + user: Optional[AuthUser], + user_permissions: Optional[list[Enum | Permission]], + is_mutation: bool = False, + *args, + **kwargs + ): + + self._data = data + self._user = user + if user_permissions is None: + user_permissions = [] + self._user_permissions: list[str] = [x.name for x in user_permissions] + + self._resolve_info = None + for arg in args: + if isinstance(arg, GraphQLResolveInfo): + self._resolve_info = arg + continue + + self._filter = kwargs.get("filters", {}) + self._sort = kwargs.get("sort", {}) + self._skip = get_value(kwargs, "skip", int) + self._take = get_value(kwargs, "take", int) + + self._input = kwargs.get("input", None) + self._args = args + self._kwargs = kwargs + + self._is_mutation = is_mutation + + @property + def data(self): + return self._data + + @property + def user(self) -> AuthUser: + return self._user + + @property + def resolve_info(self) -> Optional[GraphQLResolveInfo]: + return self._resolve_info + + @property + def filter(self) -> dict: + return self._filter + + @property + def sort(self) -> dict: + return self._sort + + @property + def skip(self) -> Optional[int]: + return self._skip + + @property + def take(self) -> Optional[int]: + return self._take + + @property + def input(self) -> Optional[Any]: + return self._input + + @property + def args(self) -> tuple: + return self._args + + @property + def kwargs(self) -> dict: + return self._kwargs + + @property + def is_mutation(self) -> bool: + return self._is_mutation + + def has_permission(self, permission: Enum | str) -> bool: + return permission.value if isinstance(permission, Enum) else permission in self._user_permissions diff --git a/src/cpl-graphql/cpl/graphql/schema/field.py b/src/cpl-graphql/cpl/graphql/schema/field.py index 2231e11c..d9417bdb 100644 --- a/src/cpl-graphql/cpl/graphql/schema/field.py +++ b/src/cpl-graphql/cpl/graphql/schema/field.py @@ -1,7 +1,8 @@ +from enum import Enum from typing import Self from cpl.graphql.schema.argument import Argument -from cpl.graphql.typing import TQuery, Resolver +from cpl.graphql.typing import TQuery, Resolver, TRequireAnyPermissions, TRequireAnyResolvers class Field: @@ -9,7 +10,7 @@ class Field: def __init__( self, name: str, - gql_type: type = None, + t: type = None, resolver: Resolver = None, optional=None, default=None, @@ -17,7 +18,7 @@ class Field: parent_type=None, ): self._name = name - self._gql_type = gql_type + self._type = t self._resolver = resolver self._optional = optional or True self._default = default @@ -26,6 +27,9 @@ class Field: self._parent_type = parent_type self._args: dict[str, Argument] = {} + self._require_any_permission = None + self._require_any = None + self._public = False @property def name(self) -> str: @@ -33,7 +37,7 @@ class Field: @property def type(self) -> type: - return self._gql_type + return self._type @property def resolver(self) -> callable: @@ -63,6 +67,34 @@ class Field: def arguments(self) -> dict[str, Argument]: return self._args + @property + def require_any_permission(self) -> TRequireAnyPermissions | None: + return self._require_any_permission + + @property + def require_any(self) -> TRequireAnyResolvers | None: + return self._require_any + + @property + def public(self) -> bool: + return self._public + + def with_type(self, t: type) -> Self: + self._type = t + return self + + def with_resolver(self, resolver: Resolver) -> Self: + self._resolver = resolver + return self + + def with_optional(self, optional: bool) -> Self: + self._optional = optional + return self + + def with_default(self, default) -> Self: + self._default = default + return self + def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self: if name in self._args: raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'") @@ -76,3 +108,21 @@ class Field: self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional) return self + + def with_require_any_permission(self, permissions: TRequireAnyPermissions) -> Self: + assert permissions is not None, "require_any_permission cannot be None" + assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type" + self._require_any_permission = permissions + return self + + def with_require_any(self, permissions: TRequireAnyPermissions, resolvers: TRequireAnyResolvers) -> Self: + assert permissions is not None, "permissions cannot be None" + assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type" + assert resolvers is not None, "resolvers cannot be None" + assert all(callable(r) for r in resolvers), "All resolvers must be callable" + self._require_any = (permissions, resolvers) + return self + + def with_public(self, public: bool = False) -> Self: + self._public = public + return self diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py index 84270056..2f0e23f0 100644 --- a/src/cpl-graphql/cpl/graphql/schema/query.py +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -1,13 +1,19 @@ +import asyncio +import functools import inspect from typing import Callable, Type, Any, Optional import strawberry from strawberry.exceptions import StrawberryException +from cpl.api import Unauthorized, Forbidden +from cpl.api.middleware.request import get_request +from cpl.core.ctx import get_user from cpl.database.abc.data_access_object_abc import DataAccessObjectABC from cpl.dependency.inject import inject from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol +from cpl.graphql.error import graphql_error from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory from cpl.graphql.schema.field import Field from cpl.graphql.schema.sort.sort_order import SortOrder @@ -202,23 +208,70 @@ class Query(StrawberryProtocol): _resolver.__signature__ = sig return _resolver + def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable: + # Signatur vom Original übernehmen + sig = getattr(resolver, "__signature__", None) + + @functools.wraps(resolver) + async def _auth_resolver(*args, **kwargs): + request = get_request() + user = get_user() + + # Public + if f.public: + return await self._maybe_await(resolver(*args, **kwargs)) + + # Auth required + if user is None: + raise graphql_error(Unauthorized("Authentication required")) + + # Permissions + if f.require_any_permission: + if not any(user.has_permission(p) for p in f.require_any_permission): + raise Forbidden("Permission denied") + + # Custom resolvers + if f.require_any: + perms, resolvers = f.require_any + if not any(user.has_permission(p) for p in perms): + for r in resolvers: + ok = await self._maybe_await(r(user, *args, **kwargs)) + if ok: + break + else: + raise Forbidden("Permission denied") + + return await self._maybe_await(resolver(*args, **kwargs)) + + # Signatur beibehalten + if sig: + _auth_resolver.__signature__ = sig + + return _auth_resolver + + @staticmethod + def _maybe_await(value): + if asyncio.iscoroutine(value): + return value + return asyncio.sleep(0, result=value) # sofort resolved Future + + def _field_to_strawberry(self, f: Field) -> Any: + resolver = None try: - if f.resolver: + if f.arguments: + resolver = self._build_resolver(f) + elif not f.resolver: + resolver = lambda *_, **__: None + else: ann = getattr(f.resolver, "__annotations__", {}) if "return" not in ann or ann["return"] is None: ann = dict(ann) ann["return"] = f.type f.resolver.__annotations__ = ann + resolver = f.resolver - if f.arguments: - resolver = self._build_resolver(f) - return strawberry.field(resolver=resolver) - - if not f.resolver: - return strawberry.field(resolver=lambda *_, **__: None) - - return strawberry.field(resolver=f.resolver) + return strawberry.field(resolver=self._wrap_with_auth(f, resolver)) except StrawberryException as e: raise Exception( f"Error converting field '{f.name}' to strawberry field: {e}" diff --git a/src/cpl-graphql/cpl/graphql/service/graphql.py b/src/cpl-graphql/cpl/graphql/service/graphql.py new file mode 100644 index 00000000..c816b2e1 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/service/graphql.py @@ -0,0 +1,51 @@ +from typing import Any, Dict, Optional + +from graphql import GraphQLError + +from cpl.api import APILogger, APIError +from cpl.api.typing import TRequest +from cpl.graphql.service.schema import Schema + + +class GraphQLService: + def __init__(self, logger: APILogger, schema: Schema): + self._logger = logger + + if schema.schema is None: + raise ValueError("Schema has not been built. Call schema.build() before using the service.") + self._schema = schema.schema + + async def execute( + self, + query: str, + variables: Optional[Dict[str, Any]], + request: TRequest, + ) -> Dict[str, Any]: + result = await self._schema.execute( + query, + variable_values=variables, + context_value={"request": request}, + ) + + response_data: Dict[str, Any] = {} + if result.errors: + errors = [] + for error in result.errors: + if isinstance(error, GraphQLError): + self._logger.error(f"GraphQL APIError: {error}") + errors.append({"message": error.message, "extensions": error.extensions}) + continue + + if isinstance(error, APIError): + self._logger.error(f"GraphQL APIError: {error}") + errors.append({"message": error.error_message, "extensions": {"code": error.status_code}}) + continue + + self._logger.error(f"GraphQL unexpected error: {error}") + errors.append({"message": str(error), "extensions": {"code": 500}}) + + response_data["errors"] = errors + if result.data: + response_data["data"] = result.data + + return response_data diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py index 23627ee4..f0b01b05 100644 --- a/src/cpl-graphql/cpl/graphql/service/schema.py +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -1,7 +1,11 @@ +import logging from typing import Type, Self import strawberry +from starlette.requests import Request +from strawberry.types import ExecutionContext +from cpl.api import APIError from cpl.api.logger import APILogger from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol @@ -40,7 +44,18 @@ class Schema: return types + def _graphql_exception_handler(self, error: Exception, execution_context: ExecutionContext): + request: Request = execution_context.context.get("request") + + if isinstance(error, APIError): + self._logger.error(f"GraphQL APIError: {error}") + return {"message": error.error_message, "extensions": {"code": error.status_code}} + + self._logger.error(f"GraphQL unexpected error: {error}") + return {"message": str(error), "extensions": {"code": 500}} + def build(self) -> strawberry.Schema: + logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL) query = self._provider.get_service(RootQuery) if not query: raise ValueError("RootQuery not registered in service provider") diff --git a/src/cpl-graphql/cpl/graphql/service/service.py b/src/cpl-graphql/cpl/graphql/service/service.py deleted file mode 100644 index f039ccbd..00000000 --- a/src/cpl-graphql/cpl/graphql/service/service.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Any, Dict, Optional - -from cpl.api.typing import TRequest -from cpl.graphql.service.schema import Schema - - -class GraphQLService: - def __init__(self, schema: Schema): - if schema.schema is None: - raise ValueError("Schema has not been built. Call schema.build() before using the service.") - self._schema = schema.schema - - async def execute( - self, - query: str, - variables: Optional[Dict[str, Any]], - request: TRequest, - ) -> Dict[str, Any]: - result = await self._schema.execute( - query, - variable_values=variables, - context_value={"request": request}, - ) - - response_data: Dict[str, Any] = {} - if result.errors: - response_data["errors"] = [str(e) for e in result.errors] - if result.data: - response_data["data"] = result.data - - return response_data diff --git a/src/cpl-graphql/cpl/graphql/typing.py b/src/cpl-graphql/cpl/graphql/typing.py index d5b63494..d36e3119 100644 --- a/src/cpl-graphql/cpl/graphql/typing.py +++ b/src/cpl-graphql/cpl/graphql/typing.py @@ -1,5 +1,15 @@ -from typing import Type, Callable +from enum import Enum +from typing import Type, Callable, List, Tuple, Awaitable + +from cpl.auth.permission import Permissions +from cpl.graphql.query_context import QueryContext TQuery = Type["Query"] Resolver = Callable -ScalarType = str | int | float | bool | object \ No newline at end of file +ScalarType = str | int | float | bool | object + +TRequireAnyPermissions = List[Enum | Permissions] | None +TRequireAnyResolvers = List[ + Callable[[QueryContext], bool | Awaitable[bool]], +] +TRequireAny = Tuple[TRequireAnyPermissions, TRequireAnyResolvers]