From 21c01164afdf7ca9c78a59906fa92b6d3ad15c21 Mon Sep 17 00:00:00 2001 From: edraft Date: Sun, 28 Sep 2025 14:53:57 +0200 Subject: [PATCH] require any #181 --- example/api/src/main.py | 8 ++- example/api/src/model/post_query.py | 12 ++-- example/api/src/permissions.py | 8 +++ .../cpl/auth/permission/permission_module.py | 3 +- .../cpl/auth/permission/permission_seeder.py | 1 - .../cpl/auth/permission/role_seeder.py | 60 +++++++++++++++++++ .../schema/_administration/auth_user_dao.py | 12 ++-- .../cpl/auth/schema/_permission/role.py | 2 +- .../schema/_permission/role_permission.py | 2 +- .../cpl/auth/schema/_permission/role_user.py | 2 +- .../scripts/mysql/3-roles-permissions.sql | 38 ++++++------ .../scripts/postgres/3-roles-permissions.sql | 8 +-- .../database/abc/data_access_object_abc.py | 2 +- .../cpl/database/table_manager.py | 2 +- src/cpl-graphql/cpl/graphql/query_context.py | 41 +------------ src/cpl-graphql/cpl/graphql/schema/field.py | 11 +++- src/cpl-graphql/cpl/graphql/schema/query.py | 60 ++++++++----------- .../cpl/graphql/service/graphql.py | 15 ++--- src/cpl-graphql/cpl/graphql/service/schema.py | 13 ---- 19 files changed, 165 insertions(+), 135 deletions(-) create mode 100644 example/api/src/permissions.py create mode 100644 src/cpl-auth/cpl/auth/permission/role_seeder.py diff --git a/example/api/src/main.py b/example/api/src/main.py index c149fe47..06e39aa4 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -18,6 +18,7 @@ from model.author_dao import AuthorDao from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort from model.post_dao import PostDao from model.post_query import PostFilter, PostSort, PostGraphType +from permissions import PostPermissions from queries.hello import HelloQuery from scoped_service import ScopedService from service import PingService @@ -85,11 +86,16 @@ def main(): schema.query.string_field("ping", resolver=lambda: "pong") schema.query.with_query("hello", HelloQuery) schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort) - schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort).with_public(True) + ( + schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort) + .with_require_any_permission(PostPermissions.read) + ) app.with_playground() app.with_graphiql() + app.with_permissions(PostPermissions) + provider = builder.service_provider user_cache = provider.get_service(Cache[AuthUser]) role_cache = provider.get_service(Cache[Role]) diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py index 48845617..381c94ca 100644 --- a/example/api/src/model/post_query.py +++ b/example/api/src/model/post_query.py @@ -1,3 +1,4 @@ +from cpl.graphql.query_context import QueryContext from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.graph_type import GraphType from cpl.graphql.schema.sort.sort import Sort @@ -34,11 +35,12 @@ class PostGraphType(GraphType[Post]): async def _a(root: Post): return await authors.get_by_id(root.author_id) - self.object_field( - "author", - AuthorGraphType, - resolver=_a#lambda root: root.author_id, - ).with_public(True) + def r_name(ctx: QueryContext): + return ctx.user.username == "admin" + + self.object_field("author", AuthorGraphType, resolver=_a).with_require_any( + [], [r_name] + ) self.string_field( "title", resolver=lambda root: root.title, diff --git a/example/api/src/permissions.py b/example/api/src/permissions.py new file mode 100644 index 00000000..d2e1d450 --- /dev/null +++ b/example/api/src/permissions.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class PostPermissions(Enum): + + read = "post.read" + write = "post.write" + delete = "post.delete" \ No newline at end of file diff --git a/src/cpl-auth/cpl/auth/permission/permission_module.py b/src/cpl-auth/cpl/auth/permission/permission_module.py index 16955c57..eafaeadc 100644 --- a/src/cpl-auth/cpl/auth/permission/permission_module.py +++ b/src/cpl-auth/cpl/auth/permission/permission_module.py @@ -2,6 +2,7 @@ from cpl.auth.auth_module import AuthModule from cpl.auth.permission.permission_seeder import PermissionSeeder from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions_registry import PermissionsRegistry +from cpl.auth.permission.role_seeder import RoleSeeder from cpl.database.abc.data_seeder_abc import DataSeederABC from cpl.database.database_module import DatabaseModule from cpl.dependency.module.module import Module @@ -10,7 +11,7 @@ from cpl.dependency.service_collection import ServiceCollection class PermissionsModule(Module): dependencies = [DatabaseModule, AuthModule] - singleton = [(DataSeederABC, PermissionSeeder)] + transient = [(DataSeederABC, PermissionSeeder), (DataSeederABC, RoleSeeder)] @staticmethod def register(collection: ServiceCollection): diff --git a/src/cpl-auth/cpl/auth/permission/permission_seeder.py b/src/cpl-auth/cpl/auth/permission/permission_seeder.py index d9d42cfa..aab41139 100644 --- a/src/cpl-auth/cpl/auth/permission/permission_seeder.py +++ b/src/cpl-auth/cpl/auth/permission/permission_seeder.py @@ -1,4 +1,3 @@ -from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions_registry import PermissionsRegistry from cpl.auth.schema import ( Permission, diff --git a/src/cpl-auth/cpl/auth/permission/role_seeder.py b/src/cpl-auth/cpl/auth/permission/role_seeder.py new file mode 100644 index 00000000..2c7687bd --- /dev/null +++ b/src/cpl-auth/cpl/auth/permission/role_seeder.py @@ -0,0 +1,60 @@ +from cpl.auth.schema import ( + Role, + RolePermission, + PermissionDao, + RoleDao, + RolePermissionDao, + ApiKeyDao, + ApiKeyPermissionDao, + AuthUserDao, + RoleUserDao, + RoleUser, +) +from cpl.database.abc.data_seeder_abc import DataSeederABC +from cpl.database.logger import DBLogger + + +class RoleSeeder(DataSeederABC): + def __init__( + self, + logger: DBLogger, + permission_dao: PermissionDao, + role_dao: RoleDao, + role_permission_dao: RolePermissionDao, + api_key_dao: ApiKeyDao, + api_key_permission_dao: ApiKeyPermissionDao, + user_dao: AuthUserDao, + role_user_dao: RoleUserDao, + ): + DataSeederABC.__init__(self) + self._logger = logger + self._permission_dao = permission_dao + self._role_dao = role_dao + self._role_permission_dao = role_permission_dao + self._api_key_dao = api_key_dao + self._api_key_permission_dao = api_key_permission_dao + self._user_dao = user_dao + self._role_user_dao = role_user_dao + + async def seed(self): + self._logger.info("Creating admin role") + roles = await self._role_dao.get_all() + if len(roles) == 0: + rid = await self._role_dao.create(Role(0, "admin", "Default admin role")) + permissions = await self._permission_dao.get_all() + + await self._role_permission_dao.create_many( + [RolePermission(0, rid, permission.id) for permission in permissions] + ) + + role = await self._role_dao.get_by_name("admin") + if len(await role.users) > 0: + return + + users = await self._user_dao.get_all() + if len(users) == 0: + return + + user = users[0] + self._logger.warning(f"Assigning admin role to first user {user.id}") + await self._role_user_dao.create(RoleUser(0, role.id, user.id)) diff --git a/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py b/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py index 4b27549a..bf59a534 100644 --- a/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py +++ b/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py @@ -1,6 +1,8 @@ from typing import Optional, Union from cpl.auth.permission.permissions import Permissions +from cpl.auth.schema._permission.permission_dao import PermissionDao +from cpl.auth.schema._permission.permission import Permission from cpl.auth.schema._administration.auth_user import AuthUser from cpl.database import TableManager from cpl.database.abc import DbModelDaoABC @@ -10,10 +12,12 @@ from cpl.dependency.context import get_provider class AuthUserDao(DbModelDaoABC[AuthUser]): - def __init__(self): + def __init__(self, permission_dao: PermissionDao): DbModelDaoABC.__init__(self, AuthUser, TableManager.get("auth_users")) - self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId") + self._permissions = permission_dao + + self.attribute(AuthUser.keycloak_id, str) async def get_users(): return [(x.id, x.username, x.email) for x in await self.get_all()] @@ -54,7 +58,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]): return result[0]["count"] > 0 - async def get_permissions(self, user_id: int) -> list[Permissions]: + async def get_permissions(self, user_id: int) -> list[Permission]: result = await self._db.select_map( f""" SELECT p.* @@ -66,4 +70,4 @@ class AuthUserDao(DbModelDaoABC[AuthUser]): AND ru.deleted = FALSE; """ ) - return [Permissions(p["name"]) for p in result] + return [self._permissions.to_object(x) for x in result] diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role.py b/src/cpl-auth/cpl/auth/schema/_permission/role.py index 24a5d82d..3c1b0a1f 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role.py @@ -6,7 +6,7 @@ from async_property import async_property from cpl.auth.permission.permissions import Permissions from cpl.core.typing import SerialId from cpl.database.abc import DbModelABC -from cpl.dependency import ServiceProvider +from cpl.dependency import ServiceProvider, get_provider class Role(DbModelABC[Self]): diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py b/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py index 82bacb4a..c58d8682 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py @@ -5,7 +5,7 @@ from async_property import async_property from cpl.core.typing import SerialId from cpl.database.abc import DbModelABC -from cpl.dependency import ServiceProvider +from cpl.dependency import ServiceProvider, get_provider class RolePermission(DbModelABC[Self]): diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role_user.py b/src/cpl-auth/cpl/auth/schema/_permission/role_user.py index 5db0f892..72504768 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role_user.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role_user.py @@ -5,7 +5,7 @@ from async_property import async_property from cpl.core.typing import SerialId from cpl.database.abc import DbJoinModelABC -from cpl.dependency import ServiceProvider +from cpl.dependency import ServiceProvider, get_provider class RoleUser(DbJoinModelABC): diff --git a/src/cpl-auth/cpl/auth/scripts/mysql/3-roles-permissions.sql b/src/cpl-auth/cpl/auth/scripts/mysql/3-roles-permissions.sql index f3082a48..63a58fbf 100644 --- a/src/cpl-auth/cpl/auth/scripts/mysql/3-roles-permissions.sql +++ b/src/cpl-auth/cpl/auth/scripts/mysql/3-roles-permissions.sql @@ -89,14 +89,14 @@ END; CREATE TABLE IF NOT EXISTS permission_role_permissions ( id INT AUTO_INCREMENT PRIMARY KEY, - RoleId INT NOT NULL, + roleId INT NOT NULL, permissionId INT NOT NULL, deleted BOOL NOT NULL DEFAULT FALSE, editorId INT NULL, created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - CONSTRAINT UQ_RolePermission UNIQUE (RoleId, permissionId), - CONSTRAINT FK_RolePermissions_Role FOREIGN KEY (RoleId) REFERENCES permission_roles (id) ON DELETE CASCADE, + CONSTRAINT UQ_RolePermission UNIQUE (roleId, permissionId), + CONSTRAINT FK_RolePermissions_Role FOREIGN KEY (roleId) REFERENCES permission_roles (id) ON DELETE CASCADE, CONSTRAINT FK_RolePermissions_Permission FOREIGN KEY (permissionId) REFERENCES permission_permissions (id) ON DELETE CASCADE, CONSTRAINT FK_RolePermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) ); @@ -104,7 +104,7 @@ CREATE TABLE IF NOT EXISTS permission_role_permissions CREATE TABLE IF NOT EXISTS permission_role_permissions_history ( id INT NOT NULL, - RoleId INT NOT NULL, + roleId INT NOT NULL, permissionId INT NOT NULL, deleted BOOL NOT NULL, editorId INT NULL, @@ -118,8 +118,8 @@ CREATE TRIGGER TR_RolePermissionsUpdate FOR EACH ROW BEGIN INSERT INTO permission_role_permissions_history - (id, RoleId, permissionId, deleted, editorId, created, updated) - VALUES (OLD.id, OLD.RoleId, OLD.permissionId, OLD.deleted, OLD.editorId, OLD.created, NOW()); + (id, roleId, permissionId, deleted, editorId, created, updated) + VALUES (OLD.id, OLD.roleId, OLD.permissionId, OLD.deleted, OLD.editorId, OLD.created, NOW()); END; CREATE TRIGGER TR_RolePermissionsDelete @@ -128,30 +128,30 @@ CREATE TRIGGER TR_RolePermissionsDelete FOR EACH ROW BEGIN INSERT INTO permission_role_permissions_history - (id, RoleId, permissionId, deleted, editorId, created, updated) - VALUES (OLD.id, OLD.RoleId, OLD.permissionId, 1, OLD.editorId, OLD.created, NOW()); + (id, roleId, permissionId, deleted, editorId, created, updated) + VALUES (OLD.id, OLD.roleId, OLD.permissionId, 1, OLD.editorId, OLD.created, NOW()); END; CREATE TABLE IF NOT EXISTS permission_role_auth_users ( id INT AUTO_INCREMENT PRIMARY KEY, - RoleId INT NOT NULL, - UserId INT NOT NULL, + roleId INT NOT NULL, + userId INT NOT NULL, deleted BOOL NOT NULL DEFAULT FALSE, editorId INT NULL, created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - CONSTRAINT UQ_RoleUser UNIQUE (RoleId, UserId), - CONSTRAINT FK_Roleauth_users_Role FOREIGN KEY (RoleId) REFERENCES permission_roles (id) ON DELETE CASCADE, - CONSTRAINT FK_Roleauth_users_User FOREIGN KEY (UserId) REFERENCES administration_auth_users (id) ON DELETE CASCADE, + CONSTRAINT UQ_RoleUser UNIQUE (roleId, userId), + CONSTRAINT FK_Roleauth_users_Role FOREIGN KEY (roleId) REFERENCES permission_roles (id) ON DELETE CASCADE, + CONSTRAINT FK_Roleauth_users_User FOREIGN KEY (userId) REFERENCES administration_auth_users (id) ON DELETE CASCADE, CONSTRAINT FK_Roleauth_users_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) ); CREATE TABLE IF NOT EXISTS permission_role_auth_users_history ( id INT NOT NULL, - RoleId INT NOT NULL, - UserId INT NOT NULL, + roleId INT NOT NULL, + userId INT NOT NULL, deleted BOOL NOT NULL, editorId INT NULL, created TIMESTAMP NOT NULL, @@ -164,8 +164,8 @@ CREATE TRIGGER TR_Roleauth_usersUpdate FOR EACH ROW BEGIN INSERT INTO permission_role_auth_users_history - (id, RoleId, UserId, deleted, editorId, created, updated) - VALUES (OLD.id, OLD.RoleId, OLD.UserId, OLD.deleted, OLD.editorId, OLD.created, NOW()); + (id, roleId, userId, deleted, editorId, created, updated) + VALUES (OLD.id, OLD.roleId, OLD.userId, OLD.deleted, OLD.editorId, OLD.created, NOW()); END; CREATE TRIGGER TR_Roleauth_usersDelete @@ -174,6 +174,6 @@ CREATE TRIGGER TR_Roleauth_usersDelete FOR EACH ROW BEGIN INSERT INTO permission_role_auth_users_history - (id, RoleId, UserId, deleted, editorId, created, updated) - VALUES (OLD.id, OLD.RoleId, OLD.UserId, 1, OLD.editorId, OLD.created, NOW()); + (id, roleId, userId, deleted, editorId, created, updated) + VALUES (OLD.id, OLD.roleId, OLD.userId, 1, OLD.editorId, OLD.created, NOW()); END; diff --git a/src/cpl-auth/cpl/auth/scripts/postgres/3-roles-permissions.sql b/src/cpl-auth/cpl/auth/scripts/postgres/3-roles-permissions.sql index 42b9283b..72400191 100644 --- a/src/cpl-auth/cpl/auth/scripts/postgres/3-roles-permissions.sql +++ b/src/cpl-auth/cpl/auth/scripts/postgres/3-roles-permissions.sql @@ -79,7 +79,7 @@ CREATE TRIGGER versioning_trigger EXECUTE PROCEDURE public.history_trigger_function(); -- Role user -CREATE TABLE permission.role_users +CREATE TABLE permission.role_auth_users ( id SERIAL PRIMARY KEY, RoleId INT NOT NULL REFERENCES permission.roles (id) ON DELETE CASCADE, @@ -93,13 +93,13 @@ CREATE TABLE permission.role_users CONSTRAINT UQ_RoleUser UNIQUE (RoleId, UserId) ); -CREATE TABLE permission.role_users_history +CREATE TABLE permission.role_auth_users_history ( - LIKE permission.role_users + LIKE permission.role_auth_users ); CREATE TRIGGER versioning_trigger BEFORE INSERT OR UPDATE OR DELETE - ON permission.role_users + ON permission.role_auth_users FOR EACH ROW EXECUTE PROCEDURE public.history_trigger_function(); \ No newline at end of file diff --git a/src/cpl-database/cpl/database/abc/data_access_object_abc.py b/src/cpl-database/cpl/database/abc/data_access_object_abc.py index 44f2a0bf..7f1e235b 100644 --- a/src/cpl-database/cpl/database/abc/data_access_object_abc.py +++ b/src/cpl-database/cpl/database/abc/data_access_object_abc.py @@ -85,7 +85,7 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]): self.__ignored_attributes.add(attr_name) if not db_name: - db_name = attr_name.lower().replace("_", "") + db_name = String.to_camel_case(attr_name) self.__db_names[attr_name] = db_name self.__db_names[db_name] = db_name diff --git a/src/cpl-database/cpl/database/table_manager.py b/src/cpl-database/cpl/database/table_manager.py index 9bd1f6b2..2d5ac533 100644 --- a/src/cpl-database/cpl/database/table_manager.py +++ b/src/cpl-database/cpl/database/table_manager.py @@ -32,7 +32,7 @@ class TableManager: ServerTypes.MYSQL: "permission_role_permissions", }, "role_users": { - ServerTypes.POSTGRES: "permission.role_users", + ServerTypes.POSTGRES: "permission.role_auth_users", ServerTypes.MYSQL: "permission_role_auth_users", }, } diff --git a/src/cpl-graphql/cpl/graphql/query_context.py b/src/cpl-graphql/cpl/graphql/query_context.py index 79d0b965..9b75d694 100644 --- a/src/cpl-graphql/cpl/graphql/query_context.py +++ b/src/cpl-graphql/cpl/graphql/query_context.py @@ -4,6 +4,7 @@ from typing import Optional, Any from graphql import GraphQLResolveInfo from cpl.auth.schema import AuthUser, Permission +from cpl.core.ctx import get_user from cpl.core.utils import get_value @@ -11,19 +12,13 @@ class QueryContext: def __init__( self, - data: Any, - user: Optional[AuthUser], user_permissions: Optional[list[Enum | Permission]], is_mutation: bool = False, *args, **kwargs ): - - self._data = data - self._user = user - if user_permissions is None: - user_permissions = [] - self._user_permissions: list[str] = [x.name for x in user_permissions] + self._user = get_user() + self._user_permissions = user_permissions or [] self._resolve_info = None for arg in args: @@ -31,21 +26,11 @@ class QueryContext: self._resolve_info = arg continue - self._filter = kwargs.get("filters", {}) - self._sort = kwargs.get("sort", {}) - self._skip = get_value(kwargs, "skip", int) - self._take = get_value(kwargs, "take", int) - - self._input = kwargs.get("input", None) self._args = args self._kwargs = kwargs self._is_mutation = is_mutation - @property - def data(self): - return self._data - @property def user(self) -> AuthUser: return self._user @@ -54,26 +39,6 @@ class QueryContext: def resolve_info(self) -> Optional[GraphQLResolveInfo]: return self._resolve_info - @property - def filter(self) -> dict: - return self._filter - - @property - def sort(self) -> dict: - return self._sort - - @property - def skip(self) -> Optional[int]: - return self._skip - - @property - def take(self) -> Optional[int]: - return self._take - - @property - def input(self) -> Optional[Any]: - return self._input - @property def args(self) -> tuple: return self._args diff --git a/src/cpl-graphql/cpl/graphql/schema/field.py b/src/cpl-graphql/cpl/graphql/schema/field.py index d9417bdb..421413a4 100644 --- a/src/cpl-graphql/cpl/graphql/schema/field.py +++ b/src/cpl-graphql/cpl/graphql/schema/field.py @@ -5,7 +5,7 @@ from cpl.graphql.schema.argument import Argument from cpl.graphql.typing import TQuery, Resolver, TRequireAnyPermissions, TRequireAnyResolvers -class Field: +class Field: def __init__( self, @@ -109,9 +109,12 @@ class Field: self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional) return self - def with_require_any_permission(self, permissions: TRequireAnyPermissions) -> Self: + def with_require_any_permission(self, *permissions: TRequireAnyPermissions) -> Self: + if not isinstance(permissions, list): + permissions = list(permissions) + assert permissions is not None, "require_any_permission cannot be None" - assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type" + assert all(isinstance(x, (str, Enum)) for x in permissions), "All permissions must be of Permission type" self._require_any_permission = permissions return self @@ -124,5 +127,7 @@ class Field: return self def with_public(self, public: bool = False) -> Self: + assert self._require_any is None, "Field cannot be public and have require_any set" + assert self._require_any_permission is None, "Field cannot be public and have require_any_permission set" self._public = public return self diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py index 2f0e23f0..0b8df16f 100644 --- a/src/cpl-graphql/cpl/graphql/schema/query.py +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -1,19 +1,19 @@ -import asyncio import functools import inspect +from asyncio import iscoroutinefunction from typing import Callable, Type, Any, Optional import strawberry from strawberry.exceptions import StrawberryException from cpl.api import Unauthorized, Forbidden -from cpl.api.middleware.request import get_request from cpl.core.ctx import get_user from cpl.database.abc.data_access_object_abc import DataAccessObjectABC from cpl.dependency.inject import inject from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.error import graphql_error +from cpl.graphql.query_context import QueryContext from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory from cpl.graphql.schema.field import Field from cpl.graphql.schema.sort.sort_order import SortOrder @@ -141,7 +141,6 @@ class Query(StrawberryProtocol): if v is None: continue - # verschachtelte Inputs rekursiv if hasattr(v, "__dict__"): result[k] = input_to_dict(v) else: @@ -152,9 +151,6 @@ class Query(StrawberryProtocol): filter_dict = input_to_dict(filter) if filter is not None else None sort_dict = None - if filter is not None: - pass - if sort is not None: sort_dict = {} for k, v in sort.__dict__.items(): @@ -202,59 +198,55 @@ class Query(StrawberryProtocol): sig = inspect.Signature(parameters=params, return_annotation=f.type) - def _resolver(*args, **kwargs): - return f.resolver(*args, **kwargs) if f.resolver else None + async def _resolver(*args, **kwargs): + if f.resolver is None: + return None + + if iscoroutinefunction(f.resolver): + return await f.resolver(*args, **kwargs) + return f.resolver(*args, **kwargs) _resolver.__signature__ = sig return _resolver def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable: - # Signatur vom Original übernehmen sig = getattr(resolver, "__signature__", None) @functools.wraps(resolver) async def _auth_resolver(*args, **kwargs): - request = get_request() + if f.public: + return await self._run_resolver(resolver, *args, **kwargs) + user = get_user() - # Public - if f.public: - return await self._maybe_await(resolver(*args, **kwargs)) - - # Auth required if user is None: - raise graphql_error(Unauthorized("Authentication required")) + raise graphql_error(Unauthorized(f"{f.name}: Authentication required")) - # Permissions if f.require_any_permission: - if not any(user.has_permission(p) for p in f.require_any_permission): - raise Forbidden("Permission denied") + if not any([await user.has_permission(p) for p in f.require_any_permission]): + raise graphql_error(Forbidden(f"{f.name}: Permission denied")) - # Custom resolvers if f.require_any: perms, resolvers = f.require_any - if not any(user.has_permission(p) for p in perms): - for r in resolvers: - ok = await self._maybe_await(r(user, *args, **kwargs)) - if ok: - break - else: - raise Forbidden("Permission denied") + if not any([await user.has_permission(p) for p in perms]): + ctx = QueryContext([x.name for x in await user.permissions]) + resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers] - return await self._maybe_await(resolver(*args, **kwargs)) + if not any(resolved): + raise graphql_error(Forbidden(f"{f.name}: Permission denied")) + + return await self._run_resolver(resolver, *args, **kwargs) - # Signatur beibehalten if sig: _auth_resolver.__signature__ = sig return _auth_resolver @staticmethod - def _maybe_await(value): - if asyncio.iscoroutine(value): - return value - return asyncio.sleep(0, result=value) # sofort resolved Future - + async def _run_resolver(r: Callable, *args, **kwargs): + if iscoroutinefunction(r): + return await r(*args, **kwargs) + return r(*args, **kwargs) def _field_to_strawberry(self, f: Field) -> Any: resolver = None diff --git a/src/cpl-graphql/cpl/graphql/service/graphql.py b/src/cpl-graphql/cpl/graphql/service/graphql.py index c816b2e1..cb4ee667 100644 --- a/src/cpl-graphql/cpl/graphql/service/graphql.py +++ b/src/cpl-graphql/cpl/graphql/service/graphql.py @@ -31,17 +31,18 @@ class GraphQLService: if result.errors: errors = [] for error in result.errors: - if isinstance(error, GraphQLError): - self._logger.error(f"GraphQL APIError: {error}") - errors.append({"message": error.message, "extensions": error.extensions}) - continue - if isinstance(error, APIError): - self._logger.error(f"GraphQL APIError: {error}") + self._logger.error(f"GraphQL APIError", error) errors.append({"message": error.error_message, "extensions": {"code": error.status_code}}) continue - self._logger.error(f"GraphQL unexpected error: {error}") + if isinstance(error, GraphQLError): + + self._logger.error(f"GraphQLError", error) + errors.append({"message": error.message, "extensions": error.extensions}) + continue + + self._logger.error(f"GraphQL unexpected error", error) errors.append({"message": str(error), "extensions": {"code": 500}}) response_data["errors"] = errors diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py index f0b01b05..c1c43cdc 100644 --- a/src/cpl-graphql/cpl/graphql/service/schema.py +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -2,10 +2,7 @@ import logging from typing import Type, Self import strawberry -from starlette.requests import Request -from strawberry.types import ExecutionContext -from cpl.api import APIError from cpl.api.logger import APILogger from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol @@ -44,16 +41,6 @@ class Schema: return types - def _graphql_exception_handler(self, error: Exception, execution_context: ExecutionContext): - request: Request = execution_context.context.get("request") - - if isinstance(error, APIError): - self._logger.error(f"GraphQL APIError: {error}") - return {"message": error.error_message, "extensions": {"code": error.status_code}} - - self._logger.error(f"GraphQL unexpected error: {error}") - return {"message": str(error), "extensions": {"code": 500}} - def build(self) -> strawberry.Schema: logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL) query = self._provider.get_service(RootQuery)