WIP: dev into master #184

Draft
edraft wants to merge 121 commits from dev into master
19 changed files with 165 additions and 135 deletions
Showing only changes of commit 3286a95cbf - Show all commits

View File

@@ -18,6 +18,7 @@ from model.author_dao import AuthorDao
from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort
from model.post_dao import PostDao from model.post_dao import PostDao
from model.post_query import PostFilter, PostSort, PostGraphType from model.post_query import PostFilter, PostSort, PostGraphType
from permissions import PostPermissions
from queries.hello import HelloQuery from queries.hello import HelloQuery
from scoped_service import ScopedService from scoped_service import ScopedService
from service import PingService from service import PingService
@@ -85,11 +86,16 @@ def main():
schema.query.string_field("ping", resolver=lambda: "pong") schema.query.string_field("ping", resolver=lambda: "pong")
schema.query.with_query("hello", HelloQuery) schema.query.with_query("hello", HelloQuery)
schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort) schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort)
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort).with_public(True) (
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort)
.with_require_any_permission(PostPermissions.read)
)
app.with_playground() app.with_playground()
app.with_graphiql() app.with_graphiql()
app.with_permissions(PostPermissions)
provider = builder.service_provider provider = builder.service_provider
user_cache = provider.get_service(Cache[AuthUser]) user_cache = provider.get_service(Cache[AuthUser])
role_cache = provider.get_service(Cache[Role]) role_cache = provider.get_service(Cache[Role])

View File

@@ -1,3 +1,4 @@
from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.graph_type import GraphType from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort import Sort
@@ -34,11 +35,12 @@ class PostGraphType(GraphType[Post]):
async def _a(root: Post): async def _a(root: Post):
return await authors.get_by_id(root.author_id) return await authors.get_by_id(root.author_id)
self.object_field( def r_name(ctx: QueryContext):
"author", return ctx.user.username == "admin"
AuthorGraphType,
resolver=_a#lambda root: root.author_id, self.object_field("author", AuthorGraphType, resolver=_a).with_require_any(
).with_public(True) [], [r_name]
)
self.string_field( self.string_field(
"title", "title",
resolver=lambda root: root.title, resolver=lambda root: root.title,

View File

@@ -0,0 +1,8 @@
from enum import Enum
class PostPermissions(Enum):
read = "post.read"
write = "post.write"
delete = "post.delete"

View File

@@ -2,6 +2,7 @@ from cpl.auth.auth_module import AuthModule
from cpl.auth.permission.permission_seeder import PermissionSeeder from cpl.auth.permission.permission_seeder import PermissionSeeder
from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions import Permissions
from cpl.auth.permission.permissions_registry import PermissionsRegistry from cpl.auth.permission.permissions_registry import PermissionsRegistry
from cpl.auth.permission.role_seeder import RoleSeeder
from cpl.database.abc.data_seeder_abc import DataSeederABC from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.database_module import DatabaseModule from cpl.database.database_module import DatabaseModule
from cpl.dependency.module.module import Module from cpl.dependency.module.module import Module
@@ -10,7 +11,7 @@ from cpl.dependency.service_collection import ServiceCollection
class PermissionsModule(Module): class PermissionsModule(Module):
dependencies = [DatabaseModule, AuthModule] dependencies = [DatabaseModule, AuthModule]
singleton = [(DataSeederABC, PermissionSeeder)] transient = [(DataSeederABC, PermissionSeeder), (DataSeederABC, RoleSeeder)]
@staticmethod @staticmethod
def register(collection: ServiceCollection): def register(collection: ServiceCollection):

View File

@@ -1,4 +1,3 @@
from cpl.auth.permission.permissions import Permissions
from cpl.auth.permission.permissions_registry import PermissionsRegistry from cpl.auth.permission.permissions_registry import PermissionsRegistry
from cpl.auth.schema import ( from cpl.auth.schema import (
Permission, Permission,

View File

@@ -0,0 +1,60 @@
from cpl.auth.schema import (
Role,
RolePermission,
PermissionDao,
RoleDao,
RolePermissionDao,
ApiKeyDao,
ApiKeyPermissionDao,
AuthUserDao,
RoleUserDao,
RoleUser,
)
from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.logger import DBLogger
class RoleSeeder(DataSeederABC):
def __init__(
self,
logger: DBLogger,
permission_dao: PermissionDao,
role_dao: RoleDao,
role_permission_dao: RolePermissionDao,
api_key_dao: ApiKeyDao,
api_key_permission_dao: ApiKeyPermissionDao,
user_dao: AuthUserDao,
role_user_dao: RoleUserDao,
):
DataSeederABC.__init__(self)
self._logger = logger
self._permission_dao = permission_dao
self._role_dao = role_dao
self._role_permission_dao = role_permission_dao
self._api_key_dao = api_key_dao
self._api_key_permission_dao = api_key_permission_dao
self._user_dao = user_dao
self._role_user_dao = role_user_dao
async def seed(self):
self._logger.info("Creating admin role")
roles = await self._role_dao.get_all()
if len(roles) == 0:
rid = await self._role_dao.create(Role(0, "admin", "Default admin role"))
permissions = await self._permission_dao.get_all()
await self._role_permission_dao.create_many(
[RolePermission(0, rid, permission.id) for permission in permissions]
)
role = await self._role_dao.get_by_name("admin")
if len(await role.users) > 0:
return
users = await self._user_dao.get_all()
if len(users) == 0:
return
user = users[0]
self._logger.warning(f"Assigning admin role to first user {user.id}")
await self._role_user_dao.create(RoleUser(0, role.id, user.id))

View File

@@ -1,6 +1,8 @@
from typing import Optional, Union from typing import Optional, Union
from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema._permission.permission_dao import PermissionDao
from cpl.auth.schema._permission.permission import Permission
from cpl.auth.schema._administration.auth_user import AuthUser from cpl.auth.schema._administration.auth_user import AuthUser
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
@@ -10,10 +12,12 @@ from cpl.dependency.context import get_provider
class AuthUserDao(DbModelDaoABC[AuthUser]): class AuthUserDao(DbModelDaoABC[AuthUser]):
def __init__(self): def __init__(self, permission_dao: PermissionDao):
DbModelDaoABC.__init__(self, AuthUser, TableManager.get("auth_users")) DbModelDaoABC.__init__(self, AuthUser, TableManager.get("auth_users"))
self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId") self._permissions = permission_dao
self.attribute(AuthUser.keycloak_id, str)
async def get_users(): async def get_users():
return [(x.id, x.username, x.email) for x in await self.get_all()] return [(x.id, x.username, x.email) for x in await self.get_all()]
@@ -54,7 +58,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
return result[0]["count"] > 0 return result[0]["count"] > 0
async def get_permissions(self, user_id: int) -> list[Permissions]: async def get_permissions(self, user_id: int) -> list[Permission]:
result = await self._db.select_map( result = await self._db.select_map(
f""" f"""
SELECT p.* SELECT p.*
@@ -66,4 +70,4 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
AND ru.deleted = FALSE; AND ru.deleted = FALSE;
""" """
) )
return [Permissions(p["name"]) for p in result] return [self._permissions.to_object(x) for x in result]

View File

@@ -6,7 +6,7 @@ from async_property import async_property
from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions import Permissions
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProvider, get_provider
class Role(DbModelABC[Self]): class Role(DbModelABC[Self]):

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProvider, get_provider
class RolePermission(DbModelABC[Self]): class RolePermission(DbModelABC[Self]):

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbJoinModelABC from cpl.database.abc import DbJoinModelABC
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProvider, get_provider
class RoleUser(DbJoinModelABC): class RoleUser(DbJoinModelABC):

View File

@@ -89,14 +89,14 @@ END;
CREATE TABLE IF NOT EXISTS permission_role_permissions CREATE TABLE IF NOT EXISTS permission_role_permissions
( (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
RoleId INT NOT NULL, roleId INT NOT NULL,
permissionId INT NOT NULL, permissionId INT NOT NULL,
deleted BOOL NOT NULL DEFAULT FALSE, deleted BOOL NOT NULL DEFAULT FALSE,
editorId INT NULL, editorId INT NULL,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
CONSTRAINT UQ_RolePermission UNIQUE (RoleId, permissionId), CONSTRAINT UQ_RolePermission UNIQUE (roleId, permissionId),
CONSTRAINT FK_RolePermissions_Role FOREIGN KEY (RoleId) REFERENCES permission_roles (id) ON DELETE CASCADE, CONSTRAINT FK_RolePermissions_Role FOREIGN KEY (roleId) REFERENCES permission_roles (id) ON DELETE CASCADE,
CONSTRAINT FK_RolePermissions_Permission FOREIGN KEY (permissionId) REFERENCES permission_permissions (id) ON DELETE CASCADE, CONSTRAINT FK_RolePermissions_Permission FOREIGN KEY (permissionId) REFERENCES permission_permissions (id) ON DELETE CASCADE,
CONSTRAINT FK_RolePermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) CONSTRAINT FK_RolePermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id)
); );
@@ -104,7 +104,7 @@ CREATE TABLE IF NOT EXISTS permission_role_permissions
CREATE TABLE IF NOT EXISTS permission_role_permissions_history CREATE TABLE IF NOT EXISTS permission_role_permissions_history
( (
id INT NOT NULL, id INT NOT NULL,
RoleId INT NOT NULL, roleId INT NOT NULL,
permissionId INT NOT NULL, permissionId INT NOT NULL,
deleted BOOL NOT NULL, deleted BOOL NOT NULL,
editorId INT NULL, editorId INT NULL,
@@ -118,8 +118,8 @@ CREATE TRIGGER TR_RolePermissionsUpdate
FOR EACH ROW FOR EACH ROW
BEGIN BEGIN
INSERT INTO permission_role_permissions_history INSERT INTO permission_role_permissions_history
(id, RoleId, permissionId, deleted, editorId, created, updated) (id, roleId, permissionId, deleted, editorId, created, updated)
VALUES (OLD.id, OLD.RoleId, OLD.permissionId, OLD.deleted, OLD.editorId, OLD.created, NOW()); VALUES (OLD.id, OLD.roleId, OLD.permissionId, OLD.deleted, OLD.editorId, OLD.created, NOW());
END; END;
CREATE TRIGGER TR_RolePermissionsDelete CREATE TRIGGER TR_RolePermissionsDelete
@@ -128,30 +128,30 @@ CREATE TRIGGER TR_RolePermissionsDelete
FOR EACH ROW FOR EACH ROW
BEGIN BEGIN
INSERT INTO permission_role_permissions_history INSERT INTO permission_role_permissions_history
(id, RoleId, permissionId, deleted, editorId, created, updated) (id, roleId, permissionId, deleted, editorId, created, updated)
VALUES (OLD.id, OLD.RoleId, OLD.permissionId, 1, OLD.editorId, OLD.created, NOW()); VALUES (OLD.id, OLD.roleId, OLD.permissionId, 1, OLD.editorId, OLD.created, NOW());
END; END;
CREATE TABLE IF NOT EXISTS permission_role_auth_users CREATE TABLE IF NOT EXISTS permission_role_auth_users
( (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
RoleId INT NOT NULL, roleId INT NOT NULL,
UserId INT NOT NULL, userId INT NOT NULL,
deleted BOOL NOT NULL DEFAULT FALSE, deleted BOOL NOT NULL DEFAULT FALSE,
editorId INT NULL, editorId INT NULL,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
CONSTRAINT UQ_RoleUser UNIQUE (RoleId, UserId), CONSTRAINT UQ_RoleUser UNIQUE (roleId, userId),
CONSTRAINT FK_Roleauth_users_Role FOREIGN KEY (RoleId) REFERENCES permission_roles (id) ON DELETE CASCADE, CONSTRAINT FK_Roleauth_users_Role FOREIGN KEY (roleId) REFERENCES permission_roles (id) ON DELETE CASCADE,
CONSTRAINT FK_Roleauth_users_User FOREIGN KEY (UserId) REFERENCES administration_auth_users (id) ON DELETE CASCADE, CONSTRAINT FK_Roleauth_users_User FOREIGN KEY (userId) REFERENCES administration_auth_users (id) ON DELETE CASCADE,
CONSTRAINT FK_Roleauth_users_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) CONSTRAINT FK_Roleauth_users_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id)
); );
CREATE TABLE IF NOT EXISTS permission_role_auth_users_history CREATE TABLE IF NOT EXISTS permission_role_auth_users_history
( (
id INT NOT NULL, id INT NOT NULL,
RoleId INT NOT NULL, roleId INT NOT NULL,
UserId INT NOT NULL, userId INT NOT NULL,
deleted BOOL NOT NULL, deleted BOOL NOT NULL,
editorId INT NULL, editorId INT NULL,
created TIMESTAMP NOT NULL, created TIMESTAMP NOT NULL,
@@ -164,8 +164,8 @@ CREATE TRIGGER TR_Roleauth_usersUpdate
FOR EACH ROW FOR EACH ROW
BEGIN BEGIN
INSERT INTO permission_role_auth_users_history INSERT INTO permission_role_auth_users_history
(id, RoleId, UserId, deleted, editorId, created, updated) (id, roleId, userId, deleted, editorId, created, updated)
VALUES (OLD.id, OLD.RoleId, OLD.UserId, OLD.deleted, OLD.editorId, OLD.created, NOW()); VALUES (OLD.id, OLD.roleId, OLD.userId, OLD.deleted, OLD.editorId, OLD.created, NOW());
END; END;
CREATE TRIGGER TR_Roleauth_usersDelete CREATE TRIGGER TR_Roleauth_usersDelete
@@ -174,6 +174,6 @@ CREATE TRIGGER TR_Roleauth_usersDelete
FOR EACH ROW FOR EACH ROW
BEGIN BEGIN
INSERT INTO permission_role_auth_users_history INSERT INTO permission_role_auth_users_history
(id, RoleId, UserId, deleted, editorId, created, updated) (id, roleId, userId, deleted, editorId, created, updated)
VALUES (OLD.id, OLD.RoleId, OLD.UserId, 1, OLD.editorId, OLD.created, NOW()); VALUES (OLD.id, OLD.roleId, OLD.userId, 1, OLD.editorId, OLD.created, NOW());
END; END;

View File

@@ -79,7 +79,7 @@ CREATE TRIGGER versioning_trigger
EXECUTE PROCEDURE public.history_trigger_function(); EXECUTE PROCEDURE public.history_trigger_function();
-- Role user -- Role user
CREATE TABLE permission.role_users CREATE TABLE permission.role_auth_users
( (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
RoleId INT NOT NULL REFERENCES permission.roles (id) ON DELETE CASCADE, RoleId INT NOT NULL REFERENCES permission.roles (id) ON DELETE CASCADE,
@@ -93,13 +93,13 @@ CREATE TABLE permission.role_users
CONSTRAINT UQ_RoleUser UNIQUE (RoleId, UserId) CONSTRAINT UQ_RoleUser UNIQUE (RoleId, UserId)
); );
CREATE TABLE permission.role_users_history CREATE TABLE permission.role_auth_users_history
( (
LIKE permission.role_users LIKE permission.role_auth_users
); );
CREATE TRIGGER versioning_trigger CREATE TRIGGER versioning_trigger
BEFORE INSERT OR UPDATE OR DELETE BEFORE INSERT OR UPDATE OR DELETE
ON permission.role_users ON permission.role_auth_users
FOR EACH ROW FOR EACH ROW
EXECUTE PROCEDURE public.history_trigger_function(); EXECUTE PROCEDURE public.history_trigger_function();

View File

@@ -85,7 +85,7 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
self.__ignored_attributes.add(attr_name) self.__ignored_attributes.add(attr_name)
if not db_name: if not db_name:
db_name = attr_name.lower().replace("_", "") db_name = String.to_camel_case(attr_name)
self.__db_names[attr_name] = db_name self.__db_names[attr_name] = db_name
self.__db_names[db_name] = db_name self.__db_names[db_name] = db_name

View File

@@ -32,7 +32,7 @@ class TableManager:
ServerTypes.MYSQL: "permission_role_permissions", ServerTypes.MYSQL: "permission_role_permissions",
}, },
"role_users": { "role_users": {
ServerTypes.POSTGRES: "permission.role_users", ServerTypes.POSTGRES: "permission.role_auth_users",
ServerTypes.MYSQL: "permission_role_auth_users", ServerTypes.MYSQL: "permission_role_auth_users",
}, },
} }

View File

@@ -4,6 +4,7 @@ from typing import Optional, Any
from graphql import GraphQLResolveInfo from graphql import GraphQLResolveInfo
from cpl.auth.schema import AuthUser, Permission from cpl.auth.schema import AuthUser, Permission
from cpl.core.ctx import get_user
from cpl.core.utils import get_value from cpl.core.utils import get_value
@@ -11,19 +12,13 @@ class QueryContext:
def __init__( def __init__(
self, self,
data: Any,
user: Optional[AuthUser],
user_permissions: Optional[list[Enum | Permission]], user_permissions: Optional[list[Enum | Permission]],
is_mutation: bool = False, is_mutation: bool = False,
*args, *args,
**kwargs **kwargs
): ):
self._user = get_user()
self._data = data self._user_permissions = user_permissions or []
self._user = user
if user_permissions is None:
user_permissions = []
self._user_permissions: list[str] = [x.name for x in user_permissions]
self._resolve_info = None self._resolve_info = None
for arg in args: for arg in args:
@@ -31,21 +26,11 @@ class QueryContext:
self._resolve_info = arg self._resolve_info = arg
continue continue
self._filter = kwargs.get("filters", {})
self._sort = kwargs.get("sort", {})
self._skip = get_value(kwargs, "skip", int)
self._take = get_value(kwargs, "take", int)
self._input = kwargs.get("input", None)
self._args = args self._args = args
self._kwargs = kwargs self._kwargs = kwargs
self._is_mutation = is_mutation self._is_mutation = is_mutation
@property
def data(self):
return self._data
@property @property
def user(self) -> AuthUser: def user(self) -> AuthUser:
return self._user return self._user
@@ -54,26 +39,6 @@ class QueryContext:
def resolve_info(self) -> Optional[GraphQLResolveInfo]: def resolve_info(self) -> Optional[GraphQLResolveInfo]:
return self._resolve_info return self._resolve_info
@property
def filter(self) -> dict:
return self._filter
@property
def sort(self) -> dict:
return self._sort
@property
def skip(self) -> Optional[int]:
return self._skip
@property
def take(self) -> Optional[int]:
return self._take
@property
def input(self) -> Optional[Any]:
return self._input
@property @property
def args(self) -> tuple: def args(self) -> tuple:
return self._args return self._args

View File

@@ -109,9 +109,12 @@ class Field:
self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional) self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional)
return self return self
def with_require_any_permission(self, permissions: TRequireAnyPermissions) -> Self: def with_require_any_permission(self, *permissions: TRequireAnyPermissions) -> Self:
if not isinstance(permissions, list):
permissions = list(permissions)
assert permissions is not None, "require_any_permission cannot be None" assert permissions is not None, "require_any_permission cannot be None"
assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type" assert all(isinstance(x, (str, Enum)) for x in permissions), "All permissions must be of Permission type"
self._require_any_permission = permissions self._require_any_permission = permissions
return self return self
@@ -124,5 +127,7 @@ class Field:
return self return self
def with_public(self, public: bool = False) -> Self: def with_public(self, public: bool = False) -> Self:
assert self._require_any is None, "Field cannot be public and have require_any set"
assert self._require_any_permission is None, "Field cannot be public and have require_any_permission set"
self._public = public self._public = public
return self return self

View File

@@ -1,19 +1,19 @@
import asyncio
import functools import functools
import inspect import inspect
from asyncio import iscoroutinefunction
from typing import Callable, Type, Any, Optional from typing import Callable, Type, Any, Optional
import strawberry import strawberry
from strawberry.exceptions import StrawberryException from strawberry.exceptions import StrawberryException
from cpl.api import Unauthorized, Forbidden from cpl.api import Unauthorized, Forbidden
from cpl.api.middleware.request import get_request
from cpl.core.ctx import get_user from cpl.core.ctx import get_user
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.dependency.inject import inject from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.error import graphql_error from cpl.graphql.error import graphql_error
from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
from cpl.graphql.schema.field import Field from cpl.graphql.schema.field import Field
from cpl.graphql.schema.sort.sort_order import SortOrder from cpl.graphql.schema.sort.sort_order import SortOrder
@@ -141,7 +141,6 @@ class Query(StrawberryProtocol):
if v is None: if v is None:
continue continue
# verschachtelte Inputs rekursiv
if hasattr(v, "__dict__"): if hasattr(v, "__dict__"):
result[k] = input_to_dict(v) result[k] = input_to_dict(v)
else: else:
@@ -152,9 +151,6 @@ class Query(StrawberryProtocol):
filter_dict = input_to_dict(filter) if filter is not None else None filter_dict = input_to_dict(filter) if filter is not None else None
sort_dict = None sort_dict = None
if filter is not None:
pass
if sort is not None: if sort is not None:
sort_dict = {} sort_dict = {}
for k, v in sort.__dict__.items(): for k, v in sort.__dict__.items():
@@ -202,59 +198,55 @@ class Query(StrawberryProtocol):
sig = inspect.Signature(parameters=params, return_annotation=f.type) sig = inspect.Signature(parameters=params, return_annotation=f.type)
def _resolver(*args, **kwargs): async def _resolver(*args, **kwargs):
return f.resolver(*args, **kwargs) if f.resolver else None if f.resolver is None:
return None
if iscoroutinefunction(f.resolver):
return await f.resolver(*args, **kwargs)
return f.resolver(*args, **kwargs)
_resolver.__signature__ = sig _resolver.__signature__ = sig
return _resolver return _resolver
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable: def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
# Signatur vom Original übernehmen
sig = getattr(resolver, "__signature__", None) sig = getattr(resolver, "__signature__", None)
@functools.wraps(resolver) @functools.wraps(resolver)
async def _auth_resolver(*args, **kwargs): async def _auth_resolver(*args, **kwargs):
request = get_request() if f.public:
return await self._run_resolver(resolver, *args, **kwargs)
user = get_user() user = get_user()
# Public
if f.public:
return await self._maybe_await(resolver(*args, **kwargs))
# Auth required
if user is None: if user is None:
raise graphql_error(Unauthorized("Authentication required")) raise graphql_error(Unauthorized(f"{f.name}: Authentication required"))
# Permissions
if f.require_any_permission: if f.require_any_permission:
if not any(user.has_permission(p) for p in f.require_any_permission): if not any([await user.has_permission(p) for p in f.require_any_permission]):
raise Forbidden("Permission denied") raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
# Custom resolvers
if f.require_any: if f.require_any:
perms, resolvers = f.require_any perms, resolvers = f.require_any
if not any(user.has_permission(p) for p in perms): if not any([await user.has_permission(p) for p in perms]):
for r in resolvers: ctx = QueryContext([x.name for x in await user.permissions])
ok = await self._maybe_await(r(user, *args, **kwargs)) resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers]
if ok:
break
else:
raise Forbidden("Permission denied")
return await self._maybe_await(resolver(*args, **kwargs)) if not any(resolved):
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
return await self._run_resolver(resolver, *args, **kwargs)
# Signatur beibehalten
if sig: if sig:
_auth_resolver.__signature__ = sig _auth_resolver.__signature__ = sig
return _auth_resolver return _auth_resolver
@staticmethod @staticmethod
def _maybe_await(value): async def _run_resolver(r: Callable, *args, **kwargs):
if asyncio.iscoroutine(value): if iscoroutinefunction(r):
return value return await r(*args, **kwargs)
return asyncio.sleep(0, result=value) # sofort resolved Future return r(*args, **kwargs)
def _field_to_strawberry(self, f: Field) -> Any: def _field_to_strawberry(self, f: Field) -> Any:
resolver = None resolver = None

View File

@@ -31,17 +31,18 @@ class GraphQLService:
if result.errors: if result.errors:
errors = [] errors = []
for error in result.errors: for error in result.errors:
if isinstance(error, GraphQLError):
self._logger.error(f"GraphQL APIError: {error}")
errors.append({"message": error.message, "extensions": error.extensions})
continue
if isinstance(error, APIError): if isinstance(error, APIError):
self._logger.error(f"GraphQL APIError: {error}") self._logger.error(f"GraphQL APIError", error)
errors.append({"message": error.error_message, "extensions": {"code": error.status_code}}) errors.append({"message": error.error_message, "extensions": {"code": error.status_code}})
continue continue
self._logger.error(f"GraphQL unexpected error: {error}") if isinstance(error, GraphQLError):
self._logger.error(f"GraphQLError", error)
errors.append({"message": error.message, "extensions": error.extensions})
continue
self._logger.error(f"GraphQL unexpected error", error)
errors.append({"message": str(error), "extensions": {"code": 500}}) errors.append({"message": str(error), "extensions": {"code": 500}})
response_data["errors"] = errors response_data["errors"] = errors

View File

@@ -2,10 +2,7 @@ import logging
from typing import Type, Self from typing import Type, Self
import strawberry import strawberry
from starlette.requests import Request
from strawberry.types import ExecutionContext
from cpl.api import APIError
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
@@ -44,16 +41,6 @@ class Schema:
return types return types
def _graphql_exception_handler(self, error: Exception, execution_context: ExecutionContext):
request: Request = execution_context.context.get("request")
if isinstance(error, APIError):
self._logger.error(f"GraphQL APIError: {error}")
return {"message": error.error_message, "extensions": {"code": error.status_code}}
self._logger.error(f"GraphQL unexpected error: {error}")
return {"message": str(error), "extensions": {"code": 500}}
def build(self) -> strawberry.Schema: def build(self) -> strawberry.Schema:
logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL) logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL)
query = self._provider.get_service(RootQuery) query = self._provider.get_service(RootQuery)