[WIP] with authentication #181
Some checks failed
Test before pr merge / test-lint (pull_request) Failing after 7s

This commit is contained in:
2025-09-28 11:45:51 +02:00
parent 652304a480
commit d42b89809b
15 changed files with 362 additions and 57 deletions

View File

@@ -85,7 +85,7 @@ def main():
schema.query.string_field("ping", resolver=lambda: "pong") schema.query.string_field("ping", resolver=lambda: "pong")
schema.query.with_query("hello", HelloQuery) schema.query.with_query("hello", HelloQuery)
schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort) schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort)
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort) schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort).with_public(True)
app.with_playground() app.with_playground()
app.with_graphiql() app.with_graphiql()

View File

@@ -26,12 +26,12 @@ class AuthorGraphType(GraphType[Author]):
self.int_field( self.int_field(
"id", "id",
resolver=lambda root: root.id, resolver=lambda root: root.id,
) ).with_public(True)
self.string_field( self.string_field(
"firstName", "firstName",
resolver=lambda root: root.first_name, resolver=lambda root: root.first_name,
) ).with_public(True)
self.string_field( self.string_field(
"lastName", "lastName",
resolver=lambda root: root.last_name, resolver=lambda root: root.last_name,
) ).with_public(True)

View File

@@ -29,7 +29,7 @@ class PostGraphType(GraphType[Post]):
self.int_field( self.int_field(
"id", "id",
resolver=lambda root: root.id, resolver=lambda root: root.id,
) ).with_public(True)
async def _a(root: Post): async def _a(root: Post):
return await authors.get_by_id(root.author_id) return await authors.get_by_id(root.author_id)
@@ -38,12 +38,12 @@ class PostGraphType(GraphType[Post]):
"author", "author",
AuthorGraphType, AuthorGraphType,
resolver=_a#lambda root: root.author_id, resolver=_a#lambda root: root.author_id,
) ).with_public(True)
self.string_field( self.string_field(
"title", "title",
resolver=lambda root: root.title, resolver=lambda root: root.title,
) ).with_public(True)
self.string_field( self.string_field(
"content", "content",
resolver=lambda root: root.content, resolver=lambda root: root.content,
) ).with_public(True)

View File

@@ -25,6 +25,21 @@ class AuthenticationMiddleware(ASGIMiddleware):
request = get_request() request = get_request()
url = request.url.path url = request.url.path
if url not in Router.get_auth_required_routes():
self._logger.trace(f"No authentication required for {url}")
return await self._app(scope, receive, send)
user = getattr(request.state, "user", None)
if not user or user.deleted:
self._logger.debug(f"Unauthorized access to {url}, user missing or deleted")
return await Unauthorized("Unauthorized").asgi_response(scope, receive, send)
return await self._call_next(scope, receive, send)
async def _old_call__(self, scope: Scope, receive: Receive, send: Send):
request = get_request()
url = request.url.path
if url not in Router.get_auth_required_routes(): if url not in Router.get_auth_required_routes():
self._logger.trace(f"No authentication required for {url}") self._logger.trace(f"No authentication required for {url}")
return await self._app(scope, receive, send) return await self._app(scope, receive, send)

View File

@@ -9,6 +9,10 @@ from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.api.typing import TRequest from cpl.api.typing import TRequest
from cpl.auth.keycloak.keycloak_client import KeycloakClient
from cpl.auth.schema import AuthUser
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
from cpl.core.ctx import set_user
from cpl.dependency.inject import inject from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
@@ -17,12 +21,15 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
class RequestMiddleware(ASGIMiddleware): class RequestMiddleware(ASGIMiddleware):
def __init__(self, app, provider: ServiceProvider, logger: APILogger): def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)
self._provider = provider self._provider = provider
self._logger = logger self._logger = logger
self._keycloak = keycloak
self._user_dao = user_dao
self._ctx_token = None self._ctx_token = None
async def __call__(self, scope: Scope, receive: Receive, send: Send): async def __call__(self, scope: Scope, receive: Receive, send: Send):
@@ -30,6 +37,7 @@ class RequestMiddleware(ASGIMiddleware):
await self.set_request_data(request) await self.set_request_data(request)
try: try:
await self._try_set_user(request)
with self._provider.create_scope(): with self._provider.create_scope():
inject(await self._app(scope, receive, send)) inject(await self._app(scope, receive, send))
finally: finally:
@@ -53,6 +61,36 @@ class RequestMiddleware(ASGIMiddleware):
self._logger.trace(f"Clearing current request: {request.state.request_id}") self._logger.trace(f"Clearing current request: {request.state.request_id}")
_request_context.reset(self._ctx_token) _request_context.reset(self._ctx_token)
async def _try_set_user(self, request: Request):
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
return
token = auth_header.split("Bearer ")[1]
try:
token_info = self._keycloak.introspect(token)
if not token_info.get("active", False):
return
keycloak_id = self._keycloak.get_user_id(token)
if not keycloak_id:
return
user = await self._user_dao.find_by_keycloak_id(keycloak_id)
if not user:
user = AuthUser(0, keycloak_id)
uid = await self._user_dao.create(user)
user = await self._user_dao.get_by_id(uid)
if user.deleted:
return
request.state.user = user
set_user(user)
self._logger.trace(f"User {user.id} bound to request {request.state.request_id}")
except Exception as e:
self._logger.debug(f"Silent user binding failed: {e}")
def get_request() -> Optional[TRequest]: def get_request() -> Optional[TRequest]:
return _request_context.get() return _request_context.get()

View File

@@ -1,7 +1,7 @@
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response, JSONResponse from starlette.responses import Response, JSONResponse
from cpl.graphql.service.service import GraphQLService from cpl.graphql.service.graphql import GraphQLService
async def graphql_endpoint(request: Request, service: GraphQLService) -> Response: async def graphql_endpoint(request: Request, service: GraphQLService) -> Response:

View File

@@ -0,0 +1,14 @@
from graphql import GraphQLError
from cpl.api import APIError
def graphql_error(api_error: APIError) -> GraphQLError:
"""Convert an APIError (from cpl-api) into a GraphQL-friendly error."""
return GraphQLError(
message=api_error.error_message,
extensions={
"code": api_error.status_code,
},
original_error=api_error,
)

View File

@@ -8,7 +8,7 @@ from cpl.graphql.schema.filter.int_filter import IntFilter
from cpl.graphql.schema.filter.string_filter import StringFilter from cpl.graphql.schema.filter.string_filter import StringFilter
from cpl.graphql.schema.root_query import RootQuery from cpl.graphql.schema.root_query import RootQuery
from cpl.graphql.service.schema import Schema from cpl.graphql.service.schema import Schema
from cpl.graphql.service.service import GraphQLService from cpl.graphql.service.graphql import GraphQLService
class GraphQLModule(Module): class GraphQLModule(Module):

View File

@@ -0,0 +1,90 @@
from enum import Enum
from typing import Optional, Any
from graphql import GraphQLResolveInfo
from cpl.auth.schema import AuthUser, Permission
from cpl.core.utils import get_value
class QueryContext:
def __init__(
self,
data: Any,
user: Optional[AuthUser],
user_permissions: Optional[list[Enum | Permission]],
is_mutation: bool = False,
*args,
**kwargs
):
self._data = data
self._user = user
if user_permissions is None:
user_permissions = []
self._user_permissions: list[str] = [x.name for x in user_permissions]
self._resolve_info = None
for arg in args:
if isinstance(arg, GraphQLResolveInfo):
self._resolve_info = arg
continue
self._filter = kwargs.get("filters", {})
self._sort = kwargs.get("sort", {})
self._skip = get_value(kwargs, "skip", int)
self._take = get_value(kwargs, "take", int)
self._input = kwargs.get("input", None)
self._args = args
self._kwargs = kwargs
self._is_mutation = is_mutation
@property
def data(self):
return self._data
@property
def user(self) -> AuthUser:
return self._user
@property
def resolve_info(self) -> Optional[GraphQLResolveInfo]:
return self._resolve_info
@property
def filter(self) -> dict:
return self._filter
@property
def sort(self) -> dict:
return self._sort
@property
def skip(self) -> Optional[int]:
return self._skip
@property
def take(self) -> Optional[int]:
return self._take
@property
def input(self) -> Optional[Any]:
return self._input
@property
def args(self) -> tuple:
return self._args
@property
def kwargs(self) -> dict:
return self._kwargs
@property
def is_mutation(self) -> bool:
return self._is_mutation
def has_permission(self, permission: Enum | str) -> bool:
return permission.value if isinstance(permission, Enum) else permission in self._user_permissions

View File

@@ -1,7 +1,8 @@
from enum import Enum
from typing import Self from typing import Self
from cpl.graphql.schema.argument import Argument from cpl.graphql.schema.argument import Argument
from cpl.graphql.typing import TQuery, Resolver from cpl.graphql.typing import TQuery, Resolver, TRequireAnyPermissions, TRequireAnyResolvers
class Field: class Field:
@@ -9,7 +10,7 @@ class Field:
def __init__( def __init__(
self, self,
name: str, name: str,
gql_type: type = None, t: type = None,
resolver: Resolver = None, resolver: Resolver = None,
optional=None, optional=None,
default=None, default=None,
@@ -17,7 +18,7 @@ class Field:
parent_type=None, parent_type=None,
): ):
self._name = name self._name = name
self._gql_type = gql_type self._type = t
self._resolver = resolver self._resolver = resolver
self._optional = optional or True self._optional = optional or True
self._default = default self._default = default
@@ -26,6 +27,9 @@ class Field:
self._parent_type = parent_type self._parent_type = parent_type
self._args: dict[str, Argument] = {} self._args: dict[str, Argument] = {}
self._require_any_permission = None
self._require_any = None
self._public = False
@property @property
def name(self) -> str: def name(self) -> str:
@@ -33,7 +37,7 @@ class Field:
@property @property
def type(self) -> type: def type(self) -> type:
return self._gql_type return self._type
@property @property
def resolver(self) -> callable: def resolver(self) -> callable:
@@ -63,6 +67,34 @@ class Field:
def arguments(self) -> dict[str, Argument]: def arguments(self) -> dict[str, Argument]:
return self._args return self._args
@property
def require_any_permission(self) -> TRequireAnyPermissions | None:
return self._require_any_permission
@property
def require_any(self) -> TRequireAnyResolvers | None:
return self._require_any
@property
def public(self) -> bool:
return self._public
def with_type(self, t: type) -> Self:
self._type = t
return self
def with_resolver(self, resolver: Resolver) -> Self:
self._resolver = resolver
return self
def with_optional(self, optional: bool) -> Self:
self._optional = optional
return self
def with_default(self, default) -> Self:
self._default = default
return self
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self: def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self:
if name in self._args: if name in self._args:
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'") raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
@@ -76,3 +108,21 @@ class Field:
self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional) self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional)
return self return self
def with_require_any_permission(self, permissions: TRequireAnyPermissions) -> Self:
assert permissions is not None, "require_any_permission cannot be None"
assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type"
self._require_any_permission = permissions
return self
def with_require_any(self, permissions: TRequireAnyPermissions, resolvers: TRequireAnyResolvers) -> Self:
assert permissions is not None, "permissions cannot be None"
assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type"
assert resolvers is not None, "resolvers cannot be None"
assert all(callable(r) for r in resolvers), "All resolvers must be callable"
self._require_any = (permissions, resolvers)
return self
def with_public(self, public: bool = False) -> Self:
self._public = public
return self

View File

@@ -1,13 +1,19 @@
import asyncio
import functools
import inspect import inspect
from typing import Callable, Type, Any, Optional from typing import Callable, Type, Any, Optional
import strawberry import strawberry
from strawberry.exceptions import StrawberryException from strawberry.exceptions import StrawberryException
from cpl.api import Unauthorized, Forbidden
from cpl.api.middleware.request import get_request
from cpl.core.ctx import get_user
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.dependency.inject import inject from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.error import graphql_error
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
from cpl.graphql.schema.field import Field from cpl.graphql.schema.field import Field
from cpl.graphql.schema.sort.sort_order import SortOrder from cpl.graphql.schema.sort.sort_order import SortOrder
@@ -202,23 +208,70 @@ class Query(StrawberryProtocol):
_resolver.__signature__ = sig _resolver.__signature__ = sig
return _resolver return _resolver
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
# Signatur vom Original übernehmen
sig = getattr(resolver, "__signature__", None)
@functools.wraps(resolver)
async def _auth_resolver(*args, **kwargs):
request = get_request()
user = get_user()
# Public
if f.public:
return await self._maybe_await(resolver(*args, **kwargs))
# Auth required
if user is None:
raise graphql_error(Unauthorized("Authentication required"))
# Permissions
if f.require_any_permission:
if not any(user.has_permission(p) for p in f.require_any_permission):
raise Forbidden("Permission denied")
# Custom resolvers
if f.require_any:
perms, resolvers = f.require_any
if not any(user.has_permission(p) for p in perms):
for r in resolvers:
ok = await self._maybe_await(r(user, *args, **kwargs))
if ok:
break
else:
raise Forbidden("Permission denied")
return await self._maybe_await(resolver(*args, **kwargs))
# Signatur beibehalten
if sig:
_auth_resolver.__signature__ = sig
return _auth_resolver
@staticmethod
def _maybe_await(value):
if asyncio.iscoroutine(value):
return value
return asyncio.sleep(0, result=value) # sofort resolved Future
def _field_to_strawberry(self, f: Field) -> Any: def _field_to_strawberry(self, f: Field) -> Any:
resolver = None
try: try:
if f.resolver: if f.arguments:
resolver = self._build_resolver(f)
elif not f.resolver:
resolver = lambda *_, **__: None
else:
ann = getattr(f.resolver, "__annotations__", {}) ann = getattr(f.resolver, "__annotations__", {})
if "return" not in ann or ann["return"] is None: if "return" not in ann or ann["return"] is None:
ann = dict(ann) ann = dict(ann)
ann["return"] = f.type ann["return"] = f.type
f.resolver.__annotations__ = ann f.resolver.__annotations__ = ann
resolver = f.resolver
if f.arguments: return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
resolver = self._build_resolver(f)
return strawberry.field(resolver=resolver)
if not f.resolver:
return strawberry.field(resolver=lambda *_, **__: None)
return strawberry.field(resolver=f.resolver)
except StrawberryException as e: except StrawberryException as e:
raise Exception( raise Exception(
f"Error converting field '{f.name}' to strawberry field: {e}" f"Error converting field '{f.name}' to strawberry field: {e}"

View File

@@ -0,0 +1,51 @@
from typing import Any, Dict, Optional
from graphql import GraphQLError
from cpl.api import APILogger, APIError
from cpl.api.typing import TRequest
from cpl.graphql.service.schema import Schema
class GraphQLService:
def __init__(self, logger: APILogger, schema: Schema):
self._logger = logger
if schema.schema is None:
raise ValueError("Schema has not been built. Call schema.build() before using the service.")
self._schema = schema.schema
async def execute(
self,
query: str,
variables: Optional[Dict[str, Any]],
request: TRequest,
) -> Dict[str, Any]:
result = await self._schema.execute(
query,
variable_values=variables,
context_value={"request": request},
)
response_data: Dict[str, Any] = {}
if result.errors:
errors = []
for error in result.errors:
if isinstance(error, GraphQLError):
self._logger.error(f"GraphQL APIError: {error}")
errors.append({"message": error.message, "extensions": error.extensions})
continue
if isinstance(error, APIError):
self._logger.error(f"GraphQL APIError: {error}")
errors.append({"message": error.error_message, "extensions": {"code": error.status_code}})
continue
self._logger.error(f"GraphQL unexpected error: {error}")
errors.append({"message": str(error), "extensions": {"code": 500}})
response_data["errors"] = errors
if result.data:
response_data["data"] = result.data
return response_data

View File

@@ -1,7 +1,11 @@
import logging
from typing import Type, Self from typing import Type, Self
import strawberry import strawberry
from starlette.requests import Request
from strawberry.types import ExecutionContext
from cpl.api import APIError
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
@@ -40,7 +44,18 @@ class Schema:
return types return types
def _graphql_exception_handler(self, error: Exception, execution_context: ExecutionContext):
request: Request = execution_context.context.get("request")
if isinstance(error, APIError):
self._logger.error(f"GraphQL APIError: {error}")
return {"message": error.error_message, "extensions": {"code": error.status_code}}
self._logger.error(f"GraphQL unexpected error: {error}")
return {"message": str(error), "extensions": {"code": 500}}
def build(self) -> strawberry.Schema: def build(self) -> strawberry.Schema:
logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL)
query = self._provider.get_service(RootQuery) query = self._provider.get_service(RootQuery)
if not query: if not query:
raise ValueError("RootQuery not registered in service provider") raise ValueError("RootQuery not registered in service provider")

View File

@@ -1,31 +0,0 @@
from typing import Any, Dict, Optional
from cpl.api.typing import TRequest
from cpl.graphql.service.schema import Schema
class GraphQLService:
def __init__(self, schema: Schema):
if schema.schema is None:
raise ValueError("Schema has not been built. Call schema.build() before using the service.")
self._schema = schema.schema
async def execute(
self,
query: str,
variables: Optional[Dict[str, Any]],
request: TRequest,
) -> Dict[str, Any]:
result = await self._schema.execute(
query,
variable_values=variables,
context_value={"request": request},
)
response_data: Dict[str, Any] = {}
if result.errors:
response_data["errors"] = [str(e) for e in result.errors]
if result.data:
response_data["data"] = result.data
return response_data

View File

@@ -1,5 +1,15 @@
from typing import Type, Callable from enum import Enum
from typing import Type, Callable, List, Tuple, Awaitable
from cpl.auth.permission import Permissions
from cpl.graphql.query_context import QueryContext
TQuery = Type["Query"] TQuery = Type["Query"]
Resolver = Callable Resolver = Callable
ScalarType = str | int | float | bool | object ScalarType = str | int | float | bool | object
TRequireAnyPermissions = List[Enum | Permissions] | None
TRequireAnyResolvers = List[
Callable[[QueryContext], bool | Awaitable[bool]],
]
TRequireAny = Tuple[TRequireAnyPermissions, TRequireAnyResolvers]