[WIP] with authentication #181
This commit is contained in:
@@ -85,7 +85,7 @@ def main():
|
||||
schema.query.string_field("ping", resolver=lambda: "pong")
|
||||
schema.query.with_query("hello", HelloQuery)
|
||||
schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort)
|
||||
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort)
|
||||
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort).with_public(True)
|
||||
|
||||
app.with_playground()
|
||||
app.with_graphiql()
|
||||
|
||||
@@ -26,12 +26,12 @@ class AuthorGraphType(GraphType[Author]):
|
||||
self.int_field(
|
||||
"id",
|
||||
resolver=lambda root: root.id,
|
||||
)
|
||||
).with_public(True)
|
||||
self.string_field(
|
||||
"firstName",
|
||||
resolver=lambda root: root.first_name,
|
||||
)
|
||||
).with_public(True)
|
||||
self.string_field(
|
||||
"lastName",
|
||||
resolver=lambda root: root.last_name,
|
||||
)
|
||||
).with_public(True)
|
||||
|
||||
@@ -29,7 +29,7 @@ class PostGraphType(GraphType[Post]):
|
||||
self.int_field(
|
||||
"id",
|
||||
resolver=lambda root: root.id,
|
||||
)
|
||||
).with_public(True)
|
||||
|
||||
async def _a(root: Post):
|
||||
return await authors.get_by_id(root.author_id)
|
||||
@@ -38,12 +38,12 @@ class PostGraphType(GraphType[Post]):
|
||||
"author",
|
||||
AuthorGraphType,
|
||||
resolver=_a#lambda root: root.author_id,
|
||||
)
|
||||
).with_public(True)
|
||||
self.string_field(
|
||||
"title",
|
||||
resolver=lambda root: root.title,
|
||||
)
|
||||
).with_public(True)
|
||||
self.string_field(
|
||||
"content",
|
||||
resolver=lambda root: root.content,
|
||||
)
|
||||
).with_public(True)
|
||||
|
||||
@@ -25,6 +25,21 @@ class AuthenticationMiddleware(ASGIMiddleware):
|
||||
request = get_request()
|
||||
url = request.url.path
|
||||
|
||||
if url not in Router.get_auth_required_routes():
|
||||
self._logger.trace(f"No authentication required for {url}")
|
||||
return await self._app(scope, receive, send)
|
||||
|
||||
user = getattr(request.state, "user", None)
|
||||
if not user or user.deleted:
|
||||
self._logger.debug(f"Unauthorized access to {url}, user missing or deleted")
|
||||
return await Unauthorized("Unauthorized").asgi_response(scope, receive, send)
|
||||
|
||||
return await self._call_next(scope, receive, send)
|
||||
|
||||
async def _old_call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
request = get_request()
|
||||
url = request.url.path
|
||||
|
||||
if url not in Router.get_auth_required_routes():
|
||||
self._logger.trace(f"No authentication required for {url}")
|
||||
return await self._app(scope, receive, send)
|
||||
|
||||
@@ -9,6 +9,10 @@ from starlette.types import Scope, Receive, Send
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.typing import TRequest
|
||||
from cpl.auth.keycloak.keycloak_client import KeycloakClient
|
||||
from cpl.auth.schema import AuthUser
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
from cpl.core.ctx import set_user
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
@@ -17,12 +21,15 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
|
||||
|
||||
class RequestMiddleware(ASGIMiddleware):
|
||||
|
||||
def __init__(self, app, provider: ServiceProvider, logger: APILogger):
|
||||
def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._provider = provider
|
||||
self._logger = logger
|
||||
|
||||
self._keycloak = keycloak
|
||||
self._user_dao = user_dao
|
||||
|
||||
self._ctx_token = None
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
@@ -30,6 +37,7 @@ class RequestMiddleware(ASGIMiddleware):
|
||||
await self.set_request_data(request)
|
||||
|
||||
try:
|
||||
await self._try_set_user(request)
|
||||
with self._provider.create_scope():
|
||||
inject(await self._app(scope, receive, send))
|
||||
finally:
|
||||
@@ -53,6 +61,36 @@ class RequestMiddleware(ASGIMiddleware):
|
||||
self._logger.trace(f"Clearing current request: {request.state.request_id}")
|
||||
_request_context.reset(self._ctx_token)
|
||||
|
||||
async def _try_set_user(self, request: Request):
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
return
|
||||
|
||||
token = auth_header.split("Bearer ")[1]
|
||||
try:
|
||||
token_info = self._keycloak.introspect(token)
|
||||
if not token_info.get("active", False):
|
||||
return
|
||||
|
||||
keycloak_id = self._keycloak.get_user_id(token)
|
||||
if not keycloak_id:
|
||||
return
|
||||
|
||||
user = await self._user_dao.find_by_keycloak_id(keycloak_id)
|
||||
if not user:
|
||||
user = AuthUser(0, keycloak_id)
|
||||
uid = await self._user_dao.create(user)
|
||||
user = await self._user_dao.get_by_id(uid)
|
||||
|
||||
if user.deleted:
|
||||
return
|
||||
|
||||
request.state.user = user
|
||||
set_user(user)
|
||||
self._logger.trace(f"User {user.id} bound to request {request.state.request_id}")
|
||||
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Silent user binding failed: {e}")
|
||||
|
||||
def get_request() -> Optional[TRequest]:
|
||||
return _request_context.get()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, JSONResponse
|
||||
|
||||
from cpl.graphql.service.service import GraphQLService
|
||||
from cpl.graphql.service.graphql import GraphQLService
|
||||
|
||||
|
||||
async def graphql_endpoint(request: Request, service: GraphQLService) -> Response:
|
||||
|
||||
14
src/cpl-graphql/cpl/graphql/error.py
Normal file
14
src/cpl-graphql/cpl/graphql/error.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from graphql import GraphQLError
|
||||
|
||||
from cpl.api import APIError
|
||||
|
||||
|
||||
def graphql_error(api_error: APIError) -> GraphQLError:
|
||||
"""Convert an APIError (from cpl-api) into a GraphQL-friendly error."""
|
||||
return GraphQLError(
|
||||
message=api_error.error_message,
|
||||
extensions={
|
||||
"code": api_error.status_code,
|
||||
},
|
||||
original_error=api_error,
|
||||
)
|
||||
@@ -8,7 +8,7 @@ from cpl.graphql.schema.filter.int_filter import IntFilter
|
||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||
from cpl.graphql.schema.root_query import RootQuery
|
||||
from cpl.graphql.service.schema import Schema
|
||||
from cpl.graphql.service.service import GraphQLService
|
||||
from cpl.graphql.service.graphql import GraphQLService
|
||||
|
||||
|
||||
class GraphQLModule(Module):
|
||||
|
||||
90
src/cpl-graphql/cpl/graphql/query_context.py
Normal file
90
src/cpl-graphql/cpl/graphql/query_context.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Any
|
||||
|
||||
from graphql import GraphQLResolveInfo
|
||||
|
||||
from cpl.auth.schema import AuthUser, Permission
|
||||
from cpl.core.utils import get_value
|
||||
|
||||
|
||||
class QueryContext:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Any,
|
||||
user: Optional[AuthUser],
|
||||
user_permissions: Optional[list[Enum | Permission]],
|
||||
is_mutation: bool = False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
self._data = data
|
||||
self._user = user
|
||||
if user_permissions is None:
|
||||
user_permissions = []
|
||||
self._user_permissions: list[str] = [x.name for x in user_permissions]
|
||||
|
||||
self._resolve_info = None
|
||||
for arg in args:
|
||||
if isinstance(arg, GraphQLResolveInfo):
|
||||
self._resolve_info = arg
|
||||
continue
|
||||
|
||||
self._filter = kwargs.get("filters", {})
|
||||
self._sort = kwargs.get("sort", {})
|
||||
self._skip = get_value(kwargs, "skip", int)
|
||||
self._take = get_value(kwargs, "take", int)
|
||||
|
||||
self._input = kwargs.get("input", None)
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
|
||||
self._is_mutation = is_mutation
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def user(self) -> AuthUser:
|
||||
return self._user
|
||||
|
||||
@property
|
||||
def resolve_info(self) -> Optional[GraphQLResolveInfo]:
|
||||
return self._resolve_info
|
||||
|
||||
@property
|
||||
def filter(self) -> dict:
|
||||
return self._filter
|
||||
|
||||
@property
|
||||
def sort(self) -> dict:
|
||||
return self._sort
|
||||
|
||||
@property
|
||||
def skip(self) -> Optional[int]:
|
||||
return self._skip
|
||||
|
||||
@property
|
||||
def take(self) -> Optional[int]:
|
||||
return self._take
|
||||
|
||||
@property
|
||||
def input(self) -> Optional[Any]:
|
||||
return self._input
|
||||
|
||||
@property
|
||||
def args(self) -> tuple:
|
||||
return self._args
|
||||
|
||||
@property
|
||||
def kwargs(self) -> dict:
|
||||
return self._kwargs
|
||||
|
||||
@property
|
||||
def is_mutation(self) -> bool:
|
||||
return self._is_mutation
|
||||
|
||||
def has_permission(self, permission: Enum | str) -> bool:
|
||||
return permission.value if isinstance(permission, Enum) else permission in self._user_permissions
|
||||
@@ -1,7 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Self
|
||||
|
||||
from cpl.graphql.schema.argument import Argument
|
||||
from cpl.graphql.typing import TQuery, Resolver
|
||||
from cpl.graphql.typing import TQuery, Resolver, TRequireAnyPermissions, TRequireAnyResolvers
|
||||
|
||||
|
||||
class Field:
|
||||
@@ -9,7 +10,7 @@ class Field:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
gql_type: type = None,
|
||||
t: type = None,
|
||||
resolver: Resolver = None,
|
||||
optional=None,
|
||||
default=None,
|
||||
@@ -17,7 +18,7 @@ class Field:
|
||||
parent_type=None,
|
||||
):
|
||||
self._name = name
|
||||
self._gql_type = gql_type
|
||||
self._type = t
|
||||
self._resolver = resolver
|
||||
self._optional = optional or True
|
||||
self._default = default
|
||||
@@ -26,6 +27,9 @@ class Field:
|
||||
self._parent_type = parent_type
|
||||
|
||||
self._args: dict[str, Argument] = {}
|
||||
self._require_any_permission = None
|
||||
self._require_any = None
|
||||
self._public = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -33,7 +37,7 @@ class Field:
|
||||
|
||||
@property
|
||||
def type(self) -> type:
|
||||
return self._gql_type
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def resolver(self) -> callable:
|
||||
@@ -63,6 +67,34 @@ class Field:
|
||||
def arguments(self) -> dict[str, Argument]:
|
||||
return self._args
|
||||
|
||||
@property
|
||||
def require_any_permission(self) -> TRequireAnyPermissions | None:
|
||||
return self._require_any_permission
|
||||
|
||||
@property
|
||||
def require_any(self) -> TRequireAnyResolvers | None:
|
||||
return self._require_any
|
||||
|
||||
@property
|
||||
def public(self) -> bool:
|
||||
return self._public
|
||||
|
||||
def with_type(self, t: type) -> Self:
|
||||
self._type = t
|
||||
return self
|
||||
|
||||
def with_resolver(self, resolver: Resolver) -> Self:
|
||||
self._resolver = resolver
|
||||
return self
|
||||
|
||||
def with_optional(self, optional: bool) -> Self:
|
||||
self._optional = optional
|
||||
return self
|
||||
|
||||
def with_default(self, default) -> Self:
|
||||
self._default = default
|
||||
return self
|
||||
|
||||
def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self:
|
||||
if name in self._args:
|
||||
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
|
||||
@@ -76,3 +108,21 @@ class Field:
|
||||
|
||||
self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional)
|
||||
return self
|
||||
|
||||
def with_require_any_permission(self, permissions: TRequireAnyPermissions) -> Self:
|
||||
assert permissions is not None, "require_any_permission cannot be None"
|
||||
assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type"
|
||||
self._require_any_permission = permissions
|
||||
return self
|
||||
|
||||
def with_require_any(self, permissions: TRequireAnyPermissions, resolvers: TRequireAnyResolvers) -> Self:
|
||||
assert permissions is not None, "permissions cannot be None"
|
||||
assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type"
|
||||
assert resolvers is not None, "resolvers cannot be None"
|
||||
assert all(callable(r) for r in resolvers), "All resolvers must be callable"
|
||||
self._require_any = (permissions, resolvers)
|
||||
return self
|
||||
|
||||
def with_public(self, public: bool = False) -> Self:
|
||||
self._public = public
|
||||
return self
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Callable, Type, Any, Optional
|
||||
|
||||
import strawberry
|
||||
from strawberry.exceptions import StrawberryException
|
||||
|
||||
from cpl.api import Unauthorized, Forbidden
|
||||
from cpl.api.middleware.request import get_request
|
||||
from cpl.core.ctx import get_user
|
||||
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.error import graphql_error
|
||||
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
@@ -202,23 +208,70 @@ class Query(StrawberryProtocol):
|
||||
_resolver.__signature__ = sig
|
||||
return _resolver
|
||||
|
||||
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
|
||||
# Signatur vom Original übernehmen
|
||||
sig = getattr(resolver, "__signature__", None)
|
||||
|
||||
@functools.wraps(resolver)
|
||||
async def _auth_resolver(*args, **kwargs):
|
||||
request = get_request()
|
||||
user = get_user()
|
||||
|
||||
# Public
|
||||
if f.public:
|
||||
return await self._maybe_await(resolver(*args, **kwargs))
|
||||
|
||||
# Auth required
|
||||
if user is None:
|
||||
raise graphql_error(Unauthorized("Authentication required"))
|
||||
|
||||
# Permissions
|
||||
if f.require_any_permission:
|
||||
if not any(user.has_permission(p) for p in f.require_any_permission):
|
||||
raise Forbidden("Permission denied")
|
||||
|
||||
# Custom resolvers
|
||||
if f.require_any:
|
||||
perms, resolvers = f.require_any
|
||||
if not any(user.has_permission(p) for p in perms):
|
||||
for r in resolvers:
|
||||
ok = await self._maybe_await(r(user, *args, **kwargs))
|
||||
if ok:
|
||||
break
|
||||
else:
|
||||
raise Forbidden("Permission denied")
|
||||
|
||||
return await self._maybe_await(resolver(*args, **kwargs))
|
||||
|
||||
# Signatur beibehalten
|
||||
if sig:
|
||||
_auth_resolver.__signature__ = sig
|
||||
|
||||
return _auth_resolver
|
||||
|
||||
@staticmethod
|
||||
def _maybe_await(value):
|
||||
if asyncio.iscoroutine(value):
|
||||
return value
|
||||
return asyncio.sleep(0, result=value) # sofort resolved Future
|
||||
|
||||
|
||||
def _field_to_strawberry(self, f: Field) -> Any:
|
||||
resolver = None
|
||||
try:
|
||||
if f.resolver:
|
||||
if f.arguments:
|
||||
resolver = self._build_resolver(f)
|
||||
elif not f.resolver:
|
||||
resolver = lambda *_, **__: None
|
||||
else:
|
||||
ann = getattr(f.resolver, "__annotations__", {})
|
||||
if "return" not in ann or ann["return"] is None:
|
||||
ann = dict(ann)
|
||||
ann["return"] = f.type
|
||||
f.resolver.__annotations__ = ann
|
||||
resolver = f.resolver
|
||||
|
||||
if f.arguments:
|
||||
resolver = self._build_resolver(f)
|
||||
return strawberry.field(resolver=resolver)
|
||||
|
||||
if not f.resolver:
|
||||
return strawberry.field(resolver=lambda *_, **__: None)
|
||||
|
||||
return strawberry.field(resolver=f.resolver)
|
||||
return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
|
||||
except StrawberryException as e:
|
||||
raise Exception(
|
||||
f"Error converting field '{f.name}' to strawberry field: {e}"
|
||||
|
||||
51
src/cpl-graphql/cpl/graphql/service/graphql.py
Normal file
51
src/cpl-graphql/cpl/graphql/service/graphql.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from graphql import GraphQLError
|
||||
|
||||
from cpl.api import APILogger, APIError
|
||||
from cpl.api.typing import TRequest
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
|
||||
class GraphQLService:
|
||||
def __init__(self, logger: APILogger, schema: Schema):
|
||||
self._logger = logger
|
||||
|
||||
if schema.schema is None:
|
||||
raise ValueError("Schema has not been built. Call schema.build() before using the service.")
|
||||
self._schema = schema.schema
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
query: str,
|
||||
variables: Optional[Dict[str, Any]],
|
||||
request: TRequest,
|
||||
) -> Dict[str, Any]:
|
||||
result = await self._schema.execute(
|
||||
query,
|
||||
variable_values=variables,
|
||||
context_value={"request": request},
|
||||
)
|
||||
|
||||
response_data: Dict[str, Any] = {}
|
||||
if result.errors:
|
||||
errors = []
|
||||
for error in result.errors:
|
||||
if isinstance(error, GraphQLError):
|
||||
self._logger.error(f"GraphQL APIError: {error}")
|
||||
errors.append({"message": error.message, "extensions": error.extensions})
|
||||
continue
|
||||
|
||||
if isinstance(error, APIError):
|
||||
self._logger.error(f"GraphQL APIError: {error}")
|
||||
errors.append({"message": error.error_message, "extensions": {"code": error.status_code}})
|
||||
continue
|
||||
|
||||
self._logger.error(f"GraphQL unexpected error: {error}")
|
||||
errors.append({"message": str(error), "extensions": {"code": 500}})
|
||||
|
||||
response_data["errors"] = errors
|
||||
if result.data:
|
||||
response_data["data"] = result.data
|
||||
|
||||
return response_data
|
||||
@@ -1,7 +1,11 @@
|
||||
import logging
|
||||
from typing import Type, Self
|
||||
|
||||
import strawberry
|
||||
from starlette.requests import Request
|
||||
from strawberry.types import ExecutionContext
|
||||
|
||||
from cpl.api import APIError
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
@@ -40,7 +44,18 @@ class Schema:
|
||||
|
||||
return types
|
||||
|
||||
def _graphql_exception_handler(self, error: Exception, execution_context: ExecutionContext):
|
||||
request: Request = execution_context.context.get("request")
|
||||
|
||||
if isinstance(error, APIError):
|
||||
self._logger.error(f"GraphQL APIError: {error}")
|
||||
return {"message": error.error_message, "extensions": {"code": error.status_code}}
|
||||
|
||||
self._logger.error(f"GraphQL unexpected error: {error}")
|
||||
return {"message": str(error), "extensions": {"code": 500}}
|
||||
|
||||
def build(self) -> strawberry.Schema:
|
||||
logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL)
|
||||
query = self._provider.get_service(RootQuery)
|
||||
if not query:
|
||||
raise ValueError("RootQuery not registered in service provider")
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from cpl.api.typing import TRequest
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
|
||||
class GraphQLService:
|
||||
def __init__(self, schema: Schema):
|
||||
if schema.schema is None:
|
||||
raise ValueError("Schema has not been built. Call schema.build() before using the service.")
|
||||
self._schema = schema.schema
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
query: str,
|
||||
variables: Optional[Dict[str, Any]],
|
||||
request: TRequest,
|
||||
) -> Dict[str, Any]:
|
||||
result = await self._schema.execute(
|
||||
query,
|
||||
variable_values=variables,
|
||||
context_value={"request": request},
|
||||
)
|
||||
|
||||
response_data: Dict[str, Any] = {}
|
||||
if result.errors:
|
||||
response_data["errors"] = [str(e) for e in result.errors]
|
||||
if result.data:
|
||||
response_data["data"] = result.data
|
||||
|
||||
return response_data
|
||||
@@ -1,5 +1,15 @@
|
||||
from typing import Type, Callable
|
||||
from enum import Enum
|
||||
from typing import Type, Callable, List, Tuple, Awaitable
|
||||
|
||||
from cpl.auth.permission import Permissions
|
||||
from cpl.graphql.query_context import QueryContext
|
||||
|
||||
TQuery = Type["Query"]
|
||||
Resolver = Callable
|
||||
ScalarType = str | int | float | bool | object
|
||||
|
||||
TRequireAnyPermissions = List[Enum | Permissions] | None
|
||||
TRequireAnyResolvers = List[
|
||||
Callable[[QueryContext], bool | Awaitable[bool]],
|
||||
]
|
||||
TRequireAny = Tuple[TRequireAnyPermissions, TRequireAnyResolvers]
|
||||
|
||||
Reference in New Issue
Block a user