[WIP] with authentication #181

This commit is contained in:
2025-09-28 11:45:51 +02:00
parent 20e5da5770
commit 6f46b94998
15 changed files with 362 additions and 57 deletions

View File

@@ -85,7 +85,7 @@ def main():
schema.query.string_field("ping", resolver=lambda: "pong")
schema.query.with_query("hello", HelloQuery)
schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort)
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort)
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort).with_public(True)
app.with_playground()
app.with_graphiql()

View File

@@ -26,12 +26,12 @@ class AuthorGraphType(GraphType[Author]):
self.int_field(
"id",
resolver=lambda root: root.id,
)
).with_public(True)
self.string_field(
"firstName",
resolver=lambda root: root.first_name,
)
).with_public(True)
self.string_field(
"lastName",
resolver=lambda root: root.last_name,
)
).with_public(True)

View File

@@ -29,7 +29,7 @@ class PostGraphType(GraphType[Post]):
self.int_field(
"id",
resolver=lambda root: root.id,
)
).with_public(True)
async def _a(root: Post):
return await authors.get_by_id(root.author_id)
@@ -38,12 +38,12 @@ class PostGraphType(GraphType[Post]):
"author",
AuthorGraphType,
resolver=_a#lambda root: root.author_id,
)
).with_public(True)
self.string_field(
"title",
resolver=lambda root: root.title,
)
).with_public(True)
self.string_field(
"content",
resolver=lambda root: root.content,
)
).with_public(True)

View File

@@ -25,6 +25,21 @@ class AuthenticationMiddleware(ASGIMiddleware):
request = get_request()
url = request.url.path
if url not in Router.get_auth_required_routes():
self._logger.trace(f"No authentication required for {url}")
return await self._app(scope, receive, send)
user = getattr(request.state, "user", None)
if not user or user.deleted:
self._logger.debug(f"Unauthorized access to {url}, user missing or deleted")
return await Unauthorized("Unauthorized").asgi_response(scope, receive, send)
return await self._call_next(scope, receive, send)
async def _old_call__(self, scope: Scope, receive: Receive, send: Send):
request = get_request()
url = request.url.path
if url not in Router.get_auth_required_routes():
self._logger.trace(f"No authentication required for {url}")
return await self._app(scope, receive, send)

View File

@@ -9,6 +9,10 @@ from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger
from cpl.api.typing import TRequest
from cpl.auth.keycloak.keycloak_client import KeycloakClient
from cpl.auth.schema import AuthUser
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
from cpl.core.ctx import set_user
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
@@ -17,12 +21,15 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
class RequestMiddleware(ASGIMiddleware):
def __init__(self, app, provider: ServiceProvider, logger: APILogger):
def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app)
self._provider = provider
self._logger = logger
self._keycloak = keycloak
self._user_dao = user_dao
self._ctx_token = None
async def __call__(self, scope: Scope, receive: Receive, send: Send):
@@ -30,6 +37,7 @@ class RequestMiddleware(ASGIMiddleware):
await self.set_request_data(request)
try:
await self._try_set_user(request)
with self._provider.create_scope():
inject(await self._app(scope, receive, send))
finally:
@@ -53,6 +61,36 @@ class RequestMiddleware(ASGIMiddleware):
self._logger.trace(f"Clearing current request: {request.state.request_id}")
_request_context.reset(self._ctx_token)
async def _try_set_user(self, request: Request):
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
return
token = auth_header.split("Bearer ")[1]
try:
token_info = self._keycloak.introspect(token)
if not token_info.get("active", False):
return
keycloak_id = self._keycloak.get_user_id(token)
if not keycloak_id:
return
user = await self._user_dao.find_by_keycloak_id(keycloak_id)
if not user:
user = AuthUser(0, keycloak_id)
uid = await self._user_dao.create(user)
user = await self._user_dao.get_by_id(uid)
if user.deleted:
return
request.state.user = user
set_user(user)
self._logger.trace(f"User {user.id} bound to request {request.state.request_id}")
except Exception as e:
self._logger.debug(f"Silent user binding failed: {e}")
def get_request() -> Optional[TRequest]:
return _request_context.get()

View File

@@ -1,7 +1,7 @@
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from cpl.graphql.service.service import GraphQLService
from cpl.graphql.service.graphql import GraphQLService
async def graphql_endpoint(request: Request, service: GraphQLService) -> Response:

View File

@@ -0,0 +1,14 @@
from graphql import GraphQLError
from cpl.api import APIError
def graphql_error(api_error: APIError) -> GraphQLError:
"""Convert an APIError (from cpl-api) into a GraphQL-friendly error."""
return GraphQLError(
message=api_error.error_message,
extensions={
"code": api_error.status_code,
},
original_error=api_error,
)

View File

@@ -8,7 +8,7 @@ from cpl.graphql.schema.filter.int_filter import IntFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
from cpl.graphql.schema.root_query import RootQuery
from cpl.graphql.service.schema import Schema
from cpl.graphql.service.service import GraphQLService
from cpl.graphql.service.graphql import GraphQLService
class GraphQLModule(Module):

View File

@@ -0,0 +1,90 @@
from enum import Enum
from typing import Optional, Any
from graphql import GraphQLResolveInfo
from cpl.auth.schema import AuthUser, Permission
from cpl.core.utils import get_value
class QueryContext:
def __init__(
self,
data: Any,
user: Optional[AuthUser],
user_permissions: Optional[list[Enum | Permission]],
is_mutation: bool = False,
*args,
**kwargs
):
self._data = data
self._user = user
if user_permissions is None:
user_permissions = []
self._user_permissions: list[str] = [x.name for x in user_permissions]
self._resolve_info = None
for arg in args:
if isinstance(arg, GraphQLResolveInfo):
self._resolve_info = arg
continue
self._filter = kwargs.get("filters", {})
self._sort = kwargs.get("sort", {})
self._skip = get_value(kwargs, "skip", int)
self._take = get_value(kwargs, "take", int)
self._input = kwargs.get("input", None)
self._args = args
self._kwargs = kwargs
self._is_mutation = is_mutation
@property
def data(self):
return self._data
@property
def user(self) -> AuthUser:
return self._user
@property
def resolve_info(self) -> Optional[GraphQLResolveInfo]:
return self._resolve_info
@property
def filter(self) -> dict:
return self._filter
@property
def sort(self) -> dict:
return self._sort
@property
def skip(self) -> Optional[int]:
return self._skip
@property
def take(self) -> Optional[int]:
return self._take
@property
def input(self) -> Optional[Any]:
return self._input
@property
def args(self) -> tuple:
return self._args
@property
def kwargs(self) -> dict:
return self._kwargs
@property
def is_mutation(self) -> bool:
return self._is_mutation
def has_permission(self, permission: Enum | str) -> bool:
return permission.value if isinstance(permission, Enum) else permission in self._user_permissions

View File

@@ -1,7 +1,8 @@
from enum import Enum
from typing import Self
from cpl.graphql.schema.argument import Argument
from cpl.graphql.typing import TQuery, Resolver
from cpl.graphql.typing import TQuery, Resolver, TRequireAnyPermissions, TRequireAnyResolvers
class Field:
@@ -9,7 +10,7 @@ class Field:
def __init__(
self,
name: str,
gql_type: type = None,
t: type = None,
resolver: Resolver = None,
optional=None,
default=None,
@@ -17,7 +18,7 @@ class Field:
parent_type=None,
):
self._name = name
self._gql_type = gql_type
self._type = t
self._resolver = resolver
self._optional = optional or True
self._default = default
@@ -26,6 +27,9 @@ class Field:
self._parent_type = parent_type
self._args: dict[str, Argument] = {}
self._require_any_permission = None
self._require_any = None
self._public = False
@property
def name(self) -> str:
@@ -33,7 +37,7 @@ class Field:
@property
def type(self) -> type:
return self._gql_type
return self._type
@property
def resolver(self) -> callable:
@@ -63,6 +67,34 @@ class Field:
def arguments(self) -> dict[str, Argument]:
return self._args
@property
def require_any_permission(self) -> TRequireAnyPermissions | None:
return self._require_any_permission
@property
def require_any(self) -> TRequireAnyResolvers | None:
return self._require_any
@property
def public(self) -> bool:
return self._public
def with_type(self, t: type) -> Self:
self._type = t
return self
def with_resolver(self, resolver: Resolver) -> Self:
self._resolver = resolver
return self
def with_optional(self, optional: bool) -> Self:
self._optional = optional
return self
def with_default(self, default) -> Self:
self._default = default
return self
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self:
if name in self._args:
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
@@ -76,3 +108,21 @@ class Field:
self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional)
return self
def with_require_any_permission(self, permissions: TRequireAnyPermissions) -> Self:
assert permissions is not None, "require_any_permission cannot be None"
assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type"
self._require_any_permission = permissions
return self
def with_require_any(self, permissions: TRequireAnyPermissions, resolvers: TRequireAnyResolvers) -> Self:
assert permissions is not None, "permissions cannot be None"
assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type"
assert resolvers is not None, "resolvers cannot be None"
assert all(callable(r) for r in resolvers), "All resolvers must be callable"
self._require_any = (permissions, resolvers)
return self
def with_public(self, public: bool = False) -> Self:
self._public = public
return self

View File

@@ -1,13 +1,19 @@
import asyncio
import functools
import inspect
from typing import Callable, Type, Any, Optional
import strawberry
from strawberry.exceptions import StrawberryException
from cpl.api import Unauthorized, Forbidden
from cpl.api.middleware.request import get_request
from cpl.core.ctx import get_user
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.error import graphql_error
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
from cpl.graphql.schema.field import Field
from cpl.graphql.schema.sort.sort_order import SortOrder
@@ -202,23 +208,70 @@ class Query(StrawberryProtocol):
_resolver.__signature__ = sig
return _resolver
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
# Signatur vom Original übernehmen
sig = getattr(resolver, "__signature__", None)
@functools.wraps(resolver)
async def _auth_resolver(*args, **kwargs):
request = get_request()
user = get_user()
# Public
if f.public:
return await self._maybe_await(resolver(*args, **kwargs))
# Auth required
if user is None:
raise graphql_error(Unauthorized("Authentication required"))
# Permissions
if f.require_any_permission:
if not any(user.has_permission(p) for p in f.require_any_permission):
raise Forbidden("Permission denied")
# Custom resolvers
if f.require_any:
perms, resolvers = f.require_any
if not any(user.has_permission(p) for p in perms):
for r in resolvers:
ok = await self._maybe_await(r(user, *args, **kwargs))
if ok:
break
else:
raise Forbidden("Permission denied")
return await self._maybe_await(resolver(*args, **kwargs))
# Signatur beibehalten
if sig:
_auth_resolver.__signature__ = sig
return _auth_resolver
@staticmethod
def _maybe_await(value):
if asyncio.iscoroutine(value):
return value
return asyncio.sleep(0, result=value) # sofort resolved Future
def _field_to_strawberry(self, f: Field) -> Any:
resolver = None
try:
if f.resolver:
if f.arguments:
resolver = self._build_resolver(f)
elif not f.resolver:
resolver = lambda *_, **__: None
else:
ann = getattr(f.resolver, "__annotations__", {})
if "return" not in ann or ann["return"] is None:
ann = dict(ann)
ann["return"] = f.type
f.resolver.__annotations__ = ann
resolver = f.resolver
if f.arguments:
resolver = self._build_resolver(f)
return strawberry.field(resolver=resolver)
if not f.resolver:
return strawberry.field(resolver=lambda *_, **__: None)
return strawberry.field(resolver=f.resolver)
return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
except StrawberryException as e:
raise Exception(
f"Error converting field '{f.name}' to strawberry field: {e}"

View File

@@ -0,0 +1,51 @@
from typing import Any, Dict, Optional
from graphql import GraphQLError
from cpl.api import APILogger, APIError
from cpl.api.typing import TRequest
from cpl.graphql.service.schema import Schema
class GraphQLService:
def __init__(self, logger: APILogger, schema: Schema):
self._logger = logger
if schema.schema is None:
raise ValueError("Schema has not been built. Call schema.build() before using the service.")
self._schema = schema.schema
async def execute(
self,
query: str,
variables: Optional[Dict[str, Any]],
request: TRequest,
) -> Dict[str, Any]:
result = await self._schema.execute(
query,
variable_values=variables,
context_value={"request": request},
)
response_data: Dict[str, Any] = {}
if result.errors:
errors = []
for error in result.errors:
if isinstance(error, GraphQLError):
self._logger.error(f"GraphQL APIError: {error}")
errors.append({"message": error.message, "extensions": error.extensions})
continue
if isinstance(error, APIError):
self._logger.error(f"GraphQL APIError: {error}")
errors.append({"message": error.error_message, "extensions": {"code": error.status_code}})
continue
self._logger.error(f"GraphQL unexpected error: {error}")
errors.append({"message": str(error), "extensions": {"code": 500}})
response_data["errors"] = errors
if result.data:
response_data["data"] = result.data
return response_data

View File

@@ -1,7 +1,11 @@
import logging
from typing import Type, Self
import strawberry
from starlette.requests import Request
from strawberry.types import ExecutionContext
from cpl.api import APIError
from cpl.api.logger import APILogger
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
@@ -40,7 +44,18 @@ class Schema:
return types
def _graphql_exception_handler(self, error: Exception, execution_context: ExecutionContext):
request: Request = execution_context.context.get("request")
if isinstance(error, APIError):
self._logger.error(f"GraphQL APIError: {error}")
return {"message": error.error_message, "extensions": {"code": error.status_code}}
self._logger.error(f"GraphQL unexpected error: {error}")
return {"message": str(error), "extensions": {"code": 500}}
def build(self) -> strawberry.Schema:
logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL)
query = self._provider.get_service(RootQuery)
if not query:
raise ValueError("RootQuery not registered in service provider")

View File

@@ -1,31 +0,0 @@
from typing import Any, Dict, Optional
from cpl.api.typing import TRequest
from cpl.graphql.service.schema import Schema
class GraphQLService:
def __init__(self, schema: Schema):
if schema.schema is None:
raise ValueError("Schema has not been built. Call schema.build() before using the service.")
self._schema = schema.schema
async def execute(
self,
query: str,
variables: Optional[Dict[str, Any]],
request: TRequest,
) -> Dict[str, Any]:
result = await self._schema.execute(
query,
variable_values=variables,
context_value={"request": request},
)
response_data: Dict[str, Any] = {}
if result.errors:
response_data["errors"] = [str(e) for e in result.errors]
if result.data:
response_data["data"] = result.data
return response_data

View File

@@ -1,5 +1,15 @@
from typing import Type, Callable
from enum import Enum
from typing import Type, Callable, List, Tuple, Awaitable
from cpl.auth.permission import Permissions
from cpl.graphql.query_context import QueryContext
TQuery = Type["Query"]
Resolver = Callable
ScalarType = str | int | float | bool | object
ScalarType = str | int | float | bool | object
TRequireAnyPermissions = List[Enum | Permissions] | None
TRequireAnyResolvers = List[
Callable[[QueryContext], bool | Awaitable[bool]],
]
TRequireAny = Tuple[TRequireAnyPermissions, TRequireAnyResolvers]