From e1ab9cf0db09e30d39dae068905ed0f6c265e15c Mon Sep 17 00:00:00 2001 From: edraft Date: Fri, 26 Sep 2025 21:56:21 +0200 Subject: [PATCH 01/20] Added gql base #181 --- example/api/src/main.py | 17 +++- example/api/src/queries/__init__.py | 0 example/api/src/queries/hello.py | 13 +++ src/cpl-api/cpl/api/abc/web_app_abc.py | 45 +++++++++++ src/cpl-api/cpl/api/application/web_app.py | 43 +++++----- src/cpl-api/cpl/api/error.py | 2 +- src/cpl-api/cpl/api/settings.py | 2 +- src/cpl-api/cpl/api/typing.py | 3 + .../cpl/application/abc/application_abc.py | 2 +- .../auth/schema/_administration/auth_user.py | 3 +- .../cpl/dependency/module/module_abc.py | 2 +- src/cpl-graphql/cpl/graphql/__init__.py | 0 .../cpl/graphql/_endpoints/__init__.py | 0 .../cpl/graphql/_endpoints/graphiql.py | 37 +++++++++ .../cpl/graphql/_endpoints/graphql.py | 13 +++ .../cpl/graphql/_endpoints/playground.py | 27 +++++++ src/cpl-graphql/cpl/graphql/abc/__init__.py | 0 src/cpl-graphql/cpl/graphql/abc/query_base.py | 54 +++++++++++++ .../cpl/graphql/application/__init__.py | 1 + .../cpl/graphql/application/graphql_app.py | 80 +++++++++++++++++++ src/cpl-graphql/cpl/graphql/graphql_module.py | 17 ++++ .../cpl/graphql/schema/__init__.py | 0 src/cpl-graphql/cpl/graphql/schema/field.py | 30 +++++++ src/cpl-graphql/cpl/graphql/schema/query.py | 6 ++ .../cpl/graphql/schema/root_query.py | 6 ++ .../cpl/graphql/service/__init__.py | 0 src/cpl-graphql/cpl/graphql/service/schema.py | 56 +++++++++++++ .../cpl/graphql/service/service.py | 31 +++++++ src/cpl-graphql/cpl/graphql/typing.py | 5 ++ src/cpl-graphql/pyproject.toml | 30 +++++++ src/cpl-graphql/requirements.dev.txt | 1 + src/cpl-graphql/requirements.txt | 2 + src/cpl-query/cpl/query/ordered_enumerable.py | 2 +- 33 files changed, 500 insertions(+), 30 deletions(-) create mode 100644 example/api/src/queries/__init__.py create mode 100644 example/api/src/queries/hello.py create mode 100644 src/cpl-api/cpl/api/abc/web_app_abc.py create mode 100644 src/cpl-graphql/cpl/graphql/__init__.py create mode 100644 src/cpl-graphql/cpl/graphql/_endpoints/__init__.py create mode 100644 src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py create mode 100644 src/cpl-graphql/cpl/graphql/_endpoints/graphql.py create mode 100644 src/cpl-graphql/cpl/graphql/_endpoints/playground.py create mode 100644 src/cpl-graphql/cpl/graphql/abc/__init__.py create mode 100644 src/cpl-graphql/cpl/graphql/abc/query_base.py create mode 100644 src/cpl-graphql/cpl/graphql/application/__init__.py create mode 100644 src/cpl-graphql/cpl/graphql/application/graphql_app.py create mode 100644 src/cpl-graphql/cpl/graphql/graphql_module.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/__init__.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/field.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/query.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/root_query.py create mode 100644 src/cpl-graphql/cpl/graphql/service/__init__.py create mode 100644 src/cpl-graphql/cpl/graphql/service/schema.py create mode 100644 src/cpl-graphql/cpl/graphql/service/service.py create mode 100644 src/cpl-graphql/cpl/graphql/typing.py create mode 100644 src/cpl-graphql/pyproject.toml create mode 100644 src/cpl-graphql/requirements.dev.txt create mode 100644 src/cpl-graphql/requirements.txt diff --git a/example/api/src/main.py b/example/api/src/main.py index 4732035b..04266833 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -3,7 +3,7 @@ from starlette.responses import JSONResponse from cpl.api.api_module import ApiModule from cpl.api.application.web_app import WebApp from cpl.application.application_builder import ApplicationBuilder -from cpl.auth import AuthModule +from cpl.graphql.application.graphql_app import GraphQLApp from cpl.auth.permission.permissions import Permissions from cpl.auth.schema import AuthUser, Role from cpl.core.configuration import Configuration @@ -11,12 +11,15 @@ from cpl.core.console import Console from cpl.core.environment import Environment from cpl.core.utils.cache import Cache from cpl.database.mysql.mysql_module import MySQLModule +from cpl.graphql.graphql_module import GraphQLModule +from cpl.graphql.schema.root_query import RootQuery +from queries.hello import HelloQuery from scoped_service import ScopedService from service import PingService def main(): - builder = ApplicationBuilder[WebApp](WebApp) + builder = ApplicationBuilder[GraphQLApp](GraphQLApp) Configuration.add_json_file(f"appsettings.json") Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json") @@ -27,12 +30,15 @@ def main(): builder.services.add_transient(PingService) builder.services.add_module(MySQLModule) builder.services.add_module(ApiModule) + builder.services.add_module(GraphQLModule) builder.services.add_scoped(ScopedService) builder.services.add_cache(AuthUser) builder.services.add_cache(Role) + builder.services.add_transient(HelloQuery) + app = builder.build() app.with_logging() @@ -48,6 +54,13 @@ def main(): ) app.with_routes_directory("routes") + schema = app.with_graphql() + schema.query.string_field("ping", resolver=lambda *_: "pong") + schema.query.with_query("hello", HelloQuery) + + app.with_playground() + app.with_graphiql() + provider = builder.service_provider user_cache = provider.get_service(Cache[AuthUser]) role_cache = provider.get_service(Cache[Role]) diff --git a/example/api/src/queries/__init__.py b/example/api/src/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/example/api/src/queries/hello.py b/example/api/src/queries/hello.py new file mode 100644 index 00000000..eb512dbb --- /dev/null +++ b/example/api/src/queries/hello.py @@ -0,0 +1,13 @@ +import graphene + +from cpl.graphql.schema.query import Query + + +class HelloQuery(Query): + def __init__(self): + Query.__init__(self) + self.string_field( + "message", + args={"name": graphene.String(default_value="world")}, + resolver=lambda *_, name: f"Hello {name}", + ) diff --git a/src/cpl-api/cpl/api/abc/web_app_abc.py b/src/cpl-api/cpl/api/abc/web_app_abc.py new file mode 100644 index 00000000..fa7eec6e --- /dev/null +++ b/src/cpl-api/cpl/api/abc/web_app_abc.py @@ -0,0 +1,45 @@ +from abc import ABC +from enum import Enum +from typing import Self + +from starlette.applications import Starlette + +from cpl.api.model.api_route import ApiRoute +from cpl.api.model.validation_match import ValidationMatch +from cpl.api.typing import HTTPMethods, PartialMiddleware, TEndpoint, PolicyInput +from cpl.application.abc.application_abc import ApplicationABC +from cpl.dependency.service_provider import ServiceProvider +from cpl.dependency.typing import Modules + + +class WebAppABC(ApplicationABC, ABC): + + def __init__(self, services: ServiceProvider, modules: Modules, required_modules: list[str | object] = None): + ApplicationABC.__init__(self, services, modules, required_modules) + + def with_routes_directory(self, directory: str) -> Self: ... + def with_app(self, app: Starlette) -> Self: ... + def with_routes( + self, + routes: list[ApiRoute], + method: HTTPMethods, + authentication: bool = False, + roles: list[str | Enum] = None, + permissions: list[str | Enum] = None, + policies: list[str] = None, + match: ValidationMatch = None, + ) -> Self: ... + def with_route( + self, + path: str, + fn: TEndpoint, + method: HTTPMethods, + authentication: bool = False, + roles: list[str | Enum] = None, + permissions: list[str | Enum] = None, + policies: list[str] = None, + match: ValidationMatch = None, + ) -> Self: ... + def with_middleware(self, middleware: PartialMiddleware) -> Self: ... + def with_authentication(self) -> Self: ... + def with_authorization(self, *policies: list[PolicyInput] | PolicyInput) -> Self: ... diff --git a/src/cpl-api/cpl/api/application/web_app.py b/src/cpl-api/cpl/api/application/web_app.py index 476e54d2..deeb2710 100644 --- a/src/cpl-api/cpl/api/application/web_app.py +++ b/src/cpl-api/cpl/api/application/web_app.py @@ -1,6 +1,6 @@ import os from enum import Enum -from typing import Mapping, Any, Callable, Self, Union +from typing import Mapping, Any, Self import uvicorn from starlette.applications import Starlette @@ -10,6 +10,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse from starlette.types import ExceptionHandler +from cpl.api.abc.web_app_abc import WebAppABC from cpl.api.api_module import ApiModule from cpl.api.error import APIError from cpl.api.logger import APILogger @@ -24,8 +25,7 @@ from cpl.api.registry.policy import PolicyRegistry from cpl.api.registry.route import RouteRegistry from cpl.api.router import Router from cpl.api.settings import ApiSettings -from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver -from cpl.application.abc.application_abc import ApplicationABC +from cpl.api.typing import HTTPMethods, PartialMiddleware, TEndpoint, PolicyInput from cpl.auth.auth_module import AuthModule from cpl.auth.permission.permission_module import PermissionsModule from cpl.core.configuration.configuration import Configuration @@ -33,12 +33,10 @@ from cpl.dependency.inject import inject from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.typing import Modules -PolicyInput = Union[dict[str, PolicyResolver], Policy] - -class WebApp(ApplicationABC): - def __init__(self, services: ServiceProvider, modules: Modules): - super().__init__(services, modules, [AuthModule, PermissionsModule, ApiModule]) +class WebApp(WebAppABC): + def __init__(self, services: ServiceProvider, modules: Modules, required_modules: list[str | object] = None): + WebAppABC.__init__(self, services, modules, [AuthModule, PermissionsModule, ApiModule] + (required_modules or [])) self._app: Starlette | None = None self._logger = services.get_service(APILogger) @@ -78,16 +76,17 @@ class WebApp(ApplicationABC): self._logger.debug(f"Allowed origins: {origins}") return origins.split(",") - def with_app(self, app: Starlette) -> Self: - assert app is not None, "app must not be None" - assert isinstance(app, Starlette), "app must be an instance of Starlette" - self._app = app - return self - def _check_for_app(self): if self._app is not None: raise ValueError("App is already set, cannot add routes or middleware") + def _validate_policies(self): + for rule in Router.get_authorization_rules(): + for policy_name in rule["policies"]: + policy = self._policies.get(policy_name) + if not policy: + self._logger.fatal(f"Authorization policy '{policy_name}' not found") + def with_routes_directory(self, directory: str) -> Self: self._check_for_app() assert directory is not None, "directory must not be None" @@ -102,6 +101,12 @@ class WebApp(ApplicationABC): return self + def with_app(self, app: Starlette) -> Self: + assert app is not None, "app must not be None" + assert isinstance(app, Starlette), "app must be an instance of Starlette" + self._app = app + return self + def with_routes( self, routes: list[ApiRoute], @@ -131,7 +136,7 @@ class WebApp(ApplicationABC): def with_route( self, path: str, - fn: Callable[[Request], Any], + fn: TEndpoint, method: HTTPMethods, authentication: bool = False, roles: list[str | Enum] = None, @@ -179,6 +184,7 @@ class WebApp(ApplicationABC): return self def with_authorization(self, *policies: list[PolicyInput] | PolicyInput) -> Self: + self._check_for_app() if policies: _policies = [] @@ -206,13 +212,6 @@ class WebApp(ApplicationABC): self.with_middleware(AuthorizationMiddleware) return self - def _validate_policies(self): - for rule in Router.get_authorization_rules(): - for policy_name in rule["policies"]: - policy = self._policies.get(policy_name) - if not policy: - self._logger.fatal(f"Authorization policy '{policy_name}' not found") - async def main(self): self._logger.debug(f"Preparing API") self._validate_policies() diff --git a/src/cpl-api/cpl/api/error.py b/src/cpl-api/cpl/api/error.py index 50329e98..8fad7e5e 100644 --- a/src/cpl-api/cpl/api/error.py +++ b/src/cpl-api/cpl/api/error.py @@ -8,7 +8,7 @@ class APIError(HTTPException): status_code = 500 def __init__(self, message: str = ""): - super().__init__(self.status_code, message) + HTTPException.__init__(self, self.status_code, message) self._message = message @property diff --git a/src/cpl-api/cpl/api/settings.py b/src/cpl-api/cpl/api/settings.py index 2f11f5d7..900c2dd2 100644 --- a/src/cpl-api/cpl/api/settings.py +++ b/src/cpl-api/cpl/api/settings.py @@ -6,7 +6,7 @@ from cpl.core.configuration import ConfigurationModelABC class ApiSettings(ConfigurationModelABC): def __init__(self, src: Optional[dict] = None): - super().__init__(src) + ConfigurationModelABC.__init__(self, src) self.option("host", str, "0.0.0.0") self.option("port", int, 5000) diff --git a/src/cpl-api/cpl/api/typing.py b/src/cpl-api/cpl/api/typing.py index c8319900..a62d4927 100644 --- a/src/cpl-api/cpl/api/typing.py +++ b/src/cpl-api/cpl/api/typing.py @@ -2,6 +2,7 @@ from typing import Union, Literal, Callable, Type, Awaitable from urllib.request import Request from starlette.middleware import Middleware +from starlette.responses import Response from starlette.types import ASGIApp from starlette.websockets import WebSocket @@ -9,6 +10,7 @@ from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.auth.schema import AuthUser TRequest = Union[Request, WebSocket] +TEndpoint = Callable[[TRequest, ...], Awaitable[Response]] | Callable[[TRequest, ...], Response] HTTPMethods = Literal["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"] PartialMiddleware = Union[ ASGIMiddleware, @@ -17,3 +19,4 @@ PartialMiddleware = Union[ Callable[[ASGIApp], ASGIApp], ] PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]] +PolicyInput = Union[dict[str, PolicyResolver], "Policy"] diff --git a/src/cpl-application/cpl/application/abc/application_abc.py b/src/cpl-application/cpl/application/abc/application_abc.py index 59c43b88..a90db406 100644 --- a/src/cpl-application/cpl/application/abc/application_abc.py +++ b/src/cpl-application/cpl/application/abc/application_abc.py @@ -56,7 +56,7 @@ class ApplicationABC(ABC): module_dependency_error( type(self).__name__, - module.__name__, + module.__name__ if not isinstance(module, str) else module, ImportError( f"Required module '{module}' for application '{self.__class__.__name__}' is not loaded. Load using 'add_module({module})' method." ), diff --git a/src/cpl-auth/cpl/auth/schema/_administration/auth_user.py b/src/cpl-auth/cpl/auth/schema/_administration/auth_user.py index cae14f97..5409e468 100644 --- a/src/cpl-auth/cpl/auth/schema/_administration/auth_user.py +++ b/src/cpl-auth/cpl/auth/schema/_administration/auth_user.py @@ -10,7 +10,7 @@ from cpl.auth.permission.permissions import Permissions from cpl.core.typing import SerialId from cpl.database.abc import DbModelABC from cpl.database.logger import DBLogger -from cpl.dependency import ServiceProvider +from cpl.dependency import get_provider class AuthUser(DbModelABC): @@ -87,3 +87,4 @@ class AuthUser(DbModelABC): self._keycloak_id = str(uuid.UUID(int=0)) await auth_user_dao.update(self) + diff --git a/src/cpl-dependency/cpl/dependency/module/module_abc.py b/src/cpl-dependency/cpl/dependency/module/module_abc.py index 971a721c..9cf0c9f8 100644 --- a/src/cpl-dependency/cpl/dependency/module/module_abc.py +++ b/src/cpl-dependency/cpl/dependency/module/module_abc.py @@ -8,7 +8,7 @@ class ModuleABC(ABC): __OPTIONAL_VARS = ["dependencies", "configuration", "singleton", "scoped", "transient", "hosted"] def __init_subclass__(cls): - super().__init_subclass__() + ABC.__init_subclass__() if f"{cls.__module__}.{cls.__name__}" == "cpl.dependency.module.module.Module": return diff --git a/src/cpl-graphql/cpl/graphql/__init__.py b/src/cpl-graphql/cpl/graphql/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/_endpoints/__init__.py b/src/cpl-graphql/cpl/graphql/_endpoints/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py b/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py new file mode 100644 index 00000000..2aedb538 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py @@ -0,0 +1,37 @@ +from starlette.responses import HTMLResponse + +async def graphiql_endpoint(request): + return HTMLResponse(""" + + + + + GraphiQL + + + +
+ + + + + + + + + + + + """) diff --git a/src/cpl-graphql/cpl/graphql/_endpoints/graphql.py b/src/cpl-graphql/cpl/graphql/_endpoints/graphql.py new file mode 100644 index 00000000..0808d704 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/_endpoints/graphql.py @@ -0,0 +1,13 @@ +from starlette.requests import Request +from starlette.responses import Response, JSONResponse + +from cpl.graphql.service.service import GraphQLService + + +async def graphql_endpoint(request: Request, service: GraphQLService) -> Response: + body = await request.json() + query = body.get("query") + variables = body.get("variables") + + response_data = await service.execute(query, variables, request) + return JSONResponse(response_data) diff --git a/src/cpl-graphql/cpl/graphql/_endpoints/playground.py b/src/cpl-graphql/cpl/graphql/_endpoints/playground.py new file mode 100644 index 00000000..68e59fdf --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/_endpoints/playground.py @@ -0,0 +1,27 @@ +from starlette.requests import Request +from starlette.responses import Response, HTMLResponse + + +async def playground_endpoint(request: Request) -> Response: + return HTMLResponse(""" + + + + + GraphQL Playground + + + + + +
+ + + + """) diff --git a/src/cpl-graphql/cpl/graphql/abc/__init__.py b/src/cpl-graphql/cpl/graphql/abc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/abc/query_base.py b/src/cpl-graphql/cpl/graphql/abc/query_base.py new file mode 100644 index 00000000..b0b47424 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/abc/query_base.py @@ -0,0 +1,54 @@ +from typing import Callable, Any, Type + +from graphene import ObjectType + + +class QueryBase(ObjectType): + + def __init__(self): + from cpl.graphql.schema.field import Field + + ObjectType.__init__(self) + self._fields: dict[str, Field] = {} + + def get_fields(self) -> dict[str, "Field"]: + return self._fields + + def field( + self, + name: str, + t: type, + args: dict[str, Any] | None = None, + resolver: Callable | None = None, + ): + gql_type_map: dict[object, str] = { + str: "String", + int: "Int", + float: "Float", + bool: "Boolean", + } + + if t not in gql_type_map: + raise ValueError(f"Unsupported field type: {t}") + + from cpl.graphql.schema.field import Field + + self._fields[name] = Field(name, "String", resolver, args) + + def with_query(self, name: str, subquery: Type["QueryBase"]): + from cpl.graphql.schema.field import Field + + f = Field(name=name, gql_type="Object", resolver=lambda root, info, **kwargs: {}, subquery=subquery) + self._fields[name] = f + + def string_field(self, name: str, args: dict[str, Any] | None = None, resolver: Callable | None = None): + self.field(name, str, args, resolver) + + def int_field(self, name: str, args: dict[str, Any] | None = None, resolver: Callable | None = None): + self.field(name, int, args, resolver) + + def float_field(self, name: str, args: dict[str, Any] | None = None, resolver: Callable | None = None): + self.field(name, float, args, resolver) + + def bool_field(self, name: str, args: dict[str, Any] | None = None, resolver: Callable | None = None): + self.field(name, bool, args, resolver) diff --git a/src/cpl-graphql/cpl/graphql/application/__init__.py b/src/cpl-graphql/cpl/graphql/application/__init__.py new file mode 100644 index 00000000..96b2346c --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/application/__init__.py @@ -0,0 +1 @@ +from .graphql_app import WebApp diff --git a/src/cpl-graphql/cpl/graphql/application/graphql_app.py b/src/cpl-graphql/cpl/graphql/application/graphql_app.py new file mode 100644 index 00000000..ad4b06f0 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/application/graphql_app.py @@ -0,0 +1,80 @@ +from enum import Enum +from typing import Self + +from cpl.api.application import WebApp +from cpl.api.model.validation_match import ValidationMatch +from cpl.dependency.service_provider import ServiceProvider +from cpl.dependency.typing import Modules +from .._endpoints.graphiql import graphiql_endpoint +from .._endpoints.graphql import graphql_endpoint +from .._endpoints.playground import playground_endpoint +from ..graphql_module import GraphQLModule +from ..service.schema import Schema + + +class GraphQLApp(WebApp): + def __init__(self, services: ServiceProvider, modules: Modules): + WebApp.__init__(self, services, modules, [GraphQLModule]) + + def with_graphql( + self, + authentication: bool = False, + roles: list[str | Enum] = None, + permissions: list[str | Enum] = None, + policies: list[str] = None, + match: ValidationMatch = None, + ) -> Schema: + self.with_route( + path="/api/graphql", + fn=graphql_endpoint, + method="POST", + authentication=authentication, + roles=roles, + permissions=permissions, + policies=policies, + match=match, + ) + schema = self._services.get_service(Schema) + if schema is None: + self._logger.fatal("Could not resolve RootQuery. Make sure GraphQLModule is registered.") + return schema + + def with_graphiql( + self, + authentication: bool = False, + roles: list[str | Enum] = None, + permissions: list[str | Enum] = None, + policies: list[str] = None, + match: ValidationMatch = None, + ) -> Self: + self.with_route( + path="/api/graphiql", + fn=graphiql_endpoint, + method="GET", + authentication=authentication, + roles=roles, + permissions=permissions, + policies=policies, + match=match, + ) + return self + + def with_playground( + self, + authentication: bool = False, + roles: list[str | Enum] = None, + permissions: list[str | Enum] = None, + policies: list[str] = None, + match: ValidationMatch = None, + ) -> Self: + self.with_route( + path="/api/playground", + fn=playground_endpoint, + method="GET", + authentication=authentication, + roles=roles, + permissions=permissions, + policies=policies, + match=match, + ) + return self diff --git a/src/cpl-graphql/cpl/graphql/graphql_module.py b/src/cpl-graphql/cpl/graphql/graphql_module.py new file mode 100644 index 00000000..e4cd635c --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/graphql_module.py @@ -0,0 +1,17 @@ +from cpl.api import ApiModule +from cpl.dependency.module.module import Module +from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.schema.root_query import RootQuery +from cpl.graphql.service.schema import Schema +from cpl.graphql.service.service import GraphQLService + + +class GraphQLModule(Module): + dependencies = [ApiModule] + singleton = [Schema, RootQuery] + scoped = [GraphQLService] + + @staticmethod + def configure(services: ServiceProvider) -> None: + schema = services.get_service(Schema) + schema.build() diff --git a/src/cpl-graphql/cpl/graphql/schema/__init__.py b/src/cpl-graphql/cpl/graphql/schema/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/schema/field.py b/src/cpl-graphql/cpl/graphql/schema/field.py new file mode 100644 index 00000000..c273b3f3 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/field.py @@ -0,0 +1,30 @@ +from cpl.graphql.abc.query_base import QueryBase + + +class Field: + def __init__(self, name: str, gql_type: str, resolver: callable, args: dict | None = None, subquery: QueryBase | None = None): + self._name = name + self._gql_type = gql_type + self._resolver = resolver + self._args = args or {} + self._subquery: QueryBase | None = subquery + + @property + def name(self) -> str: + return self._name + + @property + def type(self) -> str: + return self._gql_type + + @property + def resolver(self) -> callable: + return self._resolver + + @property + def args(self) -> dict: + return self._args + + @property + def subquery(self) -> QueryBase | None: + return self._subquery \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py new file mode 100644 index 00000000..32ef46d2 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -0,0 +1,6 @@ +from cpl.graphql.abc.query_base import QueryBase + + +class Query(QueryBase): + def __init__(self): + QueryBase.__init__(self) \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/schema/root_query.py b/src/cpl-graphql/cpl/graphql/schema/root_query.py new file mode 100644 index 00000000..85ee1d38 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/root_query.py @@ -0,0 +1,6 @@ +from cpl.graphql.schema.query import Query + + +class RootQuery(Query): + def __init__(self): + Query.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/service/__init__.py b/src/cpl-graphql/cpl/graphql/service/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py new file mode 100644 index 00000000..48ce5c5f --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -0,0 +1,56 @@ +import graphene + +from cpl.api import APILogger +from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.schema.query import Query +from cpl.graphql.schema.root_query import RootQuery + + +class Schema: + + def __init__(self, logger: APILogger, query: RootQuery, provider: ServiceProvider): + self._logger = logger + self._provider = provider + + self._query = query + self._schema = None + + @property + def schema(self) -> graphene.Schema | None: + return self._schema + + @property + def query(self) -> RootQuery: + return self._query + + def build(self) -> graphene.Schema: + self._schema = graphene.Schema( + query=self.to_graphene(self._query), + mutation=None, + subscription=None, + ) + return self._schema + + def to_graphene(self, query: Query, name: str | None = None): + assert query is not None, "Query cannot be None" + attrs = {} + + for field in query.get_fields().values(): + if field.type == "String": + attrs[field.name] = graphene.Field( + graphene.String, + **field.args, + resolver=field.resolver + ) + + elif field.type == "Object" and field.subquery is not None: + subquery = self._provider.get_service(field.subquery) + sub = self.to_graphene(subquery, name=field.name.capitalize()) + attrs[field.name] = graphene.Field( + sub, + **field.args, + resolver=field.resolver + ) + + class_name = name or query.__class__.__name__ + return type(class_name, (graphene.ObjectType,), attrs) diff --git a/src/cpl-graphql/cpl/graphql/service/service.py b/src/cpl-graphql/cpl/graphql/service/service.py new file mode 100644 index 00000000..d0a65891 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/service/service.py @@ -0,0 +1,31 @@ +from typing import Any, Dict, Optional + +from cpl.api.typing import TRequest +from cpl.graphql.service.schema import Schema + + +class GraphQLService: + def __init__(self, schema: Schema): + self._schema = schema.schema + if self._schema is None: + raise ValueError("Schema has not been built. Call schema.build() before using the service.") + + async def execute( + self, + query: str, + variables: Optional[Dict[str, Any]], + request: TRequest, + ) -> Dict[str, Any]: + result = await self._schema.execute_async( + query, + variable_values=variables, + context_value={"request": request}, + ) + + response_data: Dict[str, Any] = {} + if result.errors: + response_data["errors"] = [str(e) for e in result.errors] + if result.data: + response_data["data"] = result.data + + return response_data diff --git a/src/cpl-graphql/cpl/graphql/typing.py b/src/cpl-graphql/cpl/graphql/typing.py new file mode 100644 index 00000000..58587f3f --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/typing.py @@ -0,0 +1,5 @@ +from typing import Type + +from cpl.graphql.schema.query import Query + +TQuery = Type[Query] \ No newline at end of file diff --git a/src/cpl-graphql/pyproject.toml b/src/cpl-graphql/pyproject.toml new file mode 100644 index 00000000..cecb85d2 --- /dev/null +++ b/src/cpl-graphql/pyproject.toml @@ -0,0 +1,30 @@ +[build-system] +requires = ["setuptools>=70.1.0", "wheel>=0.43.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "cpl-database" +version = "2024.7.0" +description = "CPL database" +readme ="CPL database package" +requires-python = ">=3.12" +license = { text = "MIT" } +authors = [ + { name = "Sven Heidemann", email = "sven.heidemann@sh-edraft.de" } +] +keywords = ["cpl", "database", "backend", "shared", "library"] + +dynamic = ["dependencies", "optional-dependencies"] + +[project.urls] +Homepage = "https://www.sh-edraft.de" + +[tool.setuptools.packages.find] +where = ["."] +include = ["cpl*"] + +[tool.setuptools.dynamic] +dependencies = { file = ["requirements.txt"] } +optional-dependencies.dev = { file = ["requirements.dev.txt"] } + + diff --git a/src/cpl-graphql/requirements.dev.txt b/src/cpl-graphql/requirements.dev.txt new file mode 100644 index 00000000..e7664b42 --- /dev/null +++ b/src/cpl-graphql/requirements.dev.txt @@ -0,0 +1 @@ +black==25.1.0 \ No newline at end of file diff --git a/src/cpl-graphql/requirements.txt b/src/cpl-graphql/requirements.txt new file mode 100644 index 00000000..abe92c36 --- /dev/null +++ b/src/cpl-graphql/requirements.txt @@ -0,0 +1,2 @@ +cpl-api +graphene==3.4.3 \ No newline at end of file diff --git a/src/cpl-query/cpl/query/ordered_enumerable.py b/src/cpl-query/cpl/query/ordered_enumerable.py index 89edc3d7..03405057 100644 --- a/src/cpl-query/cpl/query/ordered_enumerable.py +++ b/src/cpl-query/cpl/query/ordered_enumerable.py @@ -6,7 +6,7 @@ from cpl.query.typing import K class OrderedEnumerable(Enumerable[T]): def __init__(self, source, key_selectors: List[tuple[Callable[[T], K], bool]]): - super().__init__(source) + Enumerable.__init__(self, source) self._key_selectors = key_selectors def __iter__(self) -> Iterator[T]: From b0f1fb983985c562180e4ac877c64859e3018c5f Mon Sep 17 00:00:00 2001 From: edraft Date: Sat, 27 Sep 2025 02:31:43 +0200 Subject: [PATCH 02/20] Removed query base #181 --- src/cpl-graphql/cpl/graphql/abc/__init__.py | 0 src/cpl-graphql/cpl/graphql/abc/query_base.py | 54 ------------------- src/cpl-graphql/cpl/graphql/graphql_module.py | 6 +-- src/cpl-graphql/cpl/graphql/schema/field.py | 8 +-- src/cpl-graphql/cpl/graphql/schema/query.py | 54 +++++++++++++++++-- src/cpl-graphql/cpl/graphql/service/schema.py | 2 +- .../cpl/graphql/service/service.py | 4 +- 7 files changed, 61 insertions(+), 67 deletions(-) delete mode 100644 src/cpl-graphql/cpl/graphql/abc/__init__.py delete mode 100644 src/cpl-graphql/cpl/graphql/abc/query_base.py diff --git a/src/cpl-graphql/cpl/graphql/abc/__init__.py b/src/cpl-graphql/cpl/graphql/abc/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/cpl-graphql/cpl/graphql/abc/query_base.py b/src/cpl-graphql/cpl/graphql/abc/query_base.py deleted file mode 100644 index b0b47424..00000000 --- a/src/cpl-graphql/cpl/graphql/abc/query_base.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Callable, Any, Type - -from graphene import ObjectType - - -class QueryBase(ObjectType): - - def __init__(self): - from cpl.graphql.schema.field import Field - - ObjectType.__init__(self) - self._fields: dict[str, Field] = {} - - def get_fields(self) -> dict[str, "Field"]: - return self._fields - - def field( - self, - name: str, - t: type, - args: dict[str, Any] | None = None, - resolver: Callable | None = None, - ): - gql_type_map: dict[object, str] = { - str: "String", - int: "Int", - float: "Float", - bool: "Boolean", - } - - if t not in gql_type_map: - raise ValueError(f"Unsupported field type: {t}") - - from cpl.graphql.schema.field import Field - - self._fields[name] = Field(name, "String", resolver, args) - - def with_query(self, name: str, subquery: Type["QueryBase"]): - from cpl.graphql.schema.field import Field - - f = Field(name=name, gql_type="Object", resolver=lambda root, info, **kwargs: {}, subquery=subquery) - self._fields[name] = f - - def string_field(self, name: str, args: dict[str, Any] | None = None, resolver: Callable | None = None): - self.field(name, str, args, resolver) - - def int_field(self, name: str, args: dict[str, Any] | None = None, resolver: Callable | None = None): - self.field(name, int, args, resolver) - - def float_field(self, name: str, args: dict[str, Any] | None = None, resolver: Callable | None = None): - self.field(name, float, args, resolver) - - def bool_field(self, name: str, args: dict[str, Any] | None = None, resolver: Callable | None = None): - self.field(name, bool, args, resolver) diff --git a/src/cpl-graphql/cpl/graphql/graphql_module.py b/src/cpl-graphql/cpl/graphql/graphql_module.py index e4cd635c..cb20a7d3 100644 --- a/src/cpl-graphql/cpl/graphql/graphql_module.py +++ b/src/cpl-graphql/cpl/graphql/graphql_module.py @@ -1,4 +1,4 @@ -from cpl.api import ApiModule +from cpl.api.api_module import ApiModule from cpl.dependency.module.module import Module from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.schema.root_query import RootQuery @@ -8,8 +8,8 @@ from cpl.graphql.service.service import GraphQLService class GraphQLModule(Module): dependencies = [ApiModule] - singleton = [Schema, RootQuery] - scoped = [GraphQLService] + singleton = [Schema] + scoped = [GraphQLService, RootQuery] @staticmethod def configure(services: ServiceProvider) -> None: diff --git a/src/cpl-graphql/cpl/graphql/schema/field.py b/src/cpl-graphql/cpl/graphql/schema/field.py index c273b3f3..13261675 100644 --- a/src/cpl-graphql/cpl/graphql/schema/field.py +++ b/src/cpl-graphql/cpl/graphql/schema/field.py @@ -1,13 +1,13 @@ -from cpl.graphql.abc.query_base import QueryBase +from cpl.graphql.schema.query import Query class Field: - def __init__(self, name: str, gql_type: str, resolver: callable, args: dict | None = None, subquery: QueryBase | None = None): + def __init__(self, name: str, gql_type: str, resolver: callable, args: dict | None = None, subquery: Query | None = None): self._name = name self._gql_type = gql_type self._resolver = resolver self._args = args or {} - self._subquery: QueryBase | None = subquery + self._subquery: Query | None = subquery @property def name(self) -> str: @@ -26,5 +26,5 @@ class Field: return self._args @property - def subquery(self) -> QueryBase | None: + def subquery(self) -> Query | None: return self._subquery \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py index 32ef46d2..13f4c62f 100644 --- a/src/cpl-graphql/cpl/graphql/schema/query.py +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -1,6 +1,54 @@ -from cpl.graphql.abc.query_base import QueryBase +from typing import Callable, Any, Type + +from graphene import ObjectType -class Query(QueryBase): +class Query(ObjectType): + def __init__(self): - QueryBase.__init__(self) \ No newline at end of file + from cpl.graphql.schema.field import Field + + ObjectType.__init__(self) + self._fields: dict[str, Field] = {} + + def get_fields(self) -> dict[str, "Field"]: + return self._fields + + def field( + self, + name: str, + t: type, + args: dict[str, Any] | None = None, + resolver: Callable | None = None, + ): + gql_type_map: dict[object, str] = { + str: "String", + int: "Int", + float: "Float", + bool: "Boolean", + } + + if t not in gql_type_map: + raise ValueError(f"Unsupported field type: {t}") + + from cpl.graphql.schema.field import Field + + self._fields[name] = Field(name, "String", resolver, args) + + def with_query(self, name: str, subquery: Type["Query"]): + from cpl.graphql.schema.field import Field + + f = Field(name=name, gql_type="Object", resolver=lambda root, info, **kwargs: {}, subquery=subquery) + self._fields[name] = f + + def string_field(self, name: str, args: dict[str, Any] | None = None, resolver: Callable | None = None): + self.field(name, str, args, resolver) + + def int_field(self, name: str, args: dict[str, Any] | None = None, resolver: Callable | None = None): + self.field(name, int, args, resolver) + + def float_field(self, name: str, args: dict[str, Any] | None = None, resolver: Callable | None = None): + self.field(name, float, args, resolver) + + def bool_field(self, name: str, args: dict[str, Any] | None = None, resolver: Callable | None = None): + self.field(name, bool, args, resolver) diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py index 48ce5c5f..0dcf02b6 100644 --- a/src/cpl-graphql/cpl/graphql/service/schema.py +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -1,6 +1,6 @@ import graphene -from cpl.api import APILogger +from cpl.api.logger import APILogger from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.schema.query import Query from cpl.graphql.schema.root_query import RootQuery diff --git a/src/cpl-graphql/cpl/graphql/service/service.py b/src/cpl-graphql/cpl/graphql/service/service.py index d0a65891..54c4f388 100644 --- a/src/cpl-graphql/cpl/graphql/service/service.py +++ b/src/cpl-graphql/cpl/graphql/service/service.py @@ -6,9 +6,9 @@ from cpl.graphql.service.schema import Schema class GraphQLService: def __init__(self, schema: Schema): - self._schema = schema.schema - if self._schema is None: + if schema.schema is None: raise ValueError("Schema has not been built. Call schema.build() before using the service.") + self._schema = schema.schema async def execute( self, From 683805137ae9d3fdfa8f45e40a4d5c9aae2e22da Mon Sep 17 00:00:00 2001 From: edraft Date: Sat, 27 Sep 2025 03:15:55 +0200 Subject: [PATCH 03/20] Added arguments to field #181 --- example/api/src/queries/hello.py | 8 ++- .../cpl/graphql/schema/argument.py | 22 ++++++++ src/cpl-graphql/cpl/graphql/schema/field.py | 33 +++++++++--- src/cpl-graphql/cpl/graphql/schema/query.py | 51 +++++++++---------- src/cpl-graphql/cpl/graphql/service/schema.py | 34 ++++++++----- src/cpl-graphql/cpl/graphql/typing.py | 8 +-- src/cpl-graphql/cpl/graphql/utils/__init__.py | 0 .../cpl/graphql/utils/type_converter.py | 38 ++++++++++++++ 8 files changed, 137 insertions(+), 57 deletions(-) create mode 100644 src/cpl-graphql/cpl/graphql/schema/argument.py create mode 100644 src/cpl-graphql/cpl/graphql/utils/__init__.py create mode 100644 src/cpl-graphql/cpl/graphql/utils/type_converter.py diff --git a/example/api/src/queries/hello.py b/example/api/src/queries/hello.py index eb512dbb..58c747b1 100644 --- a/example/api/src/queries/hello.py +++ b/example/api/src/queries/hello.py @@ -1,5 +1,4 @@ -import graphene - +from cpl.api.middleware.request import get_request from cpl.graphql.schema.query import Query @@ -8,6 +7,5 @@ class HelloQuery(Query): Query.__init__(self) self.string_field( "message", - args={"name": graphene.String(default_value="world")}, - resolver=lambda *_, name: f"Hello {name}", - ) + resolver=lambda *_, name: f"Hello {name} {get_request().state.request_id}", + ).with_argument(str, "name", "Name to greet", "world") diff --git a/src/cpl-graphql/cpl/graphql/schema/argument.py b/src/cpl-graphql/cpl/graphql/schema/argument.py new file mode 100644 index 00000000..2f3b938c --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/argument.py @@ -0,0 +1,22 @@ +class Argument: + def __init__(self, t: type, name: str, description: str = None, default_value=None): + self._type = t + self._name = name + self._description = description + self._default_value = default_value + + @property + def type(self) -> type: + return self._type + + @property + def name(self) -> str: + return self._name + + @property + def description(self) -> str | None: + return self._description + + @property + def default_value(self): + return self._default_value diff --git a/src/cpl-graphql/cpl/graphql/schema/field.py b/src/cpl-graphql/cpl/graphql/schema/field.py index 13261675..2744f6d2 100644 --- a/src/cpl-graphql/cpl/graphql/schema/field.py +++ b/src/cpl-graphql/cpl/graphql/schema/field.py @@ -1,20 +1,25 @@ -from cpl.graphql.schema.query import Query +from typing import Self + +from cpl.graphql.schema.argument import Argument +from cpl.graphql.typing import TQuery class Field: - def __init__(self, name: str, gql_type: str, resolver: callable, args: dict | None = None, subquery: Query | None = None): + + def __init__(self, name: str, gql_type: type, resolver: callable, subquery: TQuery | None = None): self._name = name self._gql_type = gql_type self._resolver = resolver - self._args = args or {} - self._subquery: Query | None = subquery + self._subquery = subquery + + self._args: dict[str, Argument] = {} @property def name(self) -> str: return self._name @property - def type(self) -> str: + def type(self) -> type: return self._gql_type @property @@ -26,5 +31,19 @@ class Field: return self._args @property - def subquery(self) -> Query | None: - return self._subquery \ No newline at end of file + def subquery(self) -> TQuery | None: + return self._subquery + + def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None) -> Self: + if name in self._args: + raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'") + self._args[name] = Argument(name, arg_type, description, default_value) + return self + + def with_arguments(self, args: list[Argument]) -> Self: + for arg in args: + if not isinstance(arg, Argument): + raise ValueError(f"Expected Argument instance, got {type(arg)}") + + self.with_argument(arg.name, arg.type, arg.description, arg.default_value) + return self diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py index 13f4c62f..ba93c3a1 100644 --- a/src/cpl-graphql/cpl/graphql/schema/query.py +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -1,7 +1,11 @@ -from typing import Callable, Any, Type +from typing import Callable, Type from graphene import ObjectType +from cpl.graphql.schema.field import Field +from cpl.graphql.typing import Resolver +from cpl.graphql.utils.type_converter import TypeConverter + class Query(ObjectType): @@ -11,44 +15,35 @@ class Query(ObjectType): ObjectType.__init__(self) self._fields: dict[str, Field] = {} - def get_fields(self) -> dict[str, "Field"]: + def get_fields(self) -> dict[str, Field]: return self._fields def field( - self, - name: str, - t: type, - args: dict[str, Any] | None = None, - resolver: Callable | None = None, - ): - gql_type_map: dict[object, str] = { - str: "String", - int: "Int", - float: "Float", - bool: "Boolean", - } - - if t not in gql_type_map: - raise ValueError(f"Unsupported field type: {t}") - + self, + name: str, + t: type, + resolver: Callable | None = None, + ) -> "Field": from cpl.graphql.schema.field import Field - self._fields[name] = Field(name, "String", resolver, args) + self._fields[name] = Field(name, t, resolver) + return self._fields[name] def with_query(self, name: str, subquery: Type["Query"]): from cpl.graphql.schema.field import Field - f = Field(name=name, gql_type="Object", resolver=lambda root, info, **kwargs: {}, subquery=subquery) + f = Field(name=name, gql_type=object, resolver=lambda root, info, **kwargs: {}, subquery=subquery) self._fields[name] = f + return self._fields[name] - def string_field(self, name: str, args: dict[str, Any] | None = None, resolver: Callable | None = None): - self.field(name, str, args, resolver) + def string_field(self, name: str, resolver: Resolver = None) -> "Field": + return self.field(name, str, resolver) - def int_field(self, name: str, args: dict[str, Any] | None = None, resolver: Callable | None = None): - self.field(name, int, args, resolver) + def int_field(self, name: str, resolver: Resolver = None) -> "Field": + return self.field(name, int, resolver) - def float_field(self, name: str, args: dict[str, Any] | None = None, resolver: Callable | None = None): - self.field(name, float, args, resolver) + def float_field(self, name: str, resolver: Resolver = None) -> "Field": + return self.field(name, float, resolver) - def bool_field(self, name: str, args: dict[str, Any] | None = None, resolver: Callable | None = None): - self.field(name, bool, args, resolver) + def bool_field(self, name: str, resolver: Resolver = None) -> "Field": + return self.field(name, bool, resolver) diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py index 0dcf02b6..69e76e48 100644 --- a/src/cpl-graphql/cpl/graphql/service/schema.py +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -1,9 +1,14 @@ +from typing import Type + import graphene from cpl.api.logger import APILogger from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.schema.argument import Argument from cpl.graphql.schema.query import Query from cpl.graphql.schema.root_query import RootQuery +from cpl.graphql.typing import Resolver +from cpl.graphql.utils.type_converter import TypeConverter class Schema: @@ -31,26 +36,29 @@ class Schema: ) return self._schema + @staticmethod + def _field_to_graphene(t: Type[graphene.Scalar] | type, args: dict[str, Argument] = None, resolver: Resolver = None) -> graphene.Field: + arguments = {} + if args is not None: + arguments = { + arg.name: graphene.Argument(TypeConverter.to_graphene(arg.type), description=arg.description, default_value=arg.default_value) + for arg in args.values() + } + + return graphene.Field(t, args=arguments, resolver=resolver) + def to_graphene(self, query: Query, name: str | None = None): assert query is not None, "Query cannot be None" attrs = {} for field in query.get_fields().values(): - if field.type == "String": - attrs[field.name] = graphene.Field( - graphene.String, - **field.args, - resolver=field.resolver - ) - - elif field.type == "Object" and field.subquery is not None: + if field.type == object and field.subquery is not None: subquery = self._provider.get_service(field.subquery) sub = self.to_graphene(subquery, name=field.name.capitalize()) - attrs[field.name] = graphene.Field( - sub, - **field.args, - resolver=field.resolver - ) + attrs[field.name] = self._field_to_graphene(sub, field.args, field.resolver) + continue + + attrs[field.name] = self._field_to_graphene(TypeConverter.to_graphene(field.type), field.args, field.resolver) class_name = name or query.__class__.__name__ return type(class_name, (graphene.ObjectType,), attrs) diff --git a/src/cpl-graphql/cpl/graphql/typing.py b/src/cpl-graphql/cpl/graphql/typing.py index 58587f3f..d5b63494 100644 --- a/src/cpl-graphql/cpl/graphql/typing.py +++ b/src/cpl-graphql/cpl/graphql/typing.py @@ -1,5 +1,5 @@ -from typing import Type +from typing import Type, Callable -from cpl.graphql.schema.query import Query - -TQuery = Type[Query] \ No newline at end of file +TQuery = Type["Query"] +Resolver = Callable +ScalarType = str | int | float | bool | object \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/utils/__init__.py b/src/cpl-graphql/cpl/graphql/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/utils/type_converter.py b/src/cpl-graphql/cpl/graphql/utils/type_converter.py new file mode 100644 index 00000000..c89b8928 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/utils/type_converter.py @@ -0,0 +1,38 @@ +from typing import Type + +import graphene + +from cpl.graphql.typing import ScalarType + + +class TypeConverter: + + @staticmethod + def from_graphene(t: Type[graphene.Scalar]) -> ScalarType: + graphene_type_map: dict[Type[graphene.Scalar], ScalarType] = { + graphene.String: str, + graphene.Int: int, + graphene.Float: float, + graphene.Boolean: bool, + graphene.ObjectType: object, + } + + if t not in graphene_type_map: + raise ValueError(f"Unsupported field type: {t}") + + return graphene_type_map[t] + + @staticmethod + def to_graphene(t: ScalarType) -> Type[graphene.Scalar]: + type_graphene_map: dict[ScalarType, Type[graphene.Scalar]] = { + str: graphene.String, + int: graphene.Int, + float: graphene.Float, + bool: graphene.Boolean, + object: graphene.ObjectType, + } + + if t not in type_graphene_map: + raise ValueError(f"Unsupported field type: {t}") + + return type_graphene_map[t] From a35b44b3b5c45913dca05806586be4a937fc47cd Mon Sep 17 00:00:00 2001 From: edraft Date: Sat, 27 Sep 2025 04:08:32 +0200 Subject: [PATCH 04/20] [WIP] collection #181 --- example/api/src/main.py | 8 +- example/api/src/queries/cities.py | 39 ++++++++ example/api/src/queries/hello.py | 19 ++++ example/api/src/queries/user.py | 39 ++++++++ .../cpl/dependency/service_provider.py | 2 +- src/cpl-graphql/cpl/graphql/graphql_module.py | 6 +- .../cpl/graphql/schema/collection.py | 18 ++++ src/cpl-graphql/cpl/graphql/schema/field.py | 8 +- .../cpl/graphql/schema/filter/__init__.py | 0 .../cpl/graphql/schema/filter/filter.py | 9 ++ .../cpl/graphql/schema/graph_type.py | 10 +++ src/cpl-graphql/cpl/graphql/schema/input.py | 26 ++++++ .../cpl/graphql/schema/object_graph_type.py | 9 ++ src/cpl-graphql/cpl/graphql/schema/query.py | 48 +++++++++- .../cpl/graphql/schema/sort/__init__.py | 0 .../cpl/graphql/schema/sort/sort.py | 9 ++ .../cpl/graphql/schema/sort/sort_order.py | 6 ++ src/cpl-graphql/cpl/graphql/service/schema.py | 49 +++------- .../cpl/graphql/service/type_converter.py | 89 +++++++++++++++++++ .../cpl/graphql/utils/name_pipe.py | 28 ++++++ .../cpl/graphql/utils/type_converter.py | 38 -------- 21 files changed, 375 insertions(+), 85 deletions(-) create mode 100644 example/api/src/queries/cities.py create mode 100644 example/api/src/queries/user.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/collection.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/filter/__init__.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/filter/filter.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/graph_type.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/input.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/object_graph_type.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/sort/__init__.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/sort/sort.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/sort/sort_order.py create mode 100644 src/cpl-graphql/cpl/graphql/service/type_converter.py create mode 100644 src/cpl-graphql/cpl/graphql/utils/name_pipe.py delete mode 100644 src/cpl-graphql/cpl/graphql/utils/type_converter.py diff --git a/example/api/src/main.py b/example/api/src/main.py index 04266833..b1f525cc 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -1,9 +1,9 @@ from starlette.responses import JSONResponse +from api.src.queries.cities import CityGraphType +from api.src.queries.hello import UserGraphType from cpl.api.api_module import ApiModule -from cpl.api.application.web_app import WebApp from cpl.application.application_builder import ApplicationBuilder -from cpl.graphql.application.graphql_app import GraphQLApp from cpl.auth.permission.permissions import Permissions from cpl.auth.schema import AuthUser, Role from cpl.core.configuration import Configuration @@ -11,8 +11,8 @@ from cpl.core.console import Console from cpl.core.environment import Environment from cpl.core.utils.cache import Cache from cpl.database.mysql.mysql_module import MySQLModule +from cpl.graphql.application.graphql_app import GraphQLApp from cpl.graphql.graphql_module import GraphQLModule -from cpl.graphql.schema.root_query import RootQuery from queries.hello import HelloQuery from scoped_service import ScopedService from service import PingService @@ -37,6 +37,8 @@ def main(): builder.services.add_cache(AuthUser) builder.services.add_cache(Role) + builder.services.add_transient(CityGraphType) + builder.services.add_transient(UserGraphType) builder.services.add_transient(HelloQuery) app = builder.build() diff --git a/example/api/src/queries/cities.py b/example/api/src/queries/cities.py new file mode 100644 index 00000000..4234f8e2 --- /dev/null +++ b/example/api/src/queries/cities.py @@ -0,0 +1,39 @@ +from cpl.graphql.schema.filter.filter import Filter +from cpl.graphql.schema.object_graph_type import ObjectGraphType + +from cpl.graphql.schema.sort.sort import Sort +from cpl.graphql.schema.sort.sort_order import SortOrder + + +class City: + def __init__(self, id: int, name: str): + self.id = id + self.name = name + + +class CityFilter(Filter[City]): + def __init__(self): + Filter.__init__(self) + self.field("id", int) + self.field("name", str) + + +class CitySort(Sort[City]): + def __init__(self): + Sort.__init__(self) + self.field("id", SortOrder) + self.field("name", SortOrder) + + +class CityGraphType(ObjectGraphType): + def __init__(self): + ObjectGraphType.__init__(self) + + self.string_field( + "id", + resolver=lambda user, *_: user.id, + ) + self.string_field( + "name", + resolver=lambda user, *_: user.name, + ) diff --git a/example/api/src/queries/hello.py b/example/api/src/queries/hello.py index 58c747b1..0f61c27c 100644 --- a/example/api/src/queries/hello.py +++ b/example/api/src/queries/hello.py @@ -1,6 +1,10 @@ +from api.src.queries.cities import CityFilter, CitySort, CityGraphType, City +from api.src.queries.user import User, UserFilter, UserSort, UserGraphType from cpl.api.middleware.request import get_request from cpl.graphql.schema.query import Query +users = [User(i, f"User {i}") for i in range(1, 101)] +cities = [City(i, f"City {i}") for i in range(1, 101)] class HelloQuery(Query): def __init__(self): @@ -9,3 +13,18 @@ class HelloQuery(Query): "message", resolver=lambda *_, name: f"Hello {name} {get_request().state.request_id}", ).with_argument(str, "name", "Name to greet", "world") + + self.collection_field( + UserGraphType, + "users", + UserFilter, + UserSort, + resolver=lambda *_: users, + ) + self.collection_field( + CityGraphType, + "cities", + CityFilter, + CitySort, + resolver=lambda *_: cities, + ) diff --git a/example/api/src/queries/user.py b/example/api/src/queries/user.py new file mode 100644 index 00000000..3c4dd70c --- /dev/null +++ b/example/api/src/queries/user.py @@ -0,0 +1,39 @@ +from cpl.graphql.schema.filter.filter import Filter +from cpl.graphql.schema.object_graph_type import ObjectGraphType + +from cpl.graphql.schema.sort.sort import Sort +from cpl.graphql.schema.sort.sort_order import SortOrder + + +class User: + def __init__(self, id: int, name: str): + self.id = id + self.name = name + + +class UserFilter(Filter[User]): + def __init__(self): + Filter.__init__(self) + self.field("id", int) + self.field("name", str) + + +class UserSort(Sort[User]): + def __init__(self): + Sort.__init__(self) + self.field("id", SortOrder) + self.field("name", SortOrder) + + +class UserGraphType(ObjectGraphType): + def __init__(self): + ObjectGraphType.__init__(self) + + self.string_field( + "id", + resolver=lambda user, *_: user.id, + ) + self.string_field( + "name", + resolver=lambda user, *_: user.name, + ) diff --git a/src/cpl-dependency/cpl/dependency/service_provider.py b/src/cpl-dependency/cpl/dependency/service_provider.py index 23a4216d..38e0ae46 100644 --- a/src/cpl-dependency/cpl/dependency/service_provider.py +++ b/src/cpl-dependency/cpl/dependency/service_provider.py @@ -25,7 +25,7 @@ class ServiceProvider: for descriptor in self._service_descriptors: if typing.get_origin(service_type) is None and ( - descriptor.service_type == service_type + descriptor.service_type.__name__ == service_type.__name__ or typing.get_origin(descriptor.base_type) is None and issubclass(descriptor.base_type, service_type) ): diff --git a/src/cpl-graphql/cpl/graphql/graphql_module.py b/src/cpl-graphql/cpl/graphql/graphql_module.py index cb20a7d3..29d9d79d 100644 --- a/src/cpl-graphql/cpl/graphql/graphql_module.py +++ b/src/cpl-graphql/cpl/graphql/graphql_module.py @@ -1,15 +1,17 @@ from cpl.api.api_module import ApiModule from cpl.dependency.module.module import Module from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.schema.collection import CollectionGraphType from cpl.graphql.schema.root_query import RootQuery from cpl.graphql.service.schema import Schema from cpl.graphql.service.service import GraphQLService +from cpl.graphql.service.type_converter import TypeConverter class GraphQLModule(Module): dependencies = [ApiModule] - singleton = [Schema] - scoped = [GraphQLService, RootQuery] + singleton = [TypeConverter, Schema] + scoped = [GraphQLService, RootQuery, CollectionGraphType] @staticmethod def configure(services: ServiceProvider) -> None: diff --git a/src/cpl-graphql/cpl/graphql/schema/collection.py b/src/cpl-graphql/cpl/graphql/schema/collection.py new file mode 100644 index 00000000..f14269fc --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/collection.py @@ -0,0 +1,18 @@ +from typing import Generic, Type + +from cpl.core.typing import T +from cpl.graphql.schema.graph_type import GraphType + + +class Collection(Generic[T]): + def __init__(self, nodes: list[T], total_count: int, count: int): + self.nodes = nodes + self.totalCount = total_count + self.count = count + +class CollectionGraphType(GraphType[T]): + def __init__(self, t: Type[GraphType[T]]): + GraphType.__init__(self) + self.string_field("totalCount", resolver=lambda obj, *_: obj.totalCount) + self.string_field("count", resolver=lambda obj, *_: obj.count) + self.list_field("nodes", t, resolver=lambda obj, *_: obj.nodes) diff --git a/src/cpl-graphql/cpl/graphql/schema/field.py b/src/cpl-graphql/cpl/graphql/schema/field.py index 2744f6d2..e6358e83 100644 --- a/src/cpl-graphql/cpl/graphql/schema/field.py +++ b/src/cpl-graphql/cpl/graphql/schema/field.py @@ -1,12 +1,12 @@ from typing import Self from cpl.graphql.schema.argument import Argument -from cpl.graphql.typing import TQuery +from cpl.graphql.typing import TQuery, Resolver class Field: - def __init__(self, name: str, gql_type: type, resolver: callable, subquery: TQuery | None = None): + def __init__(self, name: str, gql_type: type, resolver: Resolver = None, subquery: TQuery = None): self._name = name self._gql_type = gql_type self._resolver = resolver @@ -37,7 +37,7 @@ class Field: def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None) -> Self: if name in self._args: raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'") - self._args[name] = Argument(name, arg_type, description, default_value) + self._args[name] = Argument(arg_type, name, description, default_value) return self def with_arguments(self, args: list[Argument]) -> Self: @@ -45,5 +45,5 @@ class Field: if not isinstance(arg, Argument): raise ValueError(f"Expected Argument instance, got {type(arg)}") - self.with_argument(arg.name, arg.type, arg.description, arg.default_value) + self.with_argument(arg.type, arg.name, arg.description, arg.default_value) return self diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/__init__.py b/src/cpl-graphql/cpl/graphql/schema/filter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py new file mode 100644 index 00000000..26339bbc --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py @@ -0,0 +1,9 @@ +from cpl.core.typing import T +from cpl.graphql.schema.input import Input + + +class Filter(Input[T]): + def __init__( + self, + ): + Input.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/schema/graph_type.py b/src/cpl-graphql/cpl/graphql/schema/graph_type.py new file mode 100644 index 00000000..8fff69cf --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/graph_type.py @@ -0,0 +1,10 @@ +from typing import Generic + +from cpl.core.typing import T +from cpl.graphql.schema.query import Query + + +class GraphType(Generic[T], Query): + + def __init__(self): + Query.__init__(self) \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/schema/input.py b/src/cpl-graphql/cpl/graphql/schema/input.py new file mode 100644 index 00000000..8f66c69c --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/input.py @@ -0,0 +1,26 @@ +from datetime import datetime +from enum import Enum +from typing import Type, Generic + +import graphene + +from cpl.core.typing import T +from cpl.graphql.schema.field import Field + + +class Input(Generic[T], graphene.InputObjectType): + def __init__( + self, + ): + graphene.InputObjectType.__init__(self) + self._fields: dict[str, Field] = {} + + def get_fields(self) -> dict[str, Field]: + return self._fields + + def field( + self, + field: str, + t: Type["Input"] | Type[int | str | bool | datetime | list | Enum], + ): + self._fields[field] = Field(field, t) diff --git a/src/cpl-graphql/cpl/graphql/schema/object_graph_type.py b/src/cpl-graphql/cpl/graphql/schema/object_graph_type.py new file mode 100644 index 00000000..5cc46a0a --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/object_graph_type.py @@ -0,0 +1,9 @@ +from cpl.core.typing import T +from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.schema.query import Query + + +class ObjectGraphType(GraphType[T], Query): + + def __init__(self): + Query.__init__(self) \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py index ba93c3a1..0e14d6b9 100644 --- a/src/cpl-graphql/cpl/graphql/schema/query.py +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -2,9 +2,12 @@ from typing import Callable, Type from graphene import ObjectType +from cpl.graphql.schema.argument import Argument from cpl.graphql.schema.field import Field +from cpl.graphql.schema.filter.filter import Filter +from cpl.graphql.schema.sort.sort import Sort +from cpl.graphql.schema.sort.sort_order import SortOrder from cpl.graphql.typing import Resolver -from cpl.graphql.utils.type_converter import TypeConverter class Query(ObjectType): @@ -32,7 +35,7 @@ class Query(ObjectType): def with_query(self, name: str, subquery: Type["Query"]): from cpl.graphql.schema.field import Field - f = Field(name=name, gql_type=object, resolver=lambda root, info, **kwargs: {}, subquery=subquery) + f = Field(name=name, gql_type=subquery, resolver=lambda root, info, **kwargs: {}, subquery=subquery) self._fields[name] = f return self._fields[name] @@ -47,3 +50,44 @@ class Query(ObjectType): def bool_field(self, name: str, resolver: Resolver = None) -> "Field": return self.field(name, bool, resolver) + + def list_field(self, name: str, t: type, resolver: Resolver = None) -> "Field": + return self.field(name, list[t], resolver) + + def collection_field( + self, t: type, name: str, filter_type: type, sort_type: type, resolver: Resolver = None + ) -> "Field": + from cpl.graphql.schema.collection import Collection, CollectionGraphType + + def _resolve_collection(*_, filter: Filter, sort: Sort, skip: int, take: int): + items = resolver() + + for field in filter or []: + if filter[field] is None: + continue + + items = [item for item in items if getattr(item, field) == filter[field]] + + for field in sort or []: + if sort[field] is None: + continue + + reverse = sort[field] == SortOrder.DESC + items = sorted(items, key=lambda item: getattr(item, field), reverse=reverse) + + total_count = len(items) + paged = items[skip : skip + take] + return Collection(nodes=paged, total_count=total_count, count=len(paged)) + + # base = getattr(t, "__gqlname__", t.__class__.__name__) + wrapper = CollectionGraphType(t) + # wrapper.set_graphql_name(f"{base}Collection") + f = self.field(name, wrapper, resolver=_resolve_collection) + return f.with_arguments( + [ + Argument(filter_type, "filter"), + Argument(sort_type, "sort"), + Argument(int, "skip", default_value=0), + Argument(int, "take", default_value=10), + ] + ) diff --git a/src/cpl-graphql/cpl/graphql/schema/sort/__init__.py b/src/cpl-graphql/cpl/graphql/schema/sort/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/schema/sort/sort.py b/src/cpl-graphql/cpl/graphql/schema/sort/sort.py new file mode 100644 index 00000000..ccbb6980 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/sort/sort.py @@ -0,0 +1,9 @@ +from cpl.core.typing import T +from cpl.graphql.schema.input import Input + + +class Sort(Input[T]): + def __init__( + self, + ): + Input.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/schema/sort/sort_order.py b/src/cpl-graphql/cpl/graphql/schema/sort/sort_order.py new file mode 100644 index 00000000..cc3122a4 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/sort/sort_order.py @@ -0,0 +1,6 @@ +from enum import Enum, auto + + +class SortOrder(Enum): + ASC = auto() + DESC = auto() \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py index 69e76e48..9912c739 100644 --- a/src/cpl-graphql/cpl/graphql/service/schema.py +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -1,21 +1,22 @@ -from typing import Type - import graphene from cpl.api.logger import APILogger from cpl.dependency.service_provider import ServiceProvider -from cpl.graphql.schema.argument import Argument -from cpl.graphql.schema.query import Query +from cpl.graphql.schema.collection import CollectionGraphType +from cpl.graphql.schema.graph_type import GraphType from cpl.graphql.schema.root_query import RootQuery -from cpl.graphql.typing import Resolver -from cpl.graphql.utils.type_converter import TypeConverter +from cpl.graphql.service.type_converter import TypeConverter class Schema: - def __init__(self, logger: APILogger, query: RootQuery, provider: ServiceProvider): + def __init__(self, logger: APILogger, converter: TypeConverter, query: RootQuery, provider: ServiceProvider): self._logger = logger self._provider = provider + self._converter = converter + + self._types = set(GraphType.__subclasses__()) + self._types.remove(CollectionGraphType) self._query = query self._schema = None @@ -28,37 +29,15 @@ class Schema: def query(self) -> RootQuery: return self._query + def with_type(self, t: type[GraphType]): + self._types.add(t) + return self + def build(self) -> graphene.Schema: self._schema = graphene.Schema( - query=self.to_graphene(self._query), + query=self._converter.to_graphene(self._query), mutation=None, subscription=None, + # types=[self._converter.to_graphene(t) for t in self._types] if len(self._types) > 0 else None, ) return self._schema - - @staticmethod - def _field_to_graphene(t: Type[graphene.Scalar] | type, args: dict[str, Argument] = None, resolver: Resolver = None) -> graphene.Field: - arguments = {} - if args is not None: - arguments = { - arg.name: graphene.Argument(TypeConverter.to_graphene(arg.type), description=arg.description, default_value=arg.default_value) - for arg in args.values() - } - - return graphene.Field(t, args=arguments, resolver=resolver) - - def to_graphene(self, query: Query, name: str | None = None): - assert query is not None, "Query cannot be None" - attrs = {} - - for field in query.get_fields().values(): - if field.type == object and field.subquery is not None: - subquery = self._provider.get_service(field.subquery) - sub = self.to_graphene(subquery, name=field.name.capitalize()) - attrs[field.name] = self._field_to_graphene(sub, field.args, field.resolver) - continue - - attrs[field.name] = self._field_to_graphene(TypeConverter.to_graphene(field.type), field.args, field.resolver) - - class_name = name or query.__class__.__name__ - return type(class_name, (graphene.ObjectType,), attrs) diff --git a/src/cpl-graphql/cpl/graphql/service/type_converter.py b/src/cpl-graphql/cpl/graphql/service/type_converter.py new file mode 100644 index 00000000..bf483b42 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/service/type_converter.py @@ -0,0 +1,89 @@ +import typing +from enum import Enum +from inspect import isclass + +import graphene +from typing import Any, get_origin, get_args + +from cpl.dependency import ServiceProvider +from cpl.graphql.schema.argument import Argument +from cpl.graphql.schema.filter.filter import Filter +from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.schema.object_graph_type import ObjectGraphType +from cpl.graphql.schema.sort.sort import Sort +from cpl.graphql.typing import Resolver +from cpl.graphql.utils.name_pipe import NamePipe + + +class TypeConverter: + __scalar_map: dict[Any, type[graphene.Scalar]] = { + str: graphene.String, + int: graphene.Int, + float: graphene.Float, + bool: graphene.Boolean, + } + + def __init__(self, provider: ServiceProvider): + self._provider = provider + + def _field_to_graphene(self, t: typing.Type[graphene.Scalar] | type, args: dict[str, Argument] = None, resolver: Resolver = None) -> graphene.Field: + arguments = {} + if args is not None: + arguments = { + arg.name: graphene.Argument(self.to_graphene(arg.type), name=arg.name, description=arg.description, default_value=arg.default_value) + for arg in args.values() + } + + return graphene.Field(t, args=arguments, resolver=resolver) + + def to_graphene(self, t: Any, name: str | None = None) -> Any: + try: + origin = get_origin(t) + args = get_args(t) + + if t in self.__scalar_map: + return self.__scalar_map[t] + + if origin in (list, typing.List): + if not args: + raise ValueError("List must specify element type, e.g. list[str]") + inner = self.to_graphene(args[0]) + return graphene.List(inner) + + if t is list or t is typing.List: + raise ValueError("List must be parametrized: list[str], list[int], list[UserQuery]") + + if isclass(t) and issubclass(t, Enum): + return graphene.Enum.from_enum(t) + + from cpl.graphql.schema.query import Query + if isinstance(t, type) and issubclass(t, (Query)): + query = self._provider.get_service(t) + if query is None: + raise ValueError(f"Could not resolve query of type {t}") + + t = query + + if isinstance(t, type) and issubclass(t, (ObjectGraphType, GraphType, Filter, Sort)): + t = t() + + if isinstance(t, (Query, Filter, Sort)): + attrs = {} + for field in t.get_fields().values(): + if isclass(field.type) and issubclass(field.type, Query) and field.subquery is not None: + subquery = self._provider.get_service(field.subquery) + sub = self.to_graphene(subquery, name=field.name.capitalize()) + attrs[field.name] = self._field_to_graphene(sub, field.args, field.resolver) + continue + + attrs[field.name] = self._field_to_graphene(self.to_graphene(field.type), field.args, field.resolver) + + class_name = NamePipe.to_str(name or t.__class__) + if isinstance(t, (Filter, Sort)): + return type(class_name, (graphene.InputObjectType,), attrs) + + return type(class_name, (graphene.ObjectType,), attrs) + + raise ValueError(f"Unsupported field type: {t}") + except Exception as e: + raise ValueError(f"Failed to convert type {t} to graphene type: {e}") from e \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/utils/name_pipe.py b/src/cpl-graphql/cpl/graphql/utils/name_pipe.py new file mode 100644 index 00000000..7e9b72b1 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/utils/name_pipe.py @@ -0,0 +1,28 @@ +from cpl.core.pipes import PipeABC +from cpl.core.typing import T +from cpl.graphql.schema.collection import CollectionGraphType +from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.schema.object_graph_type import ObjectGraphType + + +class NamePipe(PipeABC): + + @staticmethod + def to_str(value: type, *args) -> str: + if isinstance(value, str): + return value + + if not isinstance(value, type): + raise ValueError(f"Expected a type, got {type(value)}") + + if issubclass(value, CollectionGraphType): + return f"{value.__name__.replace(GraphType.__name__, "")}" + + if issubclass(value, (ObjectGraphType, GraphType)): + return value.__name__.replace(GraphType.__name__, "") + + return value.__name__ + + @staticmethod + def from_str(value: str, *args) -> T: + pass diff --git a/src/cpl-graphql/cpl/graphql/utils/type_converter.py b/src/cpl-graphql/cpl/graphql/utils/type_converter.py deleted file mode 100644 index c89b8928..00000000 --- a/src/cpl-graphql/cpl/graphql/utils/type_converter.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Type - -import graphene - -from cpl.graphql.typing import ScalarType - - -class TypeConverter: - - @staticmethod - def from_graphene(t: Type[graphene.Scalar]) -> ScalarType: - graphene_type_map: dict[Type[graphene.Scalar], ScalarType] = { - graphene.String: str, - graphene.Int: int, - graphene.Float: float, - graphene.Boolean: bool, - graphene.ObjectType: object, - } - - if t not in graphene_type_map: - raise ValueError(f"Unsupported field type: {t}") - - return graphene_type_map[t] - - @staticmethod - def to_graphene(t: ScalarType) -> Type[graphene.Scalar]: - type_graphene_map: dict[ScalarType, Type[graphene.Scalar]] = { - str: graphene.String, - int: graphene.Int, - float: graphene.Float, - bool: graphene.Boolean, - object: graphene.ObjectType, - } - - if t not in type_graphene_map: - raise ValueError(f"Unsupported field type: {t}") - - return type_graphene_map[t] From ada50c693e45aeafa4db8757dea3007bc06e89ee Mon Sep 17 00:00:00 2001 From: edraft Date: Sat, 27 Sep 2025 18:25:50 +0200 Subject: [PATCH 05/20] Changed to strawberry #181 --- example/api/src/main.py | 11 +- example/api/src/queries/cities.py | 12 +- example/api/src/queries/hello.py | 6 +- example/api/src/queries/user.py | 16 +- src/cpl-graphql/cpl/graphql/abc/__init__.py | 0 .../cpl/graphql/abc/strawberry_protocol.py | 9 + src/cpl-graphql/cpl/graphql/graphql_module.py | 6 +- .../cpl/graphql/schema/argument.py | 20 +- .../cpl/graphql/schema/collection.py | 59 ++++-- src/cpl-graphql/cpl/graphql/schema/field.py | 37 +++- .../cpl/graphql/schema/filter/filter.py | 4 +- .../cpl/graphql/schema/graph_type.py | 2 +- src/cpl-graphql/cpl/graphql/schema/input.py | 40 ++-- .../cpl/graphql/schema/object_graph_type.py | 9 - src/cpl-graphql/cpl/graphql/schema/query.py | 178 +++++++++++++----- src/cpl-graphql/cpl/graphql/service/schema.py | 45 +++-- .../cpl/graphql/service/service.py | 2 +- .../cpl/graphql/service/type_converter.py | 89 --------- src/cpl-graphql/requirements.txt | 2 +- 19 files changed, 317 insertions(+), 230 deletions(-) create mode 100644 src/cpl-graphql/cpl/graphql/abc/__init__.py create mode 100644 src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py delete mode 100644 src/cpl-graphql/cpl/graphql/schema/object_graph_type.py delete mode 100644 src/cpl-graphql/cpl/graphql/service/type_converter.py diff --git a/example/api/src/main.py b/example/api/src/main.py index b1f525cc..bfb953fb 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -1,7 +1,8 @@ from starlette.responses import JSONResponse -from api.src.queries.cities import CityGraphType +from api.src.queries.cities import CityGraphType, CityFilter, CitySort from api.src.queries.hello import UserGraphType +from api.src.queries.user import UserFilter, UserSort from cpl.api.api_module import ApiModule from cpl.application.application_builder import ApplicationBuilder from cpl.auth.permission.permissions import Permissions @@ -38,7 +39,13 @@ def main(): builder.services.add_cache(Role) builder.services.add_transient(CityGraphType) + builder.services.add_transient(CityFilter) + builder.services.add_transient(CitySort) + builder.services.add_transient(UserGraphType) + builder.services.add_transient(UserFilter) + builder.services.add_transient(UserSort) + builder.services.add_transient(HelloQuery) app = builder.build() @@ -57,7 +64,7 @@ def main(): app.with_routes_directory("routes") schema = app.with_graphql() - schema.query.string_field("ping", resolver=lambda *_: "pong") + schema.query.string_field("ping", resolver=lambda: "pong") schema.query.with_query("hello", HelloQuery) app.with_playground() diff --git a/example/api/src/queries/cities.py b/example/api/src/queries/cities.py index 4234f8e2..7fd88273 100644 --- a/example/api/src/queries/cities.py +++ b/example/api/src/queries/cities.py @@ -1,5 +1,5 @@ from cpl.graphql.schema.filter.filter import Filter -from cpl.graphql.schema.object_graph_type import ObjectGraphType +from cpl.graphql.schema.graph_type import GraphType from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort_order import SortOrder @@ -25,15 +25,15 @@ class CitySort(Sort[City]): self.field("name", SortOrder) -class CityGraphType(ObjectGraphType): +class CityGraphType(GraphType[City]): def __init__(self): - ObjectGraphType.__init__(self) + GraphType.__init__(self) - self.string_field( + self.int_field( "id", - resolver=lambda user, *_: user.id, + resolver=lambda root: root.id, ) self.string_field( "name", - resolver=lambda user, *_: user.name, + resolver=lambda root: root.name, ) diff --git a/example/api/src/queries/hello.py b/example/api/src/queries/hello.py index 0f61c27c..2f2ba633 100644 --- a/example/api/src/queries/hello.py +++ b/example/api/src/queries/hello.py @@ -11,7 +11,7 @@ class HelloQuery(Query): Query.__init__(self) self.string_field( "message", - resolver=lambda *_, name: f"Hello {name} {get_request().state.request_id}", + resolver=lambda name: f"Hello {name} {get_request().state.request_id}", ).with_argument(str, "name", "Name to greet", "world") self.collection_field( @@ -19,12 +19,12 @@ class HelloQuery(Query): "users", UserFilter, UserSort, - resolver=lambda *_: users, + resolver=lambda: users, ) self.collection_field( CityGraphType, "cities", CityFilter, CitySort, - resolver=lambda *_: cities, + resolver=lambda: cities, ) diff --git a/example/api/src/queries/user.py b/example/api/src/queries/user.py index 3c4dd70c..a35a1780 100644 --- a/example/api/src/queries/user.py +++ b/example/api/src/queries/user.py @@ -1,6 +1,5 @@ from cpl.graphql.schema.filter.filter import Filter -from cpl.graphql.schema.object_graph_type import ObjectGraphType - +from cpl.graphql.schema.graph_type import GraphType from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort_order import SortOrder @@ -25,15 +24,16 @@ class UserSort(Sort[User]): self.field("name", SortOrder) -class UserGraphType(ObjectGraphType): - def __init__(self): - ObjectGraphType.__init__(self) +class UserGraphType(GraphType[User]): - self.string_field( + def __init__(self): + GraphType.__init__(self) + + self.int_field( "id", - resolver=lambda user, *_: user.id, + resolver=lambda root: root.id, ) self.string_field( "name", - resolver=lambda user, *_: user.name, + resolver=lambda root: root.name, ) diff --git a/src/cpl-graphql/cpl/graphql/abc/__init__.py b/src/cpl-graphql/cpl/graphql/abc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py b/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py new file mode 100644 index 00000000..1c0b6592 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py @@ -0,0 +1,9 @@ +from typing import Protocol, Type, runtime_checkable + +from cpl.graphql.schema.field import Field + + +@runtime_checkable +class StrawberryProtocol(Protocol): + def to_strawberry(self) -> Type: ... + def get_fields(self) -> dict[str, Field]: ... diff --git a/src/cpl-graphql/cpl/graphql/graphql_module.py b/src/cpl-graphql/cpl/graphql/graphql_module.py index 29d9d79d..70efa400 100644 --- a/src/cpl-graphql/cpl/graphql/graphql_module.py +++ b/src/cpl-graphql/cpl/graphql/graphql_module.py @@ -1,17 +1,15 @@ from cpl.api.api_module import ApiModule from cpl.dependency.module.module import Module from cpl.dependency.service_provider import ServiceProvider -from cpl.graphql.schema.collection import CollectionGraphType from cpl.graphql.schema.root_query import RootQuery from cpl.graphql.service.schema import Schema from cpl.graphql.service.service import GraphQLService -from cpl.graphql.service.type_converter import TypeConverter class GraphQLModule(Module): dependencies = [ApiModule] - singleton = [TypeConverter, Schema] - scoped = [GraphQLService, RootQuery, CollectionGraphType] + singleton = [Schema, RootQuery] + scoped = [GraphQLService] @staticmethod def configure(services: ServiceProvider) -> None: diff --git a/src/cpl-graphql/cpl/graphql/schema/argument.py b/src/cpl-graphql/cpl/graphql/schema/argument.py index 2f3b938c..cbf8b32f 100644 --- a/src/cpl-graphql/cpl/graphql/schema/argument.py +++ b/src/cpl-graphql/cpl/graphql/schema/argument.py @@ -1,9 +1,21 @@ +from typing import Any + + class Argument: - def __init__(self, t: type, name: str, description: str = None, default_value=None): + + def __init__( + self, + t: type, + name: str, + description: str = None, + default_value: Any = None, + optional: bool = None, + ): self._type = t self._name = name self._description = description self._default_value = default_value + self._optional = optional @property def type(self) -> type: @@ -18,5 +30,9 @@ class Argument: return self._description @property - def default_value(self): + def default_value(self) -> Any | None: return self._default_value + + @property + def optional(self) -> bool | None: + return self._optional diff --git a/src/cpl-graphql/cpl/graphql/schema/collection.py b/src/cpl-graphql/cpl/graphql/schema/collection.py index f14269fc..68b8aa69 100644 --- a/src/cpl-graphql/cpl/graphql/schema/collection.py +++ b/src/cpl-graphql/cpl/graphql/schema/collection.py @@ -1,18 +1,53 @@ -from typing import Generic, Type +from typing import Type, Dict, List + +import strawberry from cpl.core.typing import T -from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol -class Collection(Generic[T]): +class CollectionGraphTypeFactory: + _cache: Dict[Type, Type] = {} + + @classmethod + def get(cls, node_type: Type[StrawberryProtocol]) -> Type: + if node_type in cls._cache: + return cls._cache[node_type] + + gql_node = node_type().to_strawberry() if hasattr(node_type, "to_strawberry") else node_type + + gql_type = strawberry.type( + type( + f"{node_type.__name__}Collection", + (), + { + "__annotations__": { + "nodes": List[gql_node], + "total_count": int, + "count": int, + } + }, + ) + ) + + cls._cache[node_type] = gql_type + return gql_type + + +class Collection: def __init__(self, nodes: list[T], total_count: int, count: int): - self.nodes = nodes - self.totalCount = total_count - self.count = count + self._nodes = nodes + self._total_count = total_count + self._count = count -class CollectionGraphType(GraphType[T]): - def __init__(self, t: Type[GraphType[T]]): - GraphType.__init__(self) - self.string_field("totalCount", resolver=lambda obj, *_: obj.totalCount) - self.string_field("count", resolver=lambda obj, *_: obj.count) - self.list_field("nodes", t, resolver=lambda obj, *_: obj.nodes) + @property + def nodes(self) -> list[T]: + return self._nodes + + @property + def total_count(self) -> int: + return self._total_count + + @property + def count(self) -> int: + return self._count diff --git a/src/cpl-graphql/cpl/graphql/schema/field.py b/src/cpl-graphql/cpl/graphql/schema/field.py index e6358e83..2231e11c 100644 --- a/src/cpl-graphql/cpl/graphql/schema/field.py +++ b/src/cpl-graphql/cpl/graphql/schema/field.py @@ -6,11 +6,24 @@ from cpl.graphql.typing import TQuery, Resolver class Field: - def __init__(self, name: str, gql_type: type, resolver: Resolver = None, subquery: TQuery = None): + def __init__( + self, + name: str, + gql_type: type = None, + resolver: Resolver = None, + optional=None, + default=None, + subquery: TQuery = None, + parent_type=None, + ): self._name = name self._gql_type = gql_type self._resolver = resolver + self._optional = optional or True + self._default = default + self._subquery = subquery + self._parent_type = parent_type self._args: dict[str, Argument] = {} @@ -26,6 +39,14 @@ class Field: def resolver(self) -> callable: return self._resolver + @property + def optional(self) -> bool | None: + return self._optional + + @property + def default(self): + return self._default + @property def args(self) -> dict: return self._args @@ -34,10 +55,18 @@ class Field: def subquery(self) -> TQuery | None: return self._subquery - def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None) -> Self: + @property + def parent_type(self): + return self._parent_type + + @property + def arguments(self) -> dict[str, Argument]: + return self._args + + def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self: if name in self._args: raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'") - self._args[name] = Argument(arg_type, name, description, default_value) + self._args[name] = Argument(arg_type, name, description, default_value, optional) return self def with_arguments(self, args: list[Argument]) -> Self: @@ -45,5 +74,5 @@ class Field: if not isinstance(arg, Argument): raise ValueError(f"Expected Argument instance, got {type(arg)}") - self.with_argument(arg.type, arg.name, arg.description, arg.default_value) + self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional) return self diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py index 26339bbc..2f76c4b4 100644 --- a/src/cpl-graphql/cpl/graphql/schema/filter/filter.py +++ b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py @@ -3,7 +3,5 @@ from cpl.graphql.schema.input import Input class Filter(Input[T]): - def __init__( - self, - ): + def __init__(self): Input.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/schema/graph_type.py b/src/cpl-graphql/cpl/graphql/schema/graph_type.py index 8fff69cf..e829b82d 100644 --- a/src/cpl-graphql/cpl/graphql/schema/graph_type.py +++ b/src/cpl-graphql/cpl/graphql/schema/graph_type.py @@ -4,7 +4,7 @@ from cpl.core.typing import T from cpl.graphql.schema.query import Query -class GraphType(Generic[T], Query): +class GraphType(Query, Generic[T]): def __init__(self): Query.__init__(self) \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/schema/input.py b/src/cpl-graphql/cpl/graphql/schema/input.py index 8f66c69c..4c9afc86 100644 --- a/src/cpl-graphql/cpl/graphql/schema/input.py +++ b/src/cpl-graphql/cpl/graphql/schema/input.py @@ -1,26 +1,34 @@ -from datetime import datetime -from enum import Enum -from typing import Type, Generic +from typing import Generic, Dict, Type, Any, Optional -import graphene +import strawberry from cpl.core.typing import T +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.schema.field import Field -class Input(Generic[T], graphene.InputObjectType): - def __init__( - self, - ): - graphene.InputObjectType.__init__(self) - self._fields: dict[str, Field] = {} +class Input(StrawberryProtocol, Generic[T]): + def __init__(self): + self._fields: Dict[str, Field] = {} def get_fields(self) -> dict[str, Field]: return self._fields - def field( - self, - field: str, - t: Type["Input"] | Type[int | str | bool | datetime | list | Enum], - ): - self._fields[field] = Field(field, t) + def field(self, name: str, typ: type, optional: bool = True): + self._fields[name] = Field(name, typ, optional=optional) + + def to_strawberry(self) -> Type: + annotations = {} + namespace = {} + + for name, f in self._fields.items(): + ann = f.type if not f.optional else Optional[f.type] + annotations[name] = ann + + if f.optional: + namespace[name] = None + elif f.default is not None: + namespace[name] = f.default + + namespace["__annotations__"] = annotations + return strawberry.input(type(f"{self.__class__.__name__}Input", (), namespace)) diff --git a/src/cpl-graphql/cpl/graphql/schema/object_graph_type.py b/src/cpl-graphql/cpl/graphql/schema/object_graph_type.py deleted file mode 100644 index 5cc46a0a..00000000 --- a/src/cpl-graphql/cpl/graphql/schema/object_graph_type.py +++ /dev/null @@ -1,9 +0,0 @@ -from cpl.core.typing import T -from cpl.graphql.schema.graph_type import GraphType -from cpl.graphql.schema.query import Query - - -class ObjectGraphType(GraphType[T], Query): - - def __init__(self): - Query.__init__(self) \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py index 0e14d6b9..a453734a 100644 --- a/src/cpl-graphql/cpl/graphql/schema/query.py +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -1,21 +1,27 @@ -from typing import Callable, Type +import inspect +from typing import Callable, Type, Any, Optional -from graphene import ObjectType +import strawberry +from strawberry.exceptions import StrawberryException -from cpl.graphql.schema.argument import Argument +from cpl.dependency.inject import inject +from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol +from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory from cpl.graphql.schema.field import Field -from cpl.graphql.schema.filter.filter import Filter -from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort_order import SortOrder from cpl.graphql.typing import Resolver -class Query(ObjectType): +class Query(StrawberryProtocol): - def __init__(self): - from cpl.graphql.schema.field import Field + @inject + def __init__(self, provider: ServiceProvider): + self._provider = provider - ObjectType.__init__(self) + from cpl.graphql.service.schema import Schema + + self._schema = provider.get_service(Schema) self._fields: dict[str, Field] = {} def get_fields(self) -> dict[str, Field]: @@ -25,69 +31,137 @@ class Query(ObjectType): self, name: str, t: type, - resolver: Callable | None = None, - ) -> "Field": + resolver: Resolver = None, + ) -> Field: from cpl.graphql.schema.field import Field self._fields[name] = Field(name, t, resolver) return self._fields[name] - def with_query(self, name: str, subquery: Type["Query"]): - from cpl.graphql.schema.field import Field - - f = Field(name=name, gql_type=subquery, resolver=lambda root, info, **kwargs: {}, subquery=subquery) - self._fields[name] = f - return self._fields[name] - - def string_field(self, name: str, resolver: Resolver = None) -> "Field": + def string_field(self, name: str, resolver: Resolver = None) -> Field: return self.field(name, str, resolver) - def int_field(self, name: str, resolver: Resolver = None) -> "Field": + def int_field(self, name: str, resolver: Resolver = None) -> Field: return self.field(name, int, resolver) - def float_field(self, name: str, resolver: Resolver = None) -> "Field": + def float_field(self, name: str, resolver: Resolver = None) -> Field: return self.field(name, float, resolver) - def bool_field(self, name: str, resolver: Resolver = None) -> "Field": + def bool_field(self, name: str, resolver: Resolver = None) -> Field: return self.field(name, bool, resolver) - def list_field(self, name: str, t: type, resolver: Resolver = None) -> "Field": + def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field: return self.field(name, list[t], resolver) + def with_query(self, name: str, subquery_cls: Type["Query"]): + sub = self._provider.get_service(subquery_cls) + if not sub: + raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider") + + self.field(name, sub.to_strawberry(), lambda: sub) + return self + def collection_field( - self, t: type, name: str, filter_type: type, sort_type: type, resolver: Resolver = None - ) -> "Field": - from cpl.graphql.schema.collection import Collection, CollectionGraphType + self, + t: type, + name: str, + filter_type: Type[StrawberryProtocol], + sort_type: Type[StrawberryProtocol], + resolver: Callable, + ) -> Field: + # self._schema.with_type(filter_type) + # self._schema.with_type(sort_type) - def _resolve_collection(*_, filter: Filter, sort: Sort, skip: int, take: int): + def _resolve_collection(filter=None, sort=None, skip=0, take=10): items = resolver() + if filter: + for field, value in filter.__dict__.items(): + if value is None: + continue + items = [i for i in items if getattr(i, field) == value] - for field in filter or []: - if filter[field] is None: - continue - - items = [item for item in items if getattr(item, field) == filter[field]] - - for field in sort or []: - if sort[field] is None: - continue - - reverse = sort[field] == SortOrder.DESC - items = sorted(items, key=lambda item: getattr(item, field), reverse=reverse) - + if sort: + for field, direction in sort.__dict__.items(): + reverse = direction == SortOrder.DESC + items = sorted(items, key=lambda i: getattr(i, field), reverse=reverse) total_count = len(items) paged = items[skip : skip + take] return Collection(nodes=paged, total_count=total_count, count=len(paged)) - # base = getattr(t, "__gqlname__", t.__class__.__name__) - wrapper = CollectionGraphType(t) - # wrapper.set_graphql_name(f"{base}Collection") - f = self.field(name, wrapper, resolver=_resolve_collection) - return f.with_arguments( - [ - Argument(filter_type, "filter"), - Argument(sort_type, "sort"), - Argument(int, "skip", default_value=0), - Argument(int, "take", default_value=10), - ] - ) + filter = self._provider.get_service(filter_type) + if not filter: + raise ValueError(f"Filter '{filter_type.__name__}' not registered in service provider") + + sort = self._provider.get_service(sort_type) + if not sort: + raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider") + + f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection) + f.with_argument(filter.to_strawberry(), "filter") + f.with_argument(sort.to_strawberry(), "sort") + f.with_argument(int, "skip", default_value=0) + f.with_argument(int, "take", default_value=10) + return f + + @staticmethod + def _build_resolver(f: "Field"): + params: list[inspect.Parameter] = [] + for arg in f.arguments.values(): + ann = Optional[arg.type] if arg.optional else arg.type + + if arg.default_value is None: + param = inspect.Parameter( + arg.name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=ann, + ) + else: + param = inspect.Parameter( + arg.name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=ann, + default=arg.default_value, + ) + + params.append(param) + + sig = inspect.Signature(parameters=params, return_annotation=f.type) + + def _resolver(*args, **kwargs): + return f.resolver(*args, **kwargs) if f.resolver else None + + _resolver.__signature__ = sig + return _resolver + + def _field_to_strawberry(self, f: Field) -> Any: + try: + if f.resolver: + ann = getattr(f.resolver, "__annotations__", {}) + if "return" not in ann or ann["return"] is None: + ann = dict(ann) + ann["return"] = f.type + f.resolver.__annotations__ = ann + + if f.arguments: + resolver = self._build_resolver(f) + return strawberry.field(resolver=resolver) + + if not f.resolver: + return strawberry.field(resolver=lambda *_, **__: None) + + return strawberry.field(resolver=f.resolver) + except StrawberryException as e: + raise Exception( + f"Error converting field '{f.name}' to strawberry field: {e}" + ) from e + + def to_strawberry(self) -> Type: + annotations: dict[str, Any] = {} + namespace: dict[str, Any] = {} + + for name, f in self._fields.items(): + annotations[name] = f.type + namespace[name] = self._field_to_strawberry(f) + + namespace["__annotations__"] = annotations + return strawberry.type(type(f"{self.__class__.__name__}GraphType", (), namespace)) diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py index 9912c739..23627ee4 100644 --- a/src/cpl-graphql/cpl/graphql/service/schema.py +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -1,43 +1,54 @@ -import graphene +from typing import Type, Self + +import strawberry from cpl.api.logger import APILogger from cpl.dependency.service_provider import ServiceProvider -from cpl.graphql.schema.collection import CollectionGraphType -from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.schema.root_query import RootQuery -from cpl.graphql.service.type_converter import TypeConverter class Schema: - def __init__(self, logger: APILogger, converter: TypeConverter, query: RootQuery, provider: ServiceProvider): + def __init__(self, logger: APILogger, provider: ServiceProvider): self._logger = logger self._provider = provider - self._converter = converter - self._types = set(GraphType.__subclasses__()) - self._types.remove(CollectionGraphType) + self._types: dict[str, Type[StrawberryProtocol]] = {} - self._query = query self._schema = None @property - def schema(self) -> graphene.Schema | None: + def schema(self) -> strawberry.Schema | None: return self._schema @property def query(self) -> RootQuery: - return self._query + return self._provider.get_service(RootQuery) - def with_type(self, t: type[GraphType]): - self._types.add(t) + def with_type(self, t: Type[StrawberryProtocol]) -> Self: + self._types[t.__name__] = t return self - def build(self) -> graphene.Schema: - self._schema = graphene.Schema( - query=self._converter.to_graphene(self._query), + def _get_types(self): + types: list[Type] = [] + for t in self._types.values(): + t_obj = self._provider.get_service(t) + if not t_obj: + raise ValueError(f"Type '{t.__name__}' not registered in service provider") + types.append(t_obj.to_strawberry()) + + return types + + def build(self) -> strawberry.Schema: + query = self._provider.get_service(RootQuery) + if not query: + raise ValueError("RootQuery not registered in service provider") + + self._schema = strawberry.Schema( + query=query.to_strawberry(), mutation=None, subscription=None, - # types=[self._converter.to_graphene(t) for t in self._types] if len(self._types) > 0 else None, + types=self._get_types(), ) return self._schema diff --git a/src/cpl-graphql/cpl/graphql/service/service.py b/src/cpl-graphql/cpl/graphql/service/service.py index 54c4f388..f039ccbd 100644 --- a/src/cpl-graphql/cpl/graphql/service/service.py +++ b/src/cpl-graphql/cpl/graphql/service/service.py @@ -16,7 +16,7 @@ class GraphQLService: variables: Optional[Dict[str, Any]], request: TRequest, ) -> Dict[str, Any]: - result = await self._schema.execute_async( + result = await self._schema.execute( query, variable_values=variables, context_value={"request": request}, diff --git a/src/cpl-graphql/cpl/graphql/service/type_converter.py b/src/cpl-graphql/cpl/graphql/service/type_converter.py deleted file mode 100644 index bf483b42..00000000 --- a/src/cpl-graphql/cpl/graphql/service/type_converter.py +++ /dev/null @@ -1,89 +0,0 @@ -import typing -from enum import Enum -from inspect import isclass - -import graphene -from typing import Any, get_origin, get_args - -from cpl.dependency import ServiceProvider -from cpl.graphql.schema.argument import Argument -from cpl.graphql.schema.filter.filter import Filter -from cpl.graphql.schema.graph_type import GraphType -from cpl.graphql.schema.object_graph_type import ObjectGraphType -from cpl.graphql.schema.sort.sort import Sort -from cpl.graphql.typing import Resolver -from cpl.graphql.utils.name_pipe import NamePipe - - -class TypeConverter: - __scalar_map: dict[Any, type[graphene.Scalar]] = { - str: graphene.String, - int: graphene.Int, - float: graphene.Float, - bool: graphene.Boolean, - } - - def __init__(self, provider: ServiceProvider): - self._provider = provider - - def _field_to_graphene(self, t: typing.Type[graphene.Scalar] | type, args: dict[str, Argument] = None, resolver: Resolver = None) -> graphene.Field: - arguments = {} - if args is not None: - arguments = { - arg.name: graphene.Argument(self.to_graphene(arg.type), name=arg.name, description=arg.description, default_value=arg.default_value) - for arg in args.values() - } - - return graphene.Field(t, args=arguments, resolver=resolver) - - def to_graphene(self, t: Any, name: str | None = None) -> Any: - try: - origin = get_origin(t) - args = get_args(t) - - if t in self.__scalar_map: - return self.__scalar_map[t] - - if origin in (list, typing.List): - if not args: - raise ValueError("List must specify element type, e.g. list[str]") - inner = self.to_graphene(args[0]) - return graphene.List(inner) - - if t is list or t is typing.List: - raise ValueError("List must be parametrized: list[str], list[int], list[UserQuery]") - - if isclass(t) and issubclass(t, Enum): - return graphene.Enum.from_enum(t) - - from cpl.graphql.schema.query import Query - if isinstance(t, type) and issubclass(t, (Query)): - query = self._provider.get_service(t) - if query is None: - raise ValueError(f"Could not resolve query of type {t}") - - t = query - - if isinstance(t, type) and issubclass(t, (ObjectGraphType, GraphType, Filter, Sort)): - t = t() - - if isinstance(t, (Query, Filter, Sort)): - attrs = {} - for field in t.get_fields().values(): - if isclass(field.type) and issubclass(field.type, Query) and field.subquery is not None: - subquery = self._provider.get_service(field.subquery) - sub = self.to_graphene(subquery, name=field.name.capitalize()) - attrs[field.name] = self._field_to_graphene(sub, field.args, field.resolver) - continue - - attrs[field.name] = self._field_to_graphene(self.to_graphene(field.type), field.args, field.resolver) - - class_name = NamePipe.to_str(name or t.__class__) - if isinstance(t, (Filter, Sort)): - return type(class_name, (graphene.InputObjectType,), attrs) - - return type(class_name, (graphene.ObjectType,), attrs) - - raise ValueError(f"Unsupported field type: {t}") - except Exception as e: - raise ValueError(f"Failed to convert type {t} to graphene type: {e}") from e \ No newline at end of file diff --git a/src/cpl-graphql/requirements.txt b/src/cpl-graphql/requirements.txt index abe92c36..d74de843 100644 --- a/src/cpl-graphql/requirements.txt +++ b/src/cpl-graphql/requirements.txt @@ -1,2 +1,2 @@ cpl-api -graphene==3.4.3 \ No newline at end of file +strawberry-graphql==0.282.0 \ No newline at end of file From d8c60defba9ae1d550405d7acf9b0aa494a3e0d1 Mon Sep 17 00:00:00 2001 From: edraft Date: Sat, 27 Sep 2025 21:57:33 +0200 Subject: [PATCH 06/20] Further gql improvements & added test data #181 --- example/api/src/main.py | 53 ++++++++++++------- example/api/src/model/__init__.py | 0 example/api/src/model/post.py | 30 +++++++++++ example/api/src/model/post_dao.py | 11 ++++ example/api/src/model/post_query.py | 38 +++++++++++++ example/api/src/queries/hello.py | 39 ++++++++++++++ example/api/src/scripts/0-posts.sql | 10 ++++ example/api/src/test_data_seeder.py | 31 +++++++++++ example/database/src/model/city.py | 8 +-- example/database/src/model/user.py | 8 +-- .../auth/schema/_administration/api_key.py | 8 +-- .../auth/schema/_administration/auth_user.py | 11 ++-- .../schema/_administration/auth_user_dao.py | 2 +- .../schema/_permission/api_key_permission.py | 6 +-- .../cpl/auth/schema/_permission/permission.py | 10 ++-- .../cpl/auth/schema/_permission/role.py | 10 ++-- .../schema/_permission/role_permission.py | 10 ++-- .../cpl/auth/schema/_permission/role_user.py | 6 +-- .../cpl/core/utils/credential_manager.py | 9 ++-- .../database/abc/data_access_object_abc.py | 12 +++-- .../cpl/database/abc/db_join_model_abc.py | 6 +-- .../cpl/database/abc/db_model_abc.py | 6 +-- .../cpl/database/model/database_settings.py | 2 +- .../cpl/database/schema/executed_migration.py | 8 +-- .../cpl/graphql/schema/collection.py | 2 +- src/cpl-graphql/cpl/graphql/schema/query.py | 53 +++++++++++++++++-- .../cpl/graphql/schema/sort/sort_order.py | 4 +- 27 files changed, 305 insertions(+), 88 deletions(-) create mode 100644 example/api/src/model/__init__.py create mode 100644 example/api/src/model/post.py create mode 100644 example/api/src/model/post_dao.py create mode 100644 example/api/src/model/post_query.py create mode 100644 example/api/src/scripts/0-posts.sql create mode 100644 example/api/src/test_data_seeder.py diff --git a/example/api/src/main.py b/example/api/src/main.py index bfb953fb..777f62fe 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -1,7 +1,7 @@ from starlette.responses import JSONResponse from api.src.queries.cities import CityGraphType, CityFilter, CitySort -from api.src.queries.hello import UserGraphType +from api.src.queries.hello import UserGraphType, AuthUserFilter, AuthUserSort, AuthUserGraphType from api.src.queries.user import UserFilter, UserSort from cpl.api.api_module import ApiModule from cpl.application.application_builder import ApplicationBuilder @@ -14,9 +14,12 @@ from cpl.core.utils.cache import Cache from cpl.database.mysql.mysql_module import MySQLModule from cpl.graphql.application.graphql_app import GraphQLApp from cpl.graphql.graphql_module import GraphQLModule +from model.post_dao import PostDao +from model.post_query import PostFilter, PostSort, PostGraphType from queries.hello import HelloQuery from scoped_service import ScopedService from service import PingService +from test_data_seeder import TestDataSeeder def main(): @@ -27,29 +30,38 @@ def main(): Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True) # builder.services.add_logging() - builder.services.add_structured_logging() - builder.services.add_transient(PingService) - builder.services.add_module(MySQLModule) - builder.services.add_module(ApiModule) - builder.services.add_module(GraphQLModule) + ( + builder.services.add_structured_logging() + .add_transient(PingService) + .add_module(MySQLModule) + .add_module(ApiModule) + .add_module(GraphQLModule) + .add_scoped(ScopedService) + .add_cache(AuthUser) + .add_cache(Role) + .add_transient(CityGraphType) + .add_transient(CityFilter) + .add_transient(CitySort) + .add_transient(UserGraphType) + .add_transient(UserFilter) + .add_transient(UserSort) + .add_transient(AuthUserGraphType) + .add_transient(AuthUserFilter) + .add_transient(AuthUserSort) + .add_transient(HelloQuery) + # posts + .add_transient(PostDao) + .add_transient(PostGraphType) + .add_transient(PostFilter) + .add_transient(PostSort) - builder.services.add_scoped(ScopedService) - - builder.services.add_cache(AuthUser) - builder.services.add_cache(Role) - - builder.services.add_transient(CityGraphType) - builder.services.add_transient(CityFilter) - builder.services.add_transient(CitySort) - - builder.services.add_transient(UserGraphType) - builder.services.add_transient(UserFilter) - builder.services.add_transient(UserSort) - - builder.services.add_transient(HelloQuery) + # test data + .add_singleton(TestDataSeeder) + ) app = builder.build() app.with_logging() + app.with_migrations("./scripts") app.with_authentication() app.with_authorization() @@ -66,6 +78,7 @@ def main(): schema = app.with_graphql() schema.query.string_field("ping", resolver=lambda: "pong") schema.query.with_query("hello", HelloQuery) + schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort) app.with_playground() app.with_graphiql() diff --git a/example/api/src/model/__init__.py b/example/api/src/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/example/api/src/model/post.py b/example/api/src/model/post.py new file mode 100644 index 00000000..a2d22d60 --- /dev/null +++ b/example/api/src/model/post.py @@ -0,0 +1,30 @@ +from datetime import datetime +from typing import Self + +from cpl.core.typing import SerialId +from cpl.database.abc import DbModelABC + + +class Post(DbModelABC[Self]): + + def __init__( + self, + id: int, + title: str, + content: str, + deleted: bool = False, + editor_id: SerialId | None = None, + created: datetime | None = None, + updated: datetime | None = None, + ): + DbModelABC.__init__(self, id, deleted, editor_id, created, updated) + self._title = title + self._content = content + + @property + def title(self) -> str: + return self._title + + @property + def content(self) -> str: + return self._content diff --git a/example/api/src/model/post_dao.py b/example/api/src/model/post_dao.py new file mode 100644 index 00000000..da283fef --- /dev/null +++ b/example/api/src/model/post_dao.py @@ -0,0 +1,11 @@ +from cpl.database.abc import DbModelDaoABC +from model.post import Post + + +class PostDao(DbModelDaoABC): + + def __init__(self): + DbModelDaoABC.__init__(self, Post, "posts") + + self.attribute(Post.title, str) + self.attribute(Post.content, str) \ No newline at end of file diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py new file mode 100644 index 00000000..6e25dddc --- /dev/null +++ b/example/api/src/model/post_query.py @@ -0,0 +1,38 @@ +from cpl.graphql.schema.filter.filter import Filter +from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.schema.sort.sort import Sort +from cpl.graphql.schema.sort.sort_order import SortOrder +from model.post import Post + +class PostFilter(Filter[Post]): + def __init__(self): + Filter.__init__(self) + self.field("id", int) + self.field("title", str) + self.field("content", str) + +class PostSort(Sort[Post]): + def __init__(self): + Sort.__init__(self) + self.field("id", SortOrder) + self.field("title", SortOrder) + self.field("content", SortOrder) + + +class PostGraphType(GraphType[Post]): + + def __init__(self): + GraphType.__init__(self) + + self.int_field( + "id", + resolver=lambda root: root.id, + ) + self.string_field( + "title", + resolver=lambda root: root.title, + ) + self.string_field( + "content", + resolver=lambda root: root.content, + ) \ No newline at end of file diff --git a/example/api/src/queries/hello.py b/example/api/src/queries/hello.py index 2f2ba633..addd9173 100644 --- a/example/api/src/queries/hello.py +++ b/example/api/src/queries/hello.py @@ -1,11 +1,43 @@ from api.src.queries.cities import CityFilter, CitySort, CityGraphType, City from api.src.queries.user import User, UserFilter, UserSort, UserGraphType from cpl.api.middleware.request import get_request +from cpl.auth.schema import AuthUserDao, AuthUser +from cpl.graphql.schema.filter.filter import Filter +from cpl.graphql.schema.graph_type import GraphType from cpl.graphql.schema.query import Query +from cpl.graphql.schema.sort.sort import Sort +from cpl.graphql.schema.sort.sort_order import SortOrder users = [User(i, f"User {i}") for i in range(1, 101)] cities = [City(i, f"City {i}") for i in range(1, 101)] +class AuthUserFilter(Filter[AuthUser]): + def __init__(self): + Filter.__init__(self) + self.field("id", int) + self.field("username", str) + + +class AuthUserSort(Sort[AuthUser]): + def __init__(self): + Sort.__init__(self) + self.field("id", SortOrder) + self.field("username", SortOrder) + +class AuthUserGraphType(GraphType[AuthUser]): + + def __init__(self): + GraphType.__init__(self) + + self.int_field( + "id", + resolver=lambda root: root.id, + ) + self.string_field( + "username", + resolver=lambda root: root.username, + ) + class HelloQuery(Query): def __init__(self): Query.__init__(self) @@ -28,3 +60,10 @@ class HelloQuery(Query): CitySort, resolver=lambda: cities, ) + self.dao_collection_field( + AuthUserGraphType, + AuthUserDao, + "authUsers", + AuthUserFilter, + AuthUserSort, + ) diff --git a/example/api/src/scripts/0-posts.sql b/example/api/src/scripts/0-posts.sql new file mode 100644 index 00000000..bf2ecc62 --- /dev/null +++ b/example/api/src/scripts/0-posts.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS `posts` ( + `id` INT(30) NOT NULL AUTO_INCREMENT, + `title` VARCHAR(64) NOT NULL, + `content` VARCHAR(512) NOT NULL, + deleted BOOLEAN NOT NULL DEFAULT FALSE, + editorId INT NULL, + created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY(`id`) +); \ No newline at end of file diff --git a/example/api/src/test_data_seeder.py b/example/api/src/test_data_seeder.py new file mode 100644 index 00000000..f50eea6f --- /dev/null +++ b/example/api/src/test_data_seeder.py @@ -0,0 +1,31 @@ +from faker import Faker + +from cpl.database.abc import DataSeederABC +from cpl.query import Enumerable +from model.post import Post +from model.post_dao import PostDao + + +fake = Faker() + + +class TestDataSeeder(DataSeederABC): + + def __init__(self, posts: PostDao): + DataSeederABC.__init__(self) + + self._posts = posts + + async def seed(self): + if await self._posts.count() == 0: + await self._seed_posts() + + async def _seed_posts(self): + posts = Enumerable.range(0, 100).select( + lambda x: Post( + id=0, + title=fake.sentence(nb_words=6), + content=fake.paragraph(nb_sentences=6), + ) + ).to_list() + await self._posts.create_many(posts, skip_editor=True) diff --git a/example/database/src/model/city.py b/example/database/src/model/city.py index c98bef85..2d61f92f 100644 --- a/example/database/src/model/city.py +++ b/example/database/src/model/city.py @@ -5,16 +5,16 @@ from cpl.core.typing import SerialId from cpl.database.abc.db_model_abc import DbModelABC -class City(DbModelABC): +class City(DbModelABC[Self]): def __init__( self, id: int, name: str, zip: str, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None= None, + updated: datetime | None= None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._name = name diff --git a/example/database/src/model/user.py b/example/database/src/model/user.py index 445c56b7..e0116423 100644 --- a/example/database/src/model/user.py +++ b/example/database/src/model/user.py @@ -5,7 +5,7 @@ from cpl.core.typing import SerialId from cpl.database.abc.db_model_abc import DbModelABC -class User(DbModelABC): +class User(DbModelABC[Self]): def __init__( self, @@ -13,9 +13,9 @@ class User(DbModelABC): name: str, city_id: int = 0, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None= None, + updated: datetime | None= None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._name = name diff --git a/src/cpl-auth/cpl/auth/schema/_administration/api_key.py b/src/cpl-auth/cpl/auth/schema/_administration/api_key.py index 16f57a7d..995628e2 100644 --- a/src/cpl-auth/cpl/auth/schema/_administration/api_key.py +++ b/src/cpl-auth/cpl/auth/schema/_administration/api_key.py @@ -1,6 +1,6 @@ import secrets from datetime import datetime -from typing import Optional, Union +from typing import Optional, Union, Self from async_property import async_property @@ -16,7 +16,7 @@ from cpl.dependency.service_provider import ServiceProvider _logger = Logger(__name__) -class ApiKey(DbModelABC): +class ApiKey(DbModelABC[Self]): def __init__( self, @@ -25,8 +25,8 @@ class ApiKey(DbModelABC): key: Union[str, bytes], deleted: bool = False, editor_id: Optional[Id] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + created: datetime | None= None, + updated: datetime | None= None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._identifier = identifier diff --git a/src/cpl-auth/cpl/auth/schema/_administration/auth_user.py b/src/cpl-auth/cpl/auth/schema/_administration/auth_user.py index 5409e468..e9eff14d 100644 --- a/src/cpl-auth/cpl/auth/schema/_administration/auth_user.py +++ b/src/cpl-auth/cpl/auth/schema/_administration/auth_user.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime -from typing import Optional +from typing import Optional, Self from async_property import async_property from keycloak import KeycloakGetError @@ -13,15 +13,15 @@ from cpl.database.logger import DBLogger from cpl.dependency import get_provider -class AuthUser(DbModelABC): +class AuthUser(DbModelABC[Self]): def __init__( self, id: SerialId, keycloak_id: str, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None= None, + updated: datetime | None= None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._keycloak_id = keycloak_id @@ -87,4 +87,3 @@ class AuthUser(DbModelABC): self._keycloak_id = str(uuid.UUID(int=0)) await auth_user_dao.update(self) - diff --git a/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py b/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py index 8963259f..4b27549a 100644 --- a/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py +++ b/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py @@ -5,7 +5,7 @@ from cpl.auth.schema._administration.auth_user import AuthUser from cpl.database import TableManager from cpl.database.abc import DbModelDaoABC from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder -from cpl.dependency import ServiceProvider +from cpl.dependency.context import get_provider class AuthUserDao(DbModelDaoABC[AuthUser]): diff --git a/src/cpl-auth/cpl/auth/schema/_permission/api_key_permission.py b/src/cpl-auth/cpl/auth/schema/_permission/api_key_permission.py index 8a7f8e4b..59132955 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/api_key_permission.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/api_key_permission.py @@ -15,9 +15,9 @@ class ApiKeyPermission(DbJoinModelABC): api_key_id: SerialId, permission_id: SerialId, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None= None, + updated: datetime | None= None, ): DbJoinModelABC.__init__(self, api_key_id, permission_id, id, deleted, editor_id, created, updated) self._api_key_id = api_key_id diff --git a/src/cpl-auth/cpl/auth/schema/_permission/permission.py b/src/cpl-auth/cpl/auth/schema/_permission/permission.py index e5bb046d..8db9c477 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/permission.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/permission.py @@ -1,20 +1,20 @@ from datetime import datetime -from typing import Optional +from typing import Optional, Self from cpl.core.typing import SerialId from cpl.database.abc import DbModelABC -class Permission(DbModelABC): +class Permission(DbModelABC[Self]): def __init__( self, id: SerialId, name: str, description: str, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None= None, + updated: datetime | None= None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._name = name diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role.py b/src/cpl-auth/cpl/auth/schema/_permission/role.py index 325fec91..24a5d82d 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional +from typing import Optional, Self from async_property import async_property @@ -9,16 +9,16 @@ from cpl.database.abc import DbModelABC from cpl.dependency import ServiceProvider -class Role(DbModelABC): +class Role(DbModelABC[Self]): def __init__( self, id: SerialId, name: str, description: str, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None= None, + updated: datetime | None= None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._name = name diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py b/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py index 33b60f04..82bacb4a 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional +from typing import Optional, Self from async_property import async_property @@ -8,16 +8,16 @@ from cpl.database.abc import DbModelABC from cpl.dependency import ServiceProvider -class RolePermission(DbModelABC): +class RolePermission(DbModelABC[Self]): def __init__( self, id: SerialId, role_id: SerialId, permission_id: SerialId, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None= None, + updated: datetime | None= None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._role_id = role_id diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role_user.py b/src/cpl-auth/cpl/auth/schema/_permission/role_user.py index 6f1f659e..5db0f892 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role_user.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role_user.py @@ -15,9 +15,9 @@ class RoleUser(DbJoinModelABC): user_id: SerialId, role_id: SerialId, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None= None, + updated: datetime | None= None, ): DbJoinModelABC.__init__(self, id, user_id, role_id, deleted, editor_id, created, updated) self._user_id = user_id diff --git a/src/cpl-core/cpl/core/utils/credential_manager.py b/src/cpl-core/cpl/core/utils/credential_manager.py index d030dc94..126afd6a 100644 --- a/src/cpl-core/cpl/core/utils/credential_manager.py +++ b/src/cpl-core/cpl/core/utils/credential_manager.py @@ -2,10 +2,6 @@ import os from cryptography.fernet import Fernet -from cpl.core.log.logger import Logger - -_logger = Logger(__name__) - class CredentialManager: r"""Handles credential encryption and decryption""" @@ -14,6 +10,7 @@ class CredentialManager: @classmethod def with_secret(cls, file: str = None): + from cpl.core.log import Logger if file is None: file = ".secret" @@ -25,12 +22,12 @@ class CredentialManager: with open(file, "w") as secret_file: secret_file.write(Fernet.generate_key().decode()) secret_file.close() - _logger.warning("Secret file not found, regenerating") + Logger(__name__).warning("Secret file not found, regenerating") with open(file, "r") as secret_file: secret = secret_file.read().strip() if secret == "" or secret is None: - _logger.fatal("No secret found in .secret file.") + Logger(__name__).fatal("No secret found in .secret file.") cls._secret = str(secret) diff --git a/src/cpl-database/cpl/database/abc/data_access_object_abc.py b/src/cpl-database/cpl/database/abc/data_access_object_abc.py index 95a12e05..44f2a0bf 100644 --- a/src/cpl-database/cpl/database/abc/data_access_object_abc.py +++ b/src/cpl-database/cpl/database/abc/data_access_object_abc.py @@ -46,6 +46,10 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]): def table_name(self) -> str: return self._table_name + @property + def type(self) -> Type[T_DBM]: + return self._model_type + def has_attribute(self, attr_name: Attribute) -> bool: """ Check if the attribute exists in the DAO @@ -490,16 +494,16 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]): table, join_condition = self.__foreign_tables[attr] builder.with_left_join(table, join_condition) - if filters: + if filters is not None: await self._build_conditions(builder, filters, external_table_deps) - if sorts: + if sorts is not None: self._build_sorts(builder, sorts, external_table_deps) - if take: + if take is not None: builder.with_limit(take) - if skip: + if skip is not None: builder.with_offset(skip) for external_table in external_table_deps: diff --git a/src/cpl-database/cpl/database/abc/db_join_model_abc.py b/src/cpl-database/cpl/database/abc/db_join_model_abc.py index c81bd50d..55327419 100644 --- a/src/cpl-database/cpl/database/abc/db_join_model_abc.py +++ b/src/cpl-database/cpl/database/abc/db_join_model_abc.py @@ -12,9 +12,9 @@ class DbJoinModelABC[T](DbModelABC[T]): source_id: Id, foreign_id: Id, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None= None, + updated: datetime | None= None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) diff --git a/src/cpl-database/cpl/database/abc/db_model_abc.py b/src/cpl-database/cpl/database/abc/db_model_abc.py index edbd1f3b..5791afe3 100644 --- a/src/cpl-database/cpl/database/abc/db_model_abc.py +++ b/src/cpl-database/cpl/database/abc/db_model_abc.py @@ -10,9 +10,9 @@ class DbModelABC(ABC, Generic[T]): self, id: Id, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None= None, + updated: datetime | None= None, ): self._id = id self._deleted = deleted diff --git a/src/cpl-database/cpl/database/model/database_settings.py b/src/cpl-database/cpl/database/model/database_settings.py index ccf1ad44..fa6154af 100644 --- a/src/cpl-database/cpl/database/model/database_settings.py +++ b/src/cpl-database/cpl/database/model/database_settings.py @@ -1,6 +1,6 @@ from typing import Optional -from cpl.core.configuration import Configuration +from cpl.core.configuration.configuration import Configuration from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC diff --git a/src/cpl-database/cpl/database/schema/executed_migration.py b/src/cpl-database/cpl/database/schema/executed_migration.py index 3b9ed1c5..02b99dc3 100644 --- a/src/cpl-database/cpl/database/schema/executed_migration.py +++ b/src/cpl-database/cpl/database/schema/executed_migration.py @@ -1,15 +1,15 @@ from datetime import datetime -from typing import Optional +from typing import Optional, Self from cpl.database.abc import DbModelABC -class ExecutedMigration(DbModelABC): +class ExecutedMigration(DbModelABC[Self]): def __init__( self, migration_id: str, - created: Optional[datetime] = None, - modified: Optional[datetime] = None, + created: datetime | None= None, + modified: datetime | None= None, ): DbModelABC.__init__(self, migration_id, False, created, modified) diff --git a/src/cpl-graphql/cpl/graphql/schema/collection.py b/src/cpl-graphql/cpl/graphql/schema/collection.py index 68b8aa69..6cac07a8 100644 --- a/src/cpl-graphql/cpl/graphql/schema/collection.py +++ b/src/cpl-graphql/cpl/graphql/schema/collection.py @@ -18,7 +18,7 @@ class CollectionGraphTypeFactory: gql_type = strawberry.type( type( - f"{node_type.__name__}Collection", + f"{node_type.__name__.replace("GraphType", "")}Collection", (), { "__annotations__": { diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py index a453734a..5539c4c6 100644 --- a/src/cpl-graphql/cpl/graphql/schema/query.py +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -4,6 +4,7 @@ from typing import Callable, Type, Any, Optional import strawberry from strawberry.exceptions import StrawberryException +from cpl.database.abc.data_access_object_abc import DataAccessObjectABC from cpl.dependency.inject import inject from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol @@ -69,9 +70,6 @@ class Query(StrawberryProtocol): sort_type: Type[StrawberryProtocol], resolver: Callable, ) -> Field: - # self._schema.with_type(filter_type) - # self._schema.with_type(sort_type) - def _resolve_collection(filter=None, sort=None, skip=0, take=10): items = resolver() if filter: @@ -103,6 +101,53 @@ class Query(StrawberryProtocol): f.with_argument(int, "take", default_value=10) return f + def dao_collection_field( + self, + t: Type[StrawberryProtocol], + dao_type: Type[DataAccessObjectABC], + name: str, + filter_type: Type[StrawberryProtocol], + sort_type: Type[StrawberryProtocol], + ) -> Field: + assert issubclass(dao_type, DataAccessObjectABC), "dao_type must be a subclass of DataAccessObjectABC" + dao = self._provider.get_service(dao_type) + if not dao: + raise ValueError(f"DAO '{dao_type.__name__}' not registered in service provider") + + filter = self._provider.get_service(filter_type) + if not filter: + raise ValueError(f"Filter '{filter_type.__name__}' not registered in service provider") + + sort = self._provider.get_service(sort_type) + if not sort: + raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider") + + async def _resolver(filter=None, sort=None, take=10, skip=0): + sort_dict = None + + if sort is not None: + sort_dict = {} + for k, v in sort.__dict__.items(): + if v is None: + continue + + if isinstance(v, SortOrder): + sort_dict[k] = str(v.value).lower() + continue + + sort_dict[k] = str(v).lower() + + total_count = await dao.count(filter) + data = await dao.find_by(filter, sort_dict, take, skip) + return Collection(nodes=data, total_count=total_count, count=len(data)) + + f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver) + f.with_argument(filter.to_strawberry(), "filter") + f.with_argument(sort.to_strawberry(), "sort") + f.with_argument(int, "skip", default_value=0) + f.with_argument(int, "take", default_value=10) + return f + @staticmethod def _build_resolver(f: "Field"): params: list[inspect.Parameter] = [] @@ -164,4 +209,4 @@ class Query(StrawberryProtocol): namespace[name] = self._field_to_strawberry(f) namespace["__annotations__"] = annotations - return strawberry.type(type(f"{self.__class__.__name__}GraphType", (), namespace)) + return strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace)) diff --git a/src/cpl-graphql/cpl/graphql/schema/sort/sort_order.py b/src/cpl-graphql/cpl/graphql/schema/sort/sort_order.py index cc3122a4..cb8e8177 100644 --- a/src/cpl-graphql/cpl/graphql/schema/sort/sort_order.py +++ b/src/cpl-graphql/cpl/graphql/schema/sort/sort_order.py @@ -2,5 +2,5 @@ from enum import Enum, auto class SortOrder(Enum): - ASC = auto() - DESC = auto() \ No newline at end of file + ASC = "ASC" + DESC = "DESC" \ No newline at end of file From a12a4082dbdb6fd75404515bcd755be8520b906d Mon Sep 17 00:00:00 2001 From: edraft Date: Sat, 27 Sep 2025 22:35:48 +0200 Subject: [PATCH 07/20] Dao complex filtering #181 --- example/api/src/model/post_query.py | 9 ++--- src/cpl-graphql/cpl/graphql/graphql_module.py | 6 +++ .../cpl/graphql/schema/filter/bool_filter.py | 10 +++++ .../cpl/graphql/schema/filter/date_filter.py | 18 +++++++++ .../cpl/graphql/schema/filter/filter.py | 16 ++++++++ .../cpl/graphql/schema/filter/int_filter.py | 16 ++++++++ .../graphql/schema/filter/string_filter.py | 16 ++++++++ src/cpl-graphql/cpl/graphql/schema/input.py | 38 ++++++++++++++----- src/cpl-graphql/cpl/graphql/schema/query.py | 24 +++++++++++- 9 files changed, 137 insertions(+), 16 deletions(-) create mode 100644 src/cpl-graphql/cpl/graphql/schema/filter/bool_filter.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/filter/string_filter.py diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py index 6e25dddc..2e5b2998 100644 --- a/example/api/src/model/post_query.py +++ b/example/api/src/model/post_query.py @@ -7,9 +7,9 @@ from model.post import Post class PostFilter(Filter[Post]): def __init__(self): Filter.__init__(self) - self.field("id", int) - self.field("title", str) - self.field("content", str) + self.int_field("id") + self.string_field("title") + self.string_field("content") class PostSort(Sort[Post]): def __init__(self): @@ -18,7 +18,6 @@ class PostSort(Sort[Post]): self.field("title", SortOrder) self.field("content", SortOrder) - class PostGraphType(GraphType[Post]): def __init__(self): @@ -35,4 +34,4 @@ class PostGraphType(GraphType[Post]): self.string_field( "content", resolver=lambda root: root.content, - ) \ No newline at end of file + ) diff --git a/src/cpl-graphql/cpl/graphql/graphql_module.py b/src/cpl-graphql/cpl/graphql/graphql_module.py index 70efa400..2d5d6b93 100644 --- a/src/cpl-graphql/cpl/graphql/graphql_module.py +++ b/src/cpl-graphql/cpl/graphql/graphql_module.py @@ -1,6 +1,11 @@ from cpl.api.api_module import ApiModule from cpl.dependency.module.module import Module from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.schema.filter.bool_filter import BoolFilter +from cpl.graphql.schema.filter.date_filter import DateFilter +from cpl.graphql.schema.filter.filter import Filter +from cpl.graphql.schema.filter.int_filter import IntFilter +from cpl.graphql.schema.filter.string_filter import StringFilter from cpl.graphql.schema.root_query import RootQuery from cpl.graphql.service.schema import Schema from cpl.graphql.service.service import GraphQLService @@ -10,6 +15,7 @@ class GraphQLModule(Module): dependencies = [ApiModule] singleton = [Schema, RootQuery] scoped = [GraphQLService] + transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter] @staticmethod def configure(services: ServiceProvider) -> None: diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/bool_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/bool_filter.py new file mode 100644 index 00000000..4be0db85 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/filter/bool_filter.py @@ -0,0 +1,10 @@ +from cpl.graphql.schema.input import Input + + +class BoolFilter(Input[bool]): + def __init__(self): + super().__init__() + self.field("equal", bool, optional=True) + self.field("notEqual", bool, optional=True) + self.field("isNull", bool, optional=True) + self.field("isNotNull", bool, optional=True) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py new file mode 100644 index 00000000..2dd1bcf8 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py @@ -0,0 +1,18 @@ +from datetime import datetime + +from cpl.graphql.schema.input import Input + + +class DateFilter(Input[datetime]): + def __init__(self): + super().__init__() + self.field("equal", datetime, optional=True) + self.field("notEqual", datetime, optional=True) + self.field("greater", datetime, optional=True) + self.field("greaterOrEqual", datetime, optional=True) + self.field("less", datetime, optional=True) + self.field("lessOrEqual", datetime, optional=True) + self.field("isNull", datetime, optional=True) + self.field("isNotNull", datetime, optional=True) + self.field("in", list[datetime], optional=True) + self.field("notIn", list[datetime], optional=True) \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py index 2f76c4b4..d1d502e2 100644 --- a/src/cpl-graphql/cpl/graphql/schema/filter/filter.py +++ b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py @@ -1,7 +1,23 @@ from cpl.core.typing import T +from cpl.graphql.schema.filter.bool_filter import BoolFilter +from cpl.graphql.schema.filter.date_filter import DateFilter +from cpl.graphql.schema.filter.int_filter import IntFilter +from cpl.graphql.schema.filter.string_filter import StringFilter from cpl.graphql.schema.input import Input class Filter(Input[T]): def __init__(self): Input.__init__(self) + + def string_field(self, name: str): + self.field(name, StringFilter()) + + def int_field(self, name: str): + self.field(name, IntFilter()) + + def bool_field(self, name: str): + self.field(name, BoolFilter()) + + def date_field(self, name: str): + self.field(name, DateFilter()) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py new file mode 100644 index 00000000..be9eba74 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py @@ -0,0 +1,16 @@ +from cpl.graphql.schema.input import Input + + +class IntFilter(Input[int]): + def __init__(self): + super().__init__() + self.field("equal", int, optional=True) + self.field("notEqual", int, optional=True) + self.field("greater", int, optional=True) + self.field("greaterOrEqual", int, optional=True) + self.field("less", int, optional=True) + self.field("lessOrEqual", int, optional=True) + self.field("isNull", int, optional=True) + self.field("isNotNull", int, optional=True) + self.field("in", list[int], optional=True) + self.field("notIn", list[int], optional=True) \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/string_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/string_filter.py new file mode 100644 index 00000000..7c060abc --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/filter/string_filter.py @@ -0,0 +1,16 @@ +from cpl.graphql.schema.input import Input + + +class StringFilter(Input[str]): + def __init__(self): + super().__init__() + self.field("equal", str, optional=True) + self.field("notEqual", str, optional=True) + self.field("contains", str, optional=True) + self.field("notContains", str, optional=True) + self.field("startsWith", str, optional=True) + self.field("endsWith", str, optional=True) + self.field("isNull", str, optional=True) + self.field("isNotNull", str, optional=True) + self.field("in", list[str], optional=True) + self.field("notIn", list[str], optional=True) diff --git a/src/cpl-graphql/cpl/graphql/schema/input.py b/src/cpl-graphql/cpl/graphql/schema/input.py index 4c9afc86..82ff31de 100644 --- a/src/cpl-graphql/cpl/graphql/schema/input.py +++ b/src/cpl-graphql/cpl/graphql/schema/input.py @@ -1,4 +1,4 @@ -from typing import Generic, Dict, Type, Any, Optional +from typing import Generic, Dict, Type, Optional, Self, Union import strawberry @@ -6,6 +6,7 @@ from cpl.core.typing import T from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.schema.field import Field +_PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"} class Input(StrawberryProtocol, Generic[T]): def __init__(self): @@ -14,21 +15,40 @@ class Input(StrawberryProtocol, Generic[T]): def get_fields(self) -> dict[str, Field]: return self._fields - def field(self, name: str, typ: type, optional: bool = True): + def field(self, name: str, typ: Union[type, "Input"], optional: bool = True): self._fields[name] = Field(name, typ, optional=optional) + _registry: dict[type, Type] = {} + def to_strawberry(self) -> Type: + cls = self.__class__ + if cls in self._registry: + return self._registry[cls] + annotations = {} namespace = {} for name, f in self._fields.items(): - ann = f.type if not f.optional else Optional[f.type] - annotations[name] = ann + typ = f.type + if isinstance(typ, type) and issubclass(typ, Input): + typ = typ().to_strawberry() + elif isinstance(typ, Input): + typ = typ.to_strawberry() - if f.optional: - namespace[name] = None - elif f.default is not None: - namespace[name] = f.default + ann = typ if not f.optional else Optional[typ] + + py_name = name + "_" if name in _PYTHON_KEYWORDS else name + annotations[py_name] = ann + + field_args = {} + if py_name != name: + field_args["name"] = name + + default = None if f.optional else f.default + namespace[py_name] = strawberry.field(default=default, **field_args) namespace["__annotations__"] = annotations - return strawberry.input(type(f"{self.__class__.__name__}Input", (), namespace)) + + gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace)) + Input._registry[cls] = gql_type + return gql_type diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py index 5539c4c6..736de81f 100644 --- a/src/cpl-graphql/cpl/graphql/schema/query.py +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -122,9 +122,29 @@ class Query(StrawberryProtocol): if not sort: raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider") + def input_to_dict(obj) -> dict | None: + if obj is None: + return None + + result = {} + for k, v in obj.__dict__.items(): + if v is None: + continue + + # verschachtelte Inputs rekursiv + if hasattr(v, "__dict__"): + result[k] = input_to_dict(v) + else: + result[k] = v + return result + async def _resolver(filter=None, sort=None, take=10, skip=0): + filter_dict = input_to_dict(filter) if filter is not None else None sort_dict = None + if filter is not None: + pass + if sort is not None: sort_dict = {} for k, v in sort.__dict__.items(): @@ -137,8 +157,8 @@ class Query(StrawberryProtocol): sort_dict[k] = str(v).lower() - total_count = await dao.count(filter) - data = await dao.find_by(filter, sort_dict, take, skip) + total_count = await dao.count(filter_dict) + data = await dao.find_by(filter_dict, sort_dict, take, skip) return Collection(nodes=data, total_count=total_count, count=len(data)) f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver) From 20e5da57702d969970664dfd524d524effb4aa20 Mon Sep 17 00:00:00 2001 From: edraft Date: Sun, 28 Sep 2025 01:09:46 +0200 Subject: [PATCH 08/20] Recursive complex filtering #181 --- example/api/src/main.py | 13 ++- example/api/src/model/author.py | 30 ++++++ example/api/src/model/author_dao.py | 11 +++ example/api/src/model/author_query.py | 37 ++++++++ example/api/src/model/post.py | 6 ++ example/api/src/model/post_dao.py | 6 +- example/api/src/model/post_query.py | 14 ++- example/api/src/scripts/0-posts.sql | 16 +++- example/api/src/test_data_seeder.py | 19 +++- .../cpl/database/mysql/mysql_pool.py | 92 ++++++++++++------- .../cpl/database/postgres/postgres_pool.py | 2 +- .../cpl/graphql/schema/collection.py | 8 +- .../cpl/graphql/schema/filter/filter.py | 5 + src/cpl-graphql/cpl/graphql/schema/input.py | 9 +- src/cpl-graphql/cpl/graphql/schema/query.py | 12 ++- .../cpl/graphql/utils/type_collector.py | 17 ++++ 16 files changed, 249 insertions(+), 48 deletions(-) create mode 100644 example/api/src/model/author.py create mode 100644 example/api/src/model/author_dao.py create mode 100644 example/api/src/model/author_query.py create mode 100644 src/cpl-graphql/cpl/graphql/utils/type_collector.py diff --git a/example/api/src/main.py b/example/api/src/main.py index 777f62fe..7273381a 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -14,6 +14,8 @@ from cpl.core.utils.cache import Cache from cpl.database.mysql.mysql_module import MySQLModule from cpl.graphql.application.graphql_app import GraphQLApp from cpl.graphql.graphql_module import GraphQLModule +from model.author_dao import AuthorDao +from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort from model.post_dao import PostDao from model.post_query import PostFilter, PostSort, PostGraphType from queries.hello import HelloQuery @@ -49,14 +51,18 @@ def main(): .add_transient(AuthUserFilter) .add_transient(AuthUserSort) .add_transient(HelloQuery) + # test data + .add_singleton(TestDataSeeder) + # authors + .add_transient(AuthorDao) + .add_transient(AuthorGraphType) + .add_transient(AuthorFilter) + .add_transient(AuthorSort) # posts .add_transient(PostDao) .add_transient(PostGraphType) .add_transient(PostFilter) .add_transient(PostSort) - - # test data - .add_singleton(TestDataSeeder) ) app = builder.build() @@ -78,6 +84,7 @@ def main(): schema = app.with_graphql() schema.query.string_field("ping", resolver=lambda: "pong") schema.query.with_query("hello", HelloQuery) + schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort) schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort) app.with_playground() diff --git a/example/api/src/model/author.py b/example/api/src/model/author.py new file mode 100644 index 00000000..05e2d3e3 --- /dev/null +++ b/example/api/src/model/author.py @@ -0,0 +1,30 @@ +from datetime import datetime +from typing import Self + +from cpl.core.typing import SerialId +from cpl.database.abc import DbModelABC + + +class Author(DbModelABC[Self]): + + def __init__( + self, + id: int, + first_name: str, + last_name: str, + deleted: bool = False, + editor_id: SerialId | None = None, + created: datetime | None = None, + updated: datetime | None = None, + ): + DbModelABC.__init__(self, id, deleted, editor_id, created, updated) + self._first_name = first_name + self._last_name = last_name + + @property + def first_name(self) -> str: + return self._first_name + + @property + def last_name(self) -> str: + return self._last_name diff --git a/example/api/src/model/author_dao.py b/example/api/src/model/author_dao.py new file mode 100644 index 00000000..98b997a6 --- /dev/null +++ b/example/api/src/model/author_dao.py @@ -0,0 +1,11 @@ +from cpl.database.abc import DbModelDaoABC +from model.author import Author + + +class AuthorDao(DbModelDaoABC): + + def __init__(self): + DbModelDaoABC.__init__(self, Author, "authors") + + self.attribute(Author.first_name, str) + self.attribute(Author.last_name, str) \ No newline at end of file diff --git a/example/api/src/model/author_query.py b/example/api/src/model/author_query.py new file mode 100644 index 00000000..f7f1d1df --- /dev/null +++ b/example/api/src/model/author_query.py @@ -0,0 +1,37 @@ +from cpl.graphql.schema.filter.filter import Filter +from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.schema.sort.sort import Sort +from cpl.graphql.schema.sort.sort_order import SortOrder +from model.author import Author + +class AuthorFilter(Filter[Author]): + def __init__(self): + Filter.__init__(self) + self.int_field("id") + self.string_field("firstName") + self.string_field("lastName") + +class AuthorSort(Sort[Author]): + def __init__(self): + Sort.__init__(self) + self.field("id", SortOrder) + self.field("firstName", SortOrder) + self.field("lastName", SortOrder) + +class AuthorGraphType(GraphType[Author]): + + def __init__(self): + GraphType.__init__(self) + + self.int_field( + "id", + resolver=lambda root: root.id, + ) + self.string_field( + "firstName", + resolver=lambda root: root.first_name, + ) + self.string_field( + "lastName", + resolver=lambda root: root.last_name, + ) diff --git a/example/api/src/model/post.py b/example/api/src/model/post.py index a2d22d60..d5801cd0 100644 --- a/example/api/src/model/post.py +++ b/example/api/src/model/post.py @@ -10,6 +10,7 @@ class Post(DbModelABC[Self]): def __init__( self, id: int, + author_id: SerialId, title: str, content: str, deleted: bool = False, @@ -18,9 +19,14 @@ class Post(DbModelABC[Self]): updated: datetime | None = None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) + self._author_id = author_id self._title = title self._content = content + @property + def author_id(self) -> SerialId: + return self._author_id + @property def title(self) -> str: return self._title diff --git a/example/api/src/model/post_dao.py b/example/api/src/model/post_dao.py index da283fef..be8e5668 100644 --- a/example/api/src/model/post_dao.py +++ b/example/api/src/model/post_dao.py @@ -1,11 +1,15 @@ from cpl.database.abc import DbModelDaoABC +from model.author_dao import AuthorDao from model.post import Post class PostDao(DbModelDaoABC): - def __init__(self): + def __init__(self, authors: AuthorDao): DbModelDaoABC.__init__(self, Post, "posts") + self.attribute(Post.author_id, int, db_name="authorId") + self.reference("author", "id", Post.author_id, "authors", authors) + self.attribute(Post.title, str) self.attribute(Post.content, str) \ No newline at end of file diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py index 2e5b2998..e3bc41af 100644 --- a/example/api/src/model/post_query.py +++ b/example/api/src/model/post_query.py @@ -2,12 +2,15 @@ from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.graph_type import GraphType from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort_order import SortOrder +from model.author_dao import AuthorDao +from model.author_query import AuthorGraphType, AuthorFilter from model.post import Post class PostFilter(Filter[Post]): def __init__(self): Filter.__init__(self) self.int_field("id") + self.filter_field("author", AuthorFilter) self.string_field("title") self.string_field("content") @@ -20,13 +23,22 @@ class PostSort(Sort[Post]): class PostGraphType(GraphType[Post]): - def __init__(self): + def __init__(self, authors: AuthorDao): GraphType.__init__(self) self.int_field( "id", resolver=lambda root: root.id, ) + + async def _a(root: Post): + return await authors.get_by_id(root.author_id) + + self.object_field( + "author", + AuthorGraphType, + resolver=_a#lambda root: root.author_id, + ) self.string_field( "title", resolver=lambda root: root.title, diff --git a/example/api/src/scripts/0-posts.sql b/example/api/src/scripts/0-posts.sql index bf2ecc62..26268f17 100644 --- a/example/api/src/scripts/0-posts.sql +++ b/example/api/src/scripts/0-posts.sql @@ -1,7 +1,19 @@ +CREATE TABLE IF NOT EXISTS `authors` ( + `id` INT(30) NOT NULL AUTO_INCREMENT, + `firstname` VARCHAR(64) NOT NULL, + `lastname` VARCHAR(64) NOT NULL, + deleted BOOLEAN NOT NULL DEFAULT FALSE, + editorId INT NULL, + created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY(`id`) + ); + CREATE TABLE IF NOT EXISTS `posts` ( `id` INT(30) NOT NULL AUTO_INCREMENT, - `title` VARCHAR(64) NOT NULL, - `content` VARCHAR(512) NOT NULL, + `authorId` INT(30) NOT NULL REFERENCES `authors`(`id`) ON DELETE CASCADE, + `title` TEXT NOT NULL, + `content` TEXT NOT NULL, deleted BOOLEAN NOT NULL DEFAULT FALSE, editorId INT NULL, created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, diff --git a/example/api/src/test_data_seeder.py b/example/api/src/test_data_seeder.py index f50eea6f..38bcc1f1 100644 --- a/example/api/src/test_data_seeder.py +++ b/example/api/src/test_data_seeder.py @@ -2,6 +2,8 @@ from faker import Faker from cpl.database.abc import DataSeederABC from cpl.query import Enumerable +from model.author import Author +from model.author_dao import AuthorDao from model.post import Post from model.post_dao import PostDao @@ -11,19 +13,34 @@ fake = Faker() class TestDataSeeder(DataSeederABC): - def __init__(self, posts: PostDao): + def __init__(self, authors: AuthorDao, posts: PostDao): DataSeederABC.__init__(self) + self._authors = authors self._posts = posts async def seed(self): + if await self._authors.count() == 0: + await self._seed_authors() + if await self._posts.count() == 0: await self._seed_posts() + async def _seed_authors(self): + authors = Enumerable.range(0, 35).select( + lambda x: Author( + 0, + fake.first_name(), + fake.last_name(), + ) + ).to_list() + await self._authors.create_many(authors, skip_editor=True) + async def _seed_posts(self): posts = Enumerable.range(0, 100).select( lambda x: Post( id=0, + author_id=fake.random_int(min=1, max=35), title=fake.sentence(nb_words=6), content=fake.paragraph(nb_sentences=6), ) diff --git a/src/cpl-database/cpl/database/mysql/mysql_pool.py b/src/cpl-database/cpl/database/mysql/mysql_pool.py index a5422761..474bf6ce 100644 --- a/src/cpl-database/cpl/database/mysql/mysql_pool.py +++ b/src/cpl-database/cpl/database/mysql/mysql_pool.py @@ -1,6 +1,8 @@ from typing import Optional, Any - import sqlparse +import asyncio + +from mysql.connector import errors, PoolError from mysql.connector.aio import MySQLConnectionPool from cpl.core.environment import Environment @@ -10,7 +12,6 @@ from cpl.dependency.context import get_provider class MySQLPool: - def __init__(self, database_settings: DatabaseSettings): self._dbconfig = { "host": database_settings.host, @@ -25,59 +26,87 @@ class MySQLPool: "ssl_disabled": database_settings.ssl_disabled, } self._pool: Optional[MySQLConnectionPool] = None + self._pool_lock = asyncio.Lock() - async def _get_pool(self): + async def _get_pool(self) -> MySQLConnectionPool: if self._pool is None: - try: - self._pool = MySQLConnectionPool( - pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig - ) - await self._pool.initialize_pool() + async with self._pool_lock: + if self._pool is None: + try: + self._pool = MySQLConnectionPool( + pool_name="cplpool", + pool_size=Environment.get("DB_POOL_SIZE", int, 20), + **self._dbconfig, + ) + await self._pool.initialize_pool() - con = await self._pool.get_connection() - async with await con.cursor() as cursor: - await cursor.execute("SELECT 1") - await cursor.fetchall() - - await con.close() - except Exception as e: - logger = get_provider().get_service(DBLogger) - logger.fatal(f"Error connecting to the database", e) + # Testverbindung (Ping) + con = await self._pool.get_connection() + try: + async with await con.cursor() as cursor: + await cursor.execute("SELECT 1") + await cursor.fetchall() + finally: + await con.close() + except Exception as e: + logger = get_provider().get_service(DBLogger) + logger.fatal("Error connecting to the database", e) + raise return self._pool + async def _get_connection(self, retries: int = 3, delay: float = 0.5): + """Stabiler Connection-Getter mit Retry und Ping""" + pool = await self._get_pool() + + for attempt in range(retries): + try: + con = await pool.get_connection() + + # Verbindungs-Check (Ping) + try: + async with await con.cursor() as cursor: + await cursor.execute("SELECT 1") + await cursor.fetchall() + except errors.OperationalError: + await con.close() + raise + + return con + + except PoolError: + if attempt == retries - 1: + raise + await asyncio.sleep(delay) + @staticmethod async def _exec_sql(cursor: Any, query: str, args=None, multi=True): result = [] if multi: queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()] for q in queries: - if q.strip() == "": - continue - await cursor.execute(q, args) - if cursor.description is not None: - result = await cursor.fetchall() + if q: + await cursor.execute(q, args) + if cursor.description is not None: + result = await cursor.fetchall() else: await cursor.execute(query, args) if cursor.description is not None: result = await cursor.fetchall() - return result - async def execute(self, query: str, args=None, multi=True) -> list[list]: - pool = await self._get_pool() - con = await pool.get_connection() + async def execute(self, query: str, args=None, multi=True) -> list[str]: + con = await self._get_connection() try: async with await con.cursor() as cursor: - result = await self._exec_sql(cursor, query, args, multi) + res = await self._exec_sql(cursor, query, args, multi) await con.commit() - return result + return list(res) finally: await con.close() async def select(self, query: str, args=None, multi=True) -> list[str]: - pool = await self._get_pool() - con = await pool.get_connection() + con = await self._get_connection() try: async with await con.cursor() as cursor: res = await self._exec_sql(cursor, query, args, multi) @@ -86,8 +115,7 @@ class MySQLPool: await con.close() async def select_map(self, query: str, args=None, multi=True) -> list[dict]: - pool = await self._get_pool() - con = await pool.get_connection() + con = await self._get_connection() try: async with await con.cursor(dictionary=True) as cursor: res = await self._exec_sql(cursor, query, args, multi) diff --git a/src/cpl-database/cpl/database/postgres/postgres_pool.py b/src/cpl-database/cpl/database/postgres/postgres_pool.py index 891fb7f1..434c2655 100644 --- a/src/cpl-database/cpl/database/postgres/postgres_pool.py +++ b/src/cpl-database/cpl/database/postgres/postgres_pool.py @@ -27,7 +27,7 @@ class PostgresPool: self._pool: Optional[AsyncConnectionPool] = None async def _get_pool(self): - if self._pool is None: + if self._pool is None or self._pool.closed: pool = AsyncConnectionPool( conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1) ) diff --git a/src/cpl-graphql/cpl/graphql/schema/collection.py b/src/cpl-graphql/cpl/graphql/schema/collection.py index 6cac07a8..1d37a626 100644 --- a/src/cpl-graphql/cpl/graphql/schema/collection.py +++ b/src/cpl-graphql/cpl/graphql/schema/collection.py @@ -3,6 +3,7 @@ from typing import Type, Dict, List import strawberry from cpl.core.typing import T +from cpl.dependency import get_provider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol @@ -14,7 +15,12 @@ class CollectionGraphTypeFactory: if node_type in cls._cache: return cls._cache[node_type] - gql_node = node_type().to_strawberry() if hasattr(node_type, "to_strawberry") else node_type + node_t = get_provider().get_service(node_type) + if not node_t: + raise ValueError(f"Node type '{node_type.__name__}' not registered in service provider") + + + gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type gql_type = strawberry.type( type( diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py index d1d502e2..6463ace9 100644 --- a/src/cpl-graphql/cpl/graphql/schema/filter/filter.py +++ b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py @@ -1,3 +1,5 @@ +from typing import Type + from cpl.core.typing import T from cpl.graphql.schema.filter.bool_filter import BoolFilter from cpl.graphql.schema.filter.date_filter import DateFilter @@ -10,6 +12,9 @@ class Filter(Input[T]): def __init__(self): Input.__init__(self) + def filter_field(self, name: str, filter_type: Type["Filter"]): + self.field(name, filter_type()) + def string_field(self, name: str): self.field(name, StringFilter()) diff --git a/src/cpl-graphql/cpl/graphql/schema/input.py b/src/cpl-graphql/cpl/graphql/schema/input.py index 82ff31de..a4dfebdf 100644 --- a/src/cpl-graphql/cpl/graphql/schema/input.py +++ b/src/cpl-graphql/cpl/graphql/schema/input.py @@ -5,6 +5,7 @@ import strawberry from cpl.core.typing import T from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.schema.field import Field +from cpl.graphql.utils.type_collector import TypeCollector _PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"} @@ -18,12 +19,10 @@ class Input(StrawberryProtocol, Generic[T]): def field(self, name: str, typ: Union[type, "Input"], optional: bool = True): self._fields[name] = Field(name, typ, optional=optional) - _registry: dict[type, Type] = {} - def to_strawberry(self) -> Type: cls = self.__class__ - if cls in self._registry: - return self._registry[cls] + if TypeCollector.has(cls): + return TypeCollector.get(cls) annotations = {} namespace = {} @@ -50,5 +49,5 @@ class Input(StrawberryProtocol, Generic[T]): namespace["__annotations__"] = annotations gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace)) - Input._registry[cls] = gql_type + TypeCollector.set(cls, gql_type) return gql_type diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py index 736de81f..84270056 100644 --- a/src/cpl-graphql/cpl/graphql/schema/query.py +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -12,6 +12,7 @@ from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory from cpl.graphql.schema.field import Field from cpl.graphql.schema.sort.sort_order import SortOrder from cpl.graphql.typing import Resolver +from cpl.graphql.utils.type_collector import TypeCollector class Query(StrawberryProtocol): @@ -54,6 +55,9 @@ class Query(StrawberryProtocol): def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field: return self.field(name, list[t], resolver) + def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field: + return self.field(name, t().to_strawberry(), resolver) + def with_query(self, name: str, subquery_cls: Type["Query"]): sub = self._provider.get_service(subquery_cls) if not sub: @@ -221,6 +225,10 @@ class Query(StrawberryProtocol): ) from e def to_strawberry(self) -> Type: + cls = self.__class__ + if TypeCollector.has(cls): + return TypeCollector.get(cls) + annotations: dict[str, Any] = {} namespace: dict[str, Any] = {} @@ -229,4 +237,6 @@ class Query(StrawberryProtocol): namespace[name] = self._field_to_strawberry(f) namespace["__annotations__"] = annotations - return strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace)) + gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace)) + TypeCollector.set(cls, gql_type) + return gql_type diff --git a/src/cpl-graphql/cpl/graphql/utils/type_collector.py b/src/cpl-graphql/cpl/graphql/utils/type_collector.py new file mode 100644 index 00000000..c51718bf --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/utils/type_collector.py @@ -0,0 +1,17 @@ +from typing import Type + + +class TypeCollector: + _registry: dict[type, Type] = {} + + @classmethod + def has(cls, base: type) -> bool: + return base in cls._registry + + @classmethod + def get(cls, base: type) -> Type: + return cls._registry[base] + + @classmethod + def set(cls, base: type, gql_type: Type): + cls._registry[base] = gql_type \ No newline at end of file From 6f46b9499881b89ede7728c2251d947938f9de10 Mon Sep 17 00:00:00 2001 From: edraft Date: Sun, 28 Sep 2025 11:45:51 +0200 Subject: [PATCH 09/20] [WIP] with authentication #181 --- example/api/src/main.py | 2 +- example/api/src/model/author_query.py | 6 +- example/api/src/model/post_query.py | 8 +- .../cpl/api/middleware/authentication.py | 15 ++++ src/cpl-api/cpl/api/middleware/request.py | 40 ++++++++- .../cpl/graphql/_endpoints/graphql.py | 2 +- src/cpl-graphql/cpl/graphql/error.py | 14 +++ src/cpl-graphql/cpl/graphql/graphql_module.py | 2 +- src/cpl-graphql/cpl/graphql/query_context.py | 90 +++++++++++++++++++ src/cpl-graphql/cpl/graphql/schema/field.py | 58 +++++++++++- src/cpl-graphql/cpl/graphql/schema/query.py | 71 +++++++++++++-- .../cpl/graphql/service/graphql.py | 51 +++++++++++ src/cpl-graphql/cpl/graphql/service/schema.py | 15 ++++ .../cpl/graphql/service/service.py | 31 ------- src/cpl-graphql/cpl/graphql/typing.py | 14 ++- 15 files changed, 362 insertions(+), 57 deletions(-) create mode 100644 src/cpl-graphql/cpl/graphql/error.py create mode 100644 src/cpl-graphql/cpl/graphql/query_context.py create mode 100644 src/cpl-graphql/cpl/graphql/service/graphql.py delete mode 100644 src/cpl-graphql/cpl/graphql/service/service.py diff --git a/example/api/src/main.py b/example/api/src/main.py index 7273381a..c149fe47 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -85,7 +85,7 @@ def main(): schema.query.string_field("ping", resolver=lambda: "pong") schema.query.with_query("hello", HelloQuery) schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort) - schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort) + schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort).with_public(True) app.with_playground() app.with_graphiql() diff --git a/example/api/src/model/author_query.py b/example/api/src/model/author_query.py index f7f1d1df..ae365a7c 100644 --- a/example/api/src/model/author_query.py +++ b/example/api/src/model/author_query.py @@ -26,12 +26,12 @@ class AuthorGraphType(GraphType[Author]): self.int_field( "id", resolver=lambda root: root.id, - ) + ).with_public(True) self.string_field( "firstName", resolver=lambda root: root.first_name, - ) + ).with_public(True) self.string_field( "lastName", resolver=lambda root: root.last_name, - ) + ).with_public(True) diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py index e3bc41af..48845617 100644 --- a/example/api/src/model/post_query.py +++ b/example/api/src/model/post_query.py @@ -29,7 +29,7 @@ class PostGraphType(GraphType[Post]): self.int_field( "id", resolver=lambda root: root.id, - ) + ).with_public(True) async def _a(root: Post): return await authors.get_by_id(root.author_id) @@ -38,12 +38,12 @@ class PostGraphType(GraphType[Post]): "author", AuthorGraphType, resolver=_a#lambda root: root.author_id, - ) + ).with_public(True) self.string_field( "title", resolver=lambda root: root.title, - ) + ).with_public(True) self.string_field( "content", resolver=lambda root: root.content, - ) + ).with_public(True) diff --git a/src/cpl-api/cpl/api/middleware/authentication.py b/src/cpl-api/cpl/api/middleware/authentication.py index c0dc95f1..9b45c076 100644 --- a/src/cpl-api/cpl/api/middleware/authentication.py +++ b/src/cpl-api/cpl/api/middleware/authentication.py @@ -25,6 +25,21 @@ class AuthenticationMiddleware(ASGIMiddleware): request = get_request() url = request.url.path + if url not in Router.get_auth_required_routes(): + self._logger.trace(f"No authentication required for {url}") + return await self._app(scope, receive, send) + + user = getattr(request.state, "user", None) + if not user or user.deleted: + self._logger.debug(f"Unauthorized access to {url}, user missing or deleted") + return await Unauthorized("Unauthorized").asgi_response(scope, receive, send) + + return await self._call_next(scope, receive, send) + + async def _old_call__(self, scope: Scope, receive: Receive, send: Send): + request = get_request() + url = request.url.path + if url not in Router.get_auth_required_routes(): self._logger.trace(f"No authentication required for {url}") return await self._app(scope, receive, send) diff --git a/src/cpl-api/cpl/api/middleware/request.py b/src/cpl-api/cpl/api/middleware/request.py index 0cedc88b..2dc24bc5 100644 --- a/src/cpl-api/cpl/api/middleware/request.py +++ b/src/cpl-api/cpl/api/middleware/request.py @@ -9,6 +9,10 @@ from starlette.types import Scope, Receive, Send from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.logger import APILogger from cpl.api.typing import TRequest +from cpl.auth.keycloak.keycloak_client import KeycloakClient +from cpl.auth.schema import AuthUser +from cpl.auth.schema._administration.auth_user_dao import AuthUserDao +from cpl.core.ctx import set_user from cpl.dependency.inject import inject from cpl.dependency.service_provider import ServiceProvider @@ -17,12 +21,15 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa class RequestMiddleware(ASGIMiddleware): - def __init__(self, app, provider: ServiceProvider, logger: APILogger): + def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao): ASGIMiddleware.__init__(self, app) self._provider = provider self._logger = logger + self._keycloak = keycloak + self._user_dao = user_dao + self._ctx_token = None async def __call__(self, scope: Scope, receive: Receive, send: Send): @@ -30,6 +37,7 @@ class RequestMiddleware(ASGIMiddleware): await self.set_request_data(request) try: + await self._try_set_user(request) with self._provider.create_scope(): inject(await self._app(scope, receive, send)) finally: @@ -53,6 +61,36 @@ class RequestMiddleware(ASGIMiddleware): self._logger.trace(f"Clearing current request: {request.state.request_id}") _request_context.reset(self._ctx_token) + async def _try_set_user(self, request: Request): + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + return + + token = auth_header.split("Bearer ")[1] + try: + token_info = self._keycloak.introspect(token) + if not token_info.get("active", False): + return + + keycloak_id = self._keycloak.get_user_id(token) + if not keycloak_id: + return + + user = await self._user_dao.find_by_keycloak_id(keycloak_id) + if not user: + user = AuthUser(0, keycloak_id) + uid = await self._user_dao.create(user) + user = await self._user_dao.get_by_id(uid) + + if user.deleted: + return + + request.state.user = user + set_user(user) + self._logger.trace(f"User {user.id} bound to request {request.state.request_id}") + + except Exception as e: + self._logger.debug(f"Silent user binding failed: {e}") def get_request() -> Optional[TRequest]: return _request_context.get() diff --git a/src/cpl-graphql/cpl/graphql/_endpoints/graphql.py b/src/cpl-graphql/cpl/graphql/_endpoints/graphql.py index 0808d704..01cb133b 100644 --- a/src/cpl-graphql/cpl/graphql/_endpoints/graphql.py +++ b/src/cpl-graphql/cpl/graphql/_endpoints/graphql.py @@ -1,7 +1,7 @@ from starlette.requests import Request from starlette.responses import Response, JSONResponse -from cpl.graphql.service.service import GraphQLService +from cpl.graphql.service.graphql import GraphQLService async def graphql_endpoint(request: Request, service: GraphQLService) -> Response: diff --git a/src/cpl-graphql/cpl/graphql/error.py b/src/cpl-graphql/cpl/graphql/error.py new file mode 100644 index 00000000..e96e41c1 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/error.py @@ -0,0 +1,14 @@ +from graphql import GraphQLError + +from cpl.api import APIError + + +def graphql_error(api_error: APIError) -> GraphQLError: + """Convert an APIError (from cpl-api) into a GraphQL-friendly error.""" + return GraphQLError( + message=api_error.error_message, + extensions={ + "code": api_error.status_code, + }, + original_error=api_error, + ) \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/graphql_module.py b/src/cpl-graphql/cpl/graphql/graphql_module.py index 2d5d6b93..d9d66aee 100644 --- a/src/cpl-graphql/cpl/graphql/graphql_module.py +++ b/src/cpl-graphql/cpl/graphql/graphql_module.py @@ -8,7 +8,7 @@ from cpl.graphql.schema.filter.int_filter import IntFilter from cpl.graphql.schema.filter.string_filter import StringFilter from cpl.graphql.schema.root_query import RootQuery from cpl.graphql.service.schema import Schema -from cpl.graphql.service.service import GraphQLService +from cpl.graphql.service.graphql import GraphQLService class GraphQLModule(Module): diff --git a/src/cpl-graphql/cpl/graphql/query_context.py b/src/cpl-graphql/cpl/graphql/query_context.py new file mode 100644 index 00000000..79d0b965 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/query_context.py @@ -0,0 +1,90 @@ +from enum import Enum +from typing import Optional, Any + +from graphql import GraphQLResolveInfo + +from cpl.auth.schema import AuthUser, Permission +from cpl.core.utils import get_value + + +class QueryContext: + + def __init__( + self, + data: Any, + user: Optional[AuthUser], + user_permissions: Optional[list[Enum | Permission]], + is_mutation: bool = False, + *args, + **kwargs + ): + + self._data = data + self._user = user + if user_permissions is None: + user_permissions = [] + self._user_permissions: list[str] = [x.name for x in user_permissions] + + self._resolve_info = None + for arg in args: + if isinstance(arg, GraphQLResolveInfo): + self._resolve_info = arg + continue + + self._filter = kwargs.get("filters", {}) + self._sort = kwargs.get("sort", {}) + self._skip = get_value(kwargs, "skip", int) + self._take = get_value(kwargs, "take", int) + + self._input = kwargs.get("input", None) + self._args = args + self._kwargs = kwargs + + self._is_mutation = is_mutation + + @property + def data(self): + return self._data + + @property + def user(self) -> AuthUser: + return self._user + + @property + def resolve_info(self) -> Optional[GraphQLResolveInfo]: + return self._resolve_info + + @property + def filter(self) -> dict: + return self._filter + + @property + def sort(self) -> dict: + return self._sort + + @property + def skip(self) -> Optional[int]: + return self._skip + + @property + def take(self) -> Optional[int]: + return self._take + + @property + def input(self) -> Optional[Any]: + return self._input + + @property + def args(self) -> tuple: + return self._args + + @property + def kwargs(self) -> dict: + return self._kwargs + + @property + def is_mutation(self) -> bool: + return self._is_mutation + + def has_permission(self, permission: Enum | str) -> bool: + return permission.value if isinstance(permission, Enum) else permission in self._user_permissions diff --git a/src/cpl-graphql/cpl/graphql/schema/field.py b/src/cpl-graphql/cpl/graphql/schema/field.py index 2231e11c..d9417bdb 100644 --- a/src/cpl-graphql/cpl/graphql/schema/field.py +++ b/src/cpl-graphql/cpl/graphql/schema/field.py @@ -1,7 +1,8 @@ +from enum import Enum from typing import Self from cpl.graphql.schema.argument import Argument -from cpl.graphql.typing import TQuery, Resolver +from cpl.graphql.typing import TQuery, Resolver, TRequireAnyPermissions, TRequireAnyResolvers class Field: @@ -9,7 +10,7 @@ class Field: def __init__( self, name: str, - gql_type: type = None, + t: type = None, resolver: Resolver = None, optional=None, default=None, @@ -17,7 +18,7 @@ class Field: parent_type=None, ): self._name = name - self._gql_type = gql_type + self._type = t self._resolver = resolver self._optional = optional or True self._default = default @@ -26,6 +27,9 @@ class Field: self._parent_type = parent_type self._args: dict[str, Argument] = {} + self._require_any_permission = None + self._require_any = None + self._public = False @property def name(self) -> str: @@ -33,7 +37,7 @@ class Field: @property def type(self) -> type: - return self._gql_type + return self._type @property def resolver(self) -> callable: @@ -63,6 +67,34 @@ class Field: def arguments(self) -> dict[str, Argument]: return self._args + @property + def require_any_permission(self) -> TRequireAnyPermissions | None: + return self._require_any_permission + + @property + def require_any(self) -> TRequireAnyResolvers | None: + return self._require_any + + @property + def public(self) -> bool: + return self._public + + def with_type(self, t: type) -> Self: + self._type = t + return self + + def with_resolver(self, resolver: Resolver) -> Self: + self._resolver = resolver + return self + + def with_optional(self, optional: bool) -> Self: + self._optional = optional + return self + + def with_default(self, default) -> Self: + self._default = default + return self + def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self: if name in self._args: raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'") @@ -76,3 +108,21 @@ class Field: self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional) return self + + def with_require_any_permission(self, permissions: TRequireAnyPermissions) -> Self: + assert permissions is not None, "require_any_permission cannot be None" + assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type" + self._require_any_permission = permissions + return self + + def with_require_any(self, permissions: TRequireAnyPermissions, resolvers: TRequireAnyResolvers) -> Self: + assert permissions is not None, "permissions cannot be None" + assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type" + assert resolvers is not None, "resolvers cannot be None" + assert all(callable(r) for r in resolvers), "All resolvers must be callable" + self._require_any = (permissions, resolvers) + return self + + def with_public(self, public: bool = False) -> Self: + self._public = public + return self diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py index 84270056..2f0e23f0 100644 --- a/src/cpl-graphql/cpl/graphql/schema/query.py +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -1,13 +1,19 @@ +import asyncio +import functools import inspect from typing import Callable, Type, Any, Optional import strawberry from strawberry.exceptions import StrawberryException +from cpl.api import Unauthorized, Forbidden +from cpl.api.middleware.request import get_request +from cpl.core.ctx import get_user from cpl.database.abc.data_access_object_abc import DataAccessObjectABC from cpl.dependency.inject import inject from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol +from cpl.graphql.error import graphql_error from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory from cpl.graphql.schema.field import Field from cpl.graphql.schema.sort.sort_order import SortOrder @@ -202,23 +208,70 @@ class Query(StrawberryProtocol): _resolver.__signature__ = sig return _resolver + def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable: + # Signatur vom Original übernehmen + sig = getattr(resolver, "__signature__", None) + + @functools.wraps(resolver) + async def _auth_resolver(*args, **kwargs): + request = get_request() + user = get_user() + + # Public + if f.public: + return await self._maybe_await(resolver(*args, **kwargs)) + + # Auth required + if user is None: + raise graphql_error(Unauthorized("Authentication required")) + + # Permissions + if f.require_any_permission: + if not any(user.has_permission(p) for p in f.require_any_permission): + raise Forbidden("Permission denied") + + # Custom resolvers + if f.require_any: + perms, resolvers = f.require_any + if not any(user.has_permission(p) for p in perms): + for r in resolvers: + ok = await self._maybe_await(r(user, *args, **kwargs)) + if ok: + break + else: + raise Forbidden("Permission denied") + + return await self._maybe_await(resolver(*args, **kwargs)) + + # Signatur beibehalten + if sig: + _auth_resolver.__signature__ = sig + + return _auth_resolver + + @staticmethod + def _maybe_await(value): + if asyncio.iscoroutine(value): + return value + return asyncio.sleep(0, result=value) # sofort resolved Future + + def _field_to_strawberry(self, f: Field) -> Any: + resolver = None try: - if f.resolver: + if f.arguments: + resolver = self._build_resolver(f) + elif not f.resolver: + resolver = lambda *_, **__: None + else: ann = getattr(f.resolver, "__annotations__", {}) if "return" not in ann or ann["return"] is None: ann = dict(ann) ann["return"] = f.type f.resolver.__annotations__ = ann + resolver = f.resolver - if f.arguments: - resolver = self._build_resolver(f) - return strawberry.field(resolver=resolver) - - if not f.resolver: - return strawberry.field(resolver=lambda *_, **__: None) - - return strawberry.field(resolver=f.resolver) + return strawberry.field(resolver=self._wrap_with_auth(f, resolver)) except StrawberryException as e: raise Exception( f"Error converting field '{f.name}' to strawberry field: {e}" diff --git a/src/cpl-graphql/cpl/graphql/service/graphql.py b/src/cpl-graphql/cpl/graphql/service/graphql.py new file mode 100644 index 00000000..c816b2e1 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/service/graphql.py @@ -0,0 +1,51 @@ +from typing import Any, Dict, Optional + +from graphql import GraphQLError + +from cpl.api import APILogger, APIError +from cpl.api.typing import TRequest +from cpl.graphql.service.schema import Schema + + +class GraphQLService: + def __init__(self, logger: APILogger, schema: Schema): + self._logger = logger + + if schema.schema is None: + raise ValueError("Schema has not been built. Call schema.build() before using the service.") + self._schema = schema.schema + + async def execute( + self, + query: str, + variables: Optional[Dict[str, Any]], + request: TRequest, + ) -> Dict[str, Any]: + result = await self._schema.execute( + query, + variable_values=variables, + context_value={"request": request}, + ) + + response_data: Dict[str, Any] = {} + if result.errors: + errors = [] + for error in result.errors: + if isinstance(error, GraphQLError): + self._logger.error(f"GraphQL APIError: {error}") + errors.append({"message": error.message, "extensions": error.extensions}) + continue + + if isinstance(error, APIError): + self._logger.error(f"GraphQL APIError: {error}") + errors.append({"message": error.error_message, "extensions": {"code": error.status_code}}) + continue + + self._logger.error(f"GraphQL unexpected error: {error}") + errors.append({"message": str(error), "extensions": {"code": 500}}) + + response_data["errors"] = errors + if result.data: + response_data["data"] = result.data + + return response_data diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py index 23627ee4..f0b01b05 100644 --- a/src/cpl-graphql/cpl/graphql/service/schema.py +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -1,7 +1,11 @@ +import logging from typing import Type, Self import strawberry +from starlette.requests import Request +from strawberry.types import ExecutionContext +from cpl.api import APIError from cpl.api.logger import APILogger from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol @@ -40,7 +44,18 @@ class Schema: return types + def _graphql_exception_handler(self, error: Exception, execution_context: ExecutionContext): + request: Request = execution_context.context.get("request") + + if isinstance(error, APIError): + self._logger.error(f"GraphQL APIError: {error}") + return {"message": error.error_message, "extensions": {"code": error.status_code}} + + self._logger.error(f"GraphQL unexpected error: {error}") + return {"message": str(error), "extensions": {"code": 500}} + def build(self) -> strawberry.Schema: + logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL) query = self._provider.get_service(RootQuery) if not query: raise ValueError("RootQuery not registered in service provider") diff --git a/src/cpl-graphql/cpl/graphql/service/service.py b/src/cpl-graphql/cpl/graphql/service/service.py deleted file mode 100644 index f039ccbd..00000000 --- a/src/cpl-graphql/cpl/graphql/service/service.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Any, Dict, Optional - -from cpl.api.typing import TRequest -from cpl.graphql.service.schema import Schema - - -class GraphQLService: - def __init__(self, schema: Schema): - if schema.schema is None: - raise ValueError("Schema has not been built. Call schema.build() before using the service.") - self._schema = schema.schema - - async def execute( - self, - query: str, - variables: Optional[Dict[str, Any]], - request: TRequest, - ) -> Dict[str, Any]: - result = await self._schema.execute( - query, - variable_values=variables, - context_value={"request": request}, - ) - - response_data: Dict[str, Any] = {} - if result.errors: - response_data["errors"] = [str(e) for e in result.errors] - if result.data: - response_data["data"] = result.data - - return response_data diff --git a/src/cpl-graphql/cpl/graphql/typing.py b/src/cpl-graphql/cpl/graphql/typing.py index d5b63494..d36e3119 100644 --- a/src/cpl-graphql/cpl/graphql/typing.py +++ b/src/cpl-graphql/cpl/graphql/typing.py @@ -1,5 +1,15 @@ -from typing import Type, Callable +from enum import Enum +from typing import Type, Callable, List, Tuple, Awaitable + +from cpl.auth.permission import Permissions +from cpl.graphql.query_context import QueryContext TQuery = Type["Query"] Resolver = Callable -ScalarType = str | int | float | bool | object \ No newline at end of file +ScalarType = str | int | float | bool | object + +TRequireAnyPermissions = List[Enum | Permissions] | None +TRequireAnyResolvers = List[ + Callable[[QueryContext], bool | Awaitable[bool]], +] +TRequireAny = Tuple[TRequireAnyPermissions, TRequireAnyResolvers] From 3286a95cbfc6768b6cb4a24c311c3b048238055c Mon Sep 17 00:00:00 2001 From: edraft Date: Sun, 28 Sep 2025 14:53:57 +0200 Subject: [PATCH 10/20] require any #181 --- example/api/src/main.py | 8 ++- example/api/src/model/post_query.py | 12 ++-- example/api/src/permissions.py | 8 +++ .../cpl/auth/permission/permission_module.py | 3 +- .../cpl/auth/permission/permission_seeder.py | 1 - .../cpl/auth/permission/role_seeder.py | 60 +++++++++++++++++++ .../schema/_administration/auth_user_dao.py | 12 ++-- .../cpl/auth/schema/_permission/role.py | 2 +- .../schema/_permission/role_permission.py | 2 +- .../cpl/auth/schema/_permission/role_user.py | 2 +- .../scripts/mysql/3-roles-permissions.sql | 38 ++++++------ .../scripts/postgres/3-roles-permissions.sql | 8 +-- .../database/abc/data_access_object_abc.py | 2 +- .../cpl/database/table_manager.py | 2 +- src/cpl-graphql/cpl/graphql/query_context.py | 41 +------------ src/cpl-graphql/cpl/graphql/schema/field.py | 11 +++- src/cpl-graphql/cpl/graphql/schema/query.py | 60 ++++++++----------- .../cpl/graphql/service/graphql.py | 15 ++--- src/cpl-graphql/cpl/graphql/service/schema.py | 13 ---- 19 files changed, 165 insertions(+), 135 deletions(-) create mode 100644 example/api/src/permissions.py create mode 100644 src/cpl-auth/cpl/auth/permission/role_seeder.py diff --git a/example/api/src/main.py b/example/api/src/main.py index c149fe47..06e39aa4 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -18,6 +18,7 @@ from model.author_dao import AuthorDao from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort from model.post_dao import PostDao from model.post_query import PostFilter, PostSort, PostGraphType +from permissions import PostPermissions from queries.hello import HelloQuery from scoped_service import ScopedService from service import PingService @@ -85,11 +86,16 @@ def main(): schema.query.string_field("ping", resolver=lambda: "pong") schema.query.with_query("hello", HelloQuery) schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort) - schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort).with_public(True) + ( + schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort) + .with_require_any_permission(PostPermissions.read) + ) app.with_playground() app.with_graphiql() + app.with_permissions(PostPermissions) + provider = builder.service_provider user_cache = provider.get_service(Cache[AuthUser]) role_cache = provider.get_service(Cache[Role]) diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py index 48845617..381c94ca 100644 --- a/example/api/src/model/post_query.py +++ b/example/api/src/model/post_query.py @@ -1,3 +1,4 @@ +from cpl.graphql.query_context import QueryContext from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.graph_type import GraphType from cpl.graphql.schema.sort.sort import Sort @@ -34,11 +35,12 @@ class PostGraphType(GraphType[Post]): async def _a(root: Post): return await authors.get_by_id(root.author_id) - self.object_field( - "author", - AuthorGraphType, - resolver=_a#lambda root: root.author_id, - ).with_public(True) + def r_name(ctx: QueryContext): + return ctx.user.username == "admin" + + self.object_field("author", AuthorGraphType, resolver=_a).with_require_any( + [], [r_name] + ) self.string_field( "title", resolver=lambda root: root.title, diff --git a/example/api/src/permissions.py b/example/api/src/permissions.py new file mode 100644 index 00000000..d2e1d450 --- /dev/null +++ b/example/api/src/permissions.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class PostPermissions(Enum): + + read = "post.read" + write = "post.write" + delete = "post.delete" \ No newline at end of file diff --git a/src/cpl-auth/cpl/auth/permission/permission_module.py b/src/cpl-auth/cpl/auth/permission/permission_module.py index 16955c57..eafaeadc 100644 --- a/src/cpl-auth/cpl/auth/permission/permission_module.py +++ b/src/cpl-auth/cpl/auth/permission/permission_module.py @@ -2,6 +2,7 @@ from cpl.auth.auth_module import AuthModule from cpl.auth.permission.permission_seeder import PermissionSeeder from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions_registry import PermissionsRegistry +from cpl.auth.permission.role_seeder import RoleSeeder from cpl.database.abc.data_seeder_abc import DataSeederABC from cpl.database.database_module import DatabaseModule from cpl.dependency.module.module import Module @@ -10,7 +11,7 @@ from cpl.dependency.service_collection import ServiceCollection class PermissionsModule(Module): dependencies = [DatabaseModule, AuthModule] - singleton = [(DataSeederABC, PermissionSeeder)] + transient = [(DataSeederABC, PermissionSeeder), (DataSeederABC, RoleSeeder)] @staticmethod def register(collection: ServiceCollection): diff --git a/src/cpl-auth/cpl/auth/permission/permission_seeder.py b/src/cpl-auth/cpl/auth/permission/permission_seeder.py index d9d42cfa..aab41139 100644 --- a/src/cpl-auth/cpl/auth/permission/permission_seeder.py +++ b/src/cpl-auth/cpl/auth/permission/permission_seeder.py @@ -1,4 +1,3 @@ -from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions_registry import PermissionsRegistry from cpl.auth.schema import ( Permission, diff --git a/src/cpl-auth/cpl/auth/permission/role_seeder.py b/src/cpl-auth/cpl/auth/permission/role_seeder.py new file mode 100644 index 00000000..2c7687bd --- /dev/null +++ b/src/cpl-auth/cpl/auth/permission/role_seeder.py @@ -0,0 +1,60 @@ +from cpl.auth.schema import ( + Role, + RolePermission, + PermissionDao, + RoleDao, + RolePermissionDao, + ApiKeyDao, + ApiKeyPermissionDao, + AuthUserDao, + RoleUserDao, + RoleUser, +) +from cpl.database.abc.data_seeder_abc import DataSeederABC +from cpl.database.logger import DBLogger + + +class RoleSeeder(DataSeederABC): + def __init__( + self, + logger: DBLogger, + permission_dao: PermissionDao, + role_dao: RoleDao, + role_permission_dao: RolePermissionDao, + api_key_dao: ApiKeyDao, + api_key_permission_dao: ApiKeyPermissionDao, + user_dao: AuthUserDao, + role_user_dao: RoleUserDao, + ): + DataSeederABC.__init__(self) + self._logger = logger + self._permission_dao = permission_dao + self._role_dao = role_dao + self._role_permission_dao = role_permission_dao + self._api_key_dao = api_key_dao + self._api_key_permission_dao = api_key_permission_dao + self._user_dao = user_dao + self._role_user_dao = role_user_dao + + async def seed(self): + self._logger.info("Creating admin role") + roles = await self._role_dao.get_all() + if len(roles) == 0: + rid = await self._role_dao.create(Role(0, "admin", "Default admin role")) + permissions = await self._permission_dao.get_all() + + await self._role_permission_dao.create_many( + [RolePermission(0, rid, permission.id) for permission in permissions] + ) + + role = await self._role_dao.get_by_name("admin") + if len(await role.users) > 0: + return + + users = await self._user_dao.get_all() + if len(users) == 0: + return + + user = users[0] + self._logger.warning(f"Assigning admin role to first user {user.id}") + await self._role_user_dao.create(RoleUser(0, role.id, user.id)) diff --git a/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py b/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py index 4b27549a..bf59a534 100644 --- a/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py +++ b/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py @@ -1,6 +1,8 @@ from typing import Optional, Union from cpl.auth.permission.permissions import Permissions +from cpl.auth.schema._permission.permission_dao import PermissionDao +from cpl.auth.schema._permission.permission import Permission from cpl.auth.schema._administration.auth_user import AuthUser from cpl.database import TableManager from cpl.database.abc import DbModelDaoABC @@ -10,10 +12,12 @@ from cpl.dependency.context import get_provider class AuthUserDao(DbModelDaoABC[AuthUser]): - def __init__(self): + def __init__(self, permission_dao: PermissionDao): DbModelDaoABC.__init__(self, AuthUser, TableManager.get("auth_users")) - self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId") + self._permissions = permission_dao + + self.attribute(AuthUser.keycloak_id, str) async def get_users(): return [(x.id, x.username, x.email) for x in await self.get_all()] @@ -54,7 +58,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]): return result[0]["count"] > 0 - async def get_permissions(self, user_id: int) -> list[Permissions]: + async def get_permissions(self, user_id: int) -> list[Permission]: result = await self._db.select_map( f""" SELECT p.* @@ -66,4 +70,4 @@ class AuthUserDao(DbModelDaoABC[AuthUser]): AND ru.deleted = FALSE; """ ) - return [Permissions(p["name"]) for p in result] + return [self._permissions.to_object(x) for x in result] diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role.py b/src/cpl-auth/cpl/auth/schema/_permission/role.py index 24a5d82d..3c1b0a1f 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role.py @@ -6,7 +6,7 @@ from async_property import async_property from cpl.auth.permission.permissions import Permissions from cpl.core.typing import SerialId from cpl.database.abc import DbModelABC -from cpl.dependency import ServiceProvider +from cpl.dependency import ServiceProvider, get_provider class Role(DbModelABC[Self]): diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py b/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py index 82bacb4a..c58d8682 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py @@ -5,7 +5,7 @@ from async_property import async_property from cpl.core.typing import SerialId from cpl.database.abc import DbModelABC -from cpl.dependency import ServiceProvider +from cpl.dependency import ServiceProvider, get_provider class RolePermission(DbModelABC[Self]): diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role_user.py b/src/cpl-auth/cpl/auth/schema/_permission/role_user.py index 5db0f892..72504768 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role_user.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role_user.py @@ -5,7 +5,7 @@ from async_property import async_property from cpl.core.typing import SerialId from cpl.database.abc import DbJoinModelABC -from cpl.dependency import ServiceProvider +from cpl.dependency import ServiceProvider, get_provider class RoleUser(DbJoinModelABC): diff --git a/src/cpl-auth/cpl/auth/scripts/mysql/3-roles-permissions.sql b/src/cpl-auth/cpl/auth/scripts/mysql/3-roles-permissions.sql index f3082a48..63a58fbf 100644 --- a/src/cpl-auth/cpl/auth/scripts/mysql/3-roles-permissions.sql +++ b/src/cpl-auth/cpl/auth/scripts/mysql/3-roles-permissions.sql @@ -89,14 +89,14 @@ END; CREATE TABLE IF NOT EXISTS permission_role_permissions ( id INT AUTO_INCREMENT PRIMARY KEY, - RoleId INT NOT NULL, + roleId INT NOT NULL, permissionId INT NOT NULL, deleted BOOL NOT NULL DEFAULT FALSE, editorId INT NULL, created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - CONSTRAINT UQ_RolePermission UNIQUE (RoleId, permissionId), - CONSTRAINT FK_RolePermissions_Role FOREIGN KEY (RoleId) REFERENCES permission_roles (id) ON DELETE CASCADE, + CONSTRAINT UQ_RolePermission UNIQUE (roleId, permissionId), + CONSTRAINT FK_RolePermissions_Role FOREIGN KEY (roleId) REFERENCES permission_roles (id) ON DELETE CASCADE, CONSTRAINT FK_RolePermissions_Permission FOREIGN KEY (permissionId) REFERENCES permission_permissions (id) ON DELETE CASCADE, CONSTRAINT FK_RolePermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) ); @@ -104,7 +104,7 @@ CREATE TABLE IF NOT EXISTS permission_role_permissions CREATE TABLE IF NOT EXISTS permission_role_permissions_history ( id INT NOT NULL, - RoleId INT NOT NULL, + roleId INT NOT NULL, permissionId INT NOT NULL, deleted BOOL NOT NULL, editorId INT NULL, @@ -118,8 +118,8 @@ CREATE TRIGGER TR_RolePermissionsUpdate FOR EACH ROW BEGIN INSERT INTO permission_role_permissions_history - (id, RoleId, permissionId, deleted, editorId, created, updated) - VALUES (OLD.id, OLD.RoleId, OLD.permissionId, OLD.deleted, OLD.editorId, OLD.created, NOW()); + (id, roleId, permissionId, deleted, editorId, created, updated) + VALUES (OLD.id, OLD.roleId, OLD.permissionId, OLD.deleted, OLD.editorId, OLD.created, NOW()); END; CREATE TRIGGER TR_RolePermissionsDelete @@ -128,30 +128,30 @@ CREATE TRIGGER TR_RolePermissionsDelete FOR EACH ROW BEGIN INSERT INTO permission_role_permissions_history - (id, RoleId, permissionId, deleted, editorId, created, updated) - VALUES (OLD.id, OLD.RoleId, OLD.permissionId, 1, OLD.editorId, OLD.created, NOW()); + (id, roleId, permissionId, deleted, editorId, created, updated) + VALUES (OLD.id, OLD.roleId, OLD.permissionId, 1, OLD.editorId, OLD.created, NOW()); END; CREATE TABLE IF NOT EXISTS permission_role_auth_users ( id INT AUTO_INCREMENT PRIMARY KEY, - RoleId INT NOT NULL, - UserId INT NOT NULL, + roleId INT NOT NULL, + userId INT NOT NULL, deleted BOOL NOT NULL DEFAULT FALSE, editorId INT NULL, created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - CONSTRAINT UQ_RoleUser UNIQUE (RoleId, UserId), - CONSTRAINT FK_Roleauth_users_Role FOREIGN KEY (RoleId) REFERENCES permission_roles (id) ON DELETE CASCADE, - CONSTRAINT FK_Roleauth_users_User FOREIGN KEY (UserId) REFERENCES administration_auth_users (id) ON DELETE CASCADE, + CONSTRAINT UQ_RoleUser UNIQUE (roleId, userId), + CONSTRAINT FK_Roleauth_users_Role FOREIGN KEY (roleId) REFERENCES permission_roles (id) ON DELETE CASCADE, + CONSTRAINT FK_Roleauth_users_User FOREIGN KEY (userId) REFERENCES administration_auth_users (id) ON DELETE CASCADE, CONSTRAINT FK_Roleauth_users_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) ); CREATE TABLE IF NOT EXISTS permission_role_auth_users_history ( id INT NOT NULL, - RoleId INT NOT NULL, - UserId INT NOT NULL, + roleId INT NOT NULL, + userId INT NOT NULL, deleted BOOL NOT NULL, editorId INT NULL, created TIMESTAMP NOT NULL, @@ -164,8 +164,8 @@ CREATE TRIGGER TR_Roleauth_usersUpdate FOR EACH ROW BEGIN INSERT INTO permission_role_auth_users_history - (id, RoleId, UserId, deleted, editorId, created, updated) - VALUES (OLD.id, OLD.RoleId, OLD.UserId, OLD.deleted, OLD.editorId, OLD.created, NOW()); + (id, roleId, userId, deleted, editorId, created, updated) + VALUES (OLD.id, OLD.roleId, OLD.userId, OLD.deleted, OLD.editorId, OLD.created, NOW()); END; CREATE TRIGGER TR_Roleauth_usersDelete @@ -174,6 +174,6 @@ CREATE TRIGGER TR_Roleauth_usersDelete FOR EACH ROW BEGIN INSERT INTO permission_role_auth_users_history - (id, RoleId, UserId, deleted, editorId, created, updated) - VALUES (OLD.id, OLD.RoleId, OLD.UserId, 1, OLD.editorId, OLD.created, NOW()); + (id, roleId, userId, deleted, editorId, created, updated) + VALUES (OLD.id, OLD.roleId, OLD.userId, 1, OLD.editorId, OLD.created, NOW()); END; diff --git a/src/cpl-auth/cpl/auth/scripts/postgres/3-roles-permissions.sql b/src/cpl-auth/cpl/auth/scripts/postgres/3-roles-permissions.sql index 42b9283b..72400191 100644 --- a/src/cpl-auth/cpl/auth/scripts/postgres/3-roles-permissions.sql +++ b/src/cpl-auth/cpl/auth/scripts/postgres/3-roles-permissions.sql @@ -79,7 +79,7 @@ CREATE TRIGGER versioning_trigger EXECUTE PROCEDURE public.history_trigger_function(); -- Role user -CREATE TABLE permission.role_users +CREATE TABLE permission.role_auth_users ( id SERIAL PRIMARY KEY, RoleId INT NOT NULL REFERENCES permission.roles (id) ON DELETE CASCADE, @@ -93,13 +93,13 @@ CREATE TABLE permission.role_users CONSTRAINT UQ_RoleUser UNIQUE (RoleId, UserId) ); -CREATE TABLE permission.role_users_history +CREATE TABLE permission.role_auth_users_history ( - LIKE permission.role_users + LIKE permission.role_auth_users ); CREATE TRIGGER versioning_trigger BEFORE INSERT OR UPDATE OR DELETE - ON permission.role_users + ON permission.role_auth_users FOR EACH ROW EXECUTE PROCEDURE public.history_trigger_function(); \ No newline at end of file diff --git a/src/cpl-database/cpl/database/abc/data_access_object_abc.py b/src/cpl-database/cpl/database/abc/data_access_object_abc.py index 44f2a0bf..7f1e235b 100644 --- a/src/cpl-database/cpl/database/abc/data_access_object_abc.py +++ b/src/cpl-database/cpl/database/abc/data_access_object_abc.py @@ -85,7 +85,7 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]): self.__ignored_attributes.add(attr_name) if not db_name: - db_name = attr_name.lower().replace("_", "") + db_name = String.to_camel_case(attr_name) self.__db_names[attr_name] = db_name self.__db_names[db_name] = db_name diff --git a/src/cpl-database/cpl/database/table_manager.py b/src/cpl-database/cpl/database/table_manager.py index 9bd1f6b2..2d5ac533 100644 --- a/src/cpl-database/cpl/database/table_manager.py +++ b/src/cpl-database/cpl/database/table_manager.py @@ -32,7 +32,7 @@ class TableManager: ServerTypes.MYSQL: "permission_role_permissions", }, "role_users": { - ServerTypes.POSTGRES: "permission.role_users", + ServerTypes.POSTGRES: "permission.role_auth_users", ServerTypes.MYSQL: "permission_role_auth_users", }, } diff --git a/src/cpl-graphql/cpl/graphql/query_context.py b/src/cpl-graphql/cpl/graphql/query_context.py index 79d0b965..9b75d694 100644 --- a/src/cpl-graphql/cpl/graphql/query_context.py +++ b/src/cpl-graphql/cpl/graphql/query_context.py @@ -4,6 +4,7 @@ from typing import Optional, Any from graphql import GraphQLResolveInfo from cpl.auth.schema import AuthUser, Permission +from cpl.core.ctx import get_user from cpl.core.utils import get_value @@ -11,19 +12,13 @@ class QueryContext: def __init__( self, - data: Any, - user: Optional[AuthUser], user_permissions: Optional[list[Enum | Permission]], is_mutation: bool = False, *args, **kwargs ): - - self._data = data - self._user = user - if user_permissions is None: - user_permissions = [] - self._user_permissions: list[str] = [x.name for x in user_permissions] + self._user = get_user() + self._user_permissions = user_permissions or [] self._resolve_info = None for arg in args: @@ -31,21 +26,11 @@ class QueryContext: self._resolve_info = arg continue - self._filter = kwargs.get("filters", {}) - self._sort = kwargs.get("sort", {}) - self._skip = get_value(kwargs, "skip", int) - self._take = get_value(kwargs, "take", int) - - self._input = kwargs.get("input", None) self._args = args self._kwargs = kwargs self._is_mutation = is_mutation - @property - def data(self): - return self._data - @property def user(self) -> AuthUser: return self._user @@ -54,26 +39,6 @@ class QueryContext: def resolve_info(self) -> Optional[GraphQLResolveInfo]: return self._resolve_info - @property - def filter(self) -> dict: - return self._filter - - @property - def sort(self) -> dict: - return self._sort - - @property - def skip(self) -> Optional[int]: - return self._skip - - @property - def take(self) -> Optional[int]: - return self._take - - @property - def input(self) -> Optional[Any]: - return self._input - @property def args(self) -> tuple: return self._args diff --git a/src/cpl-graphql/cpl/graphql/schema/field.py b/src/cpl-graphql/cpl/graphql/schema/field.py index d9417bdb..421413a4 100644 --- a/src/cpl-graphql/cpl/graphql/schema/field.py +++ b/src/cpl-graphql/cpl/graphql/schema/field.py @@ -5,7 +5,7 @@ from cpl.graphql.schema.argument import Argument from cpl.graphql.typing import TQuery, Resolver, TRequireAnyPermissions, TRequireAnyResolvers -class Field: +class Field: def __init__( self, @@ -109,9 +109,12 @@ class Field: self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional) return self - def with_require_any_permission(self, permissions: TRequireAnyPermissions) -> Self: + def with_require_any_permission(self, *permissions: TRequireAnyPermissions) -> Self: + if not isinstance(permissions, list): + permissions = list(permissions) + assert permissions is not None, "require_any_permission cannot be None" - assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type" + assert all(isinstance(x, (str, Enum)) for x in permissions), "All permissions must be of Permission type" self._require_any_permission = permissions return self @@ -124,5 +127,7 @@ class Field: return self def with_public(self, public: bool = False) -> Self: + assert self._require_any is None, "Field cannot be public and have require_any set" + assert self._require_any_permission is None, "Field cannot be public and have require_any_permission set" self._public = public return self diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py index 2f0e23f0..0b8df16f 100644 --- a/src/cpl-graphql/cpl/graphql/schema/query.py +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -1,19 +1,19 @@ -import asyncio import functools import inspect +from asyncio import iscoroutinefunction from typing import Callable, Type, Any, Optional import strawberry from strawberry.exceptions import StrawberryException from cpl.api import Unauthorized, Forbidden -from cpl.api.middleware.request import get_request from cpl.core.ctx import get_user from cpl.database.abc.data_access_object_abc import DataAccessObjectABC from cpl.dependency.inject import inject from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.error import graphql_error +from cpl.graphql.query_context import QueryContext from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory from cpl.graphql.schema.field import Field from cpl.graphql.schema.sort.sort_order import SortOrder @@ -141,7 +141,6 @@ class Query(StrawberryProtocol): if v is None: continue - # verschachtelte Inputs rekursiv if hasattr(v, "__dict__"): result[k] = input_to_dict(v) else: @@ -152,9 +151,6 @@ class Query(StrawberryProtocol): filter_dict = input_to_dict(filter) if filter is not None else None sort_dict = None - if filter is not None: - pass - if sort is not None: sort_dict = {} for k, v in sort.__dict__.items(): @@ -202,59 +198,55 @@ class Query(StrawberryProtocol): sig = inspect.Signature(parameters=params, return_annotation=f.type) - def _resolver(*args, **kwargs): - return f.resolver(*args, **kwargs) if f.resolver else None + async def _resolver(*args, **kwargs): + if f.resolver is None: + return None + + if iscoroutinefunction(f.resolver): + return await f.resolver(*args, **kwargs) + return f.resolver(*args, **kwargs) _resolver.__signature__ = sig return _resolver def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable: - # Signatur vom Original übernehmen sig = getattr(resolver, "__signature__", None) @functools.wraps(resolver) async def _auth_resolver(*args, **kwargs): - request = get_request() + if f.public: + return await self._run_resolver(resolver, *args, **kwargs) + user = get_user() - # Public - if f.public: - return await self._maybe_await(resolver(*args, **kwargs)) - - # Auth required if user is None: - raise graphql_error(Unauthorized("Authentication required")) + raise graphql_error(Unauthorized(f"{f.name}: Authentication required")) - # Permissions if f.require_any_permission: - if not any(user.has_permission(p) for p in f.require_any_permission): - raise Forbidden("Permission denied") + if not any([await user.has_permission(p) for p in f.require_any_permission]): + raise graphql_error(Forbidden(f"{f.name}: Permission denied")) - # Custom resolvers if f.require_any: perms, resolvers = f.require_any - if not any(user.has_permission(p) for p in perms): - for r in resolvers: - ok = await self._maybe_await(r(user, *args, **kwargs)) - if ok: - break - else: - raise Forbidden("Permission denied") + if not any([await user.has_permission(p) for p in perms]): + ctx = QueryContext([x.name for x in await user.permissions]) + resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers] - return await self._maybe_await(resolver(*args, **kwargs)) + if not any(resolved): + raise graphql_error(Forbidden(f"{f.name}: Permission denied")) + + return await self._run_resolver(resolver, *args, **kwargs) - # Signatur beibehalten if sig: _auth_resolver.__signature__ = sig return _auth_resolver @staticmethod - def _maybe_await(value): - if asyncio.iscoroutine(value): - return value - return asyncio.sleep(0, result=value) # sofort resolved Future - + async def _run_resolver(r: Callable, *args, **kwargs): + if iscoroutinefunction(r): + return await r(*args, **kwargs) + return r(*args, **kwargs) def _field_to_strawberry(self, f: Field) -> Any: resolver = None diff --git a/src/cpl-graphql/cpl/graphql/service/graphql.py b/src/cpl-graphql/cpl/graphql/service/graphql.py index c816b2e1..cb4ee667 100644 --- a/src/cpl-graphql/cpl/graphql/service/graphql.py +++ b/src/cpl-graphql/cpl/graphql/service/graphql.py @@ -31,17 +31,18 @@ class GraphQLService: if result.errors: errors = [] for error in result.errors: - if isinstance(error, GraphQLError): - self._logger.error(f"GraphQL APIError: {error}") - errors.append({"message": error.message, "extensions": error.extensions}) - continue - if isinstance(error, APIError): - self._logger.error(f"GraphQL APIError: {error}") + self._logger.error(f"GraphQL APIError", error) errors.append({"message": error.error_message, "extensions": {"code": error.status_code}}) continue - self._logger.error(f"GraphQL unexpected error: {error}") + if isinstance(error, GraphQLError): + + self._logger.error(f"GraphQLError", error) + errors.append({"message": error.message, "extensions": error.extensions}) + continue + + self._logger.error(f"GraphQL unexpected error", error) errors.append({"message": str(error), "extensions": {"code": 500}}) response_data["errors"] = errors diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py index f0b01b05..c1c43cdc 100644 --- a/src/cpl-graphql/cpl/graphql/service/schema.py +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -2,10 +2,7 @@ import logging from typing import Type, Self import strawberry -from starlette.requests import Request -from strawberry.types import ExecutionContext -from cpl.api import APIError from cpl.api.logger import APILogger from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol @@ -44,16 +41,6 @@ class Schema: return types - def _graphql_exception_handler(self, error: Exception, execution_context: ExecutionContext): - request: Request = execution_context.context.get("request") - - if isinstance(error, APIError): - self._logger.error(f"GraphQL APIError: {error}") - return {"message": error.error_message, "extensions": {"code": error.status_code}} - - self._logger.error(f"GraphQL unexpected error: {error}") - return {"message": str(error), "extensions": {"code": 500}} - def build(self) -> strawberry.Schema: logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL) query = self._provider.get_service(RootQuery) From 39d06dfe48aa14a5c5b723b14b9a8f54ab3de117 Mon Sep 17 00:00:00 2001 From: edraft Date: Sun, 28 Sep 2025 18:51:28 +0200 Subject: [PATCH 11/20] Added mutations #181 --- example/api/src/main.py | 13 +- example/api/src/model/author_dao.py | 4 +- example/api/src/model/post.py | 8 + example/api/src/model/post_dao.py | 2 +- example/api/src/model/post_query.py | 85 +++++++- example/api/src/queries/hello.py | 2 +- src/cpl-graphql/cpl/graphql/abc/query_abc.py | 178 +++++++++++++++++ src/cpl-graphql/cpl/graphql/graphql_module.py | 5 +- src/cpl-graphql/cpl/graphql/query_context.py | 3 +- .../cpl/graphql/schema/argument.py | 38 ++-- src/cpl-graphql/cpl/graphql/schema/field.py | 14 +- src/cpl-graphql/cpl/graphql/schema/input.py | 44 ++++- .../cpl/graphql/schema/mutation.py | 25 +++ src/cpl-graphql/cpl/graphql/schema/query.py | 184 ++---------------- .../cpl/graphql/schema/root_mutation.py | 6 + src/cpl-graphql/cpl/graphql/service/schema.py | 23 ++- 16 files changed, 424 insertions(+), 210 deletions(-) create mode 100644 src/cpl-graphql/cpl/graphql/abc/query_abc.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/mutation.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/root_mutation.py diff --git a/example/api/src/main.py b/example/api/src/main.py index 06e39aa4..fdd7dff0 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -5,7 +5,6 @@ from api.src.queries.hello import UserGraphType, AuthUserFilter, AuthUserSort, A from api.src.queries.user import UserFilter, UserSort from cpl.api.api_module import ApiModule from cpl.application.application_builder import ApplicationBuilder -from cpl.auth.permission.permissions import Permissions from cpl.auth.schema import AuthUser, Role from cpl.core.configuration import Configuration from cpl.core.console import Console @@ -17,7 +16,7 @@ from cpl.graphql.graphql_module import GraphQLModule from model.author_dao import AuthorDao from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort from model.post_dao import PostDao -from model.post_query import PostFilter, PostSort, PostGraphType +from model.post_query import PostFilter, PostSort, PostGraphType, PostMutation from permissions import PostPermissions from queries.hello import HelloQuery from scoped_service import ScopedService @@ -64,6 +63,7 @@ def main(): .add_transient(PostGraphType) .add_transient(PostFilter) .add_transient(PostSort) + .add_transient(PostMutation) ) app = builder.build() @@ -77,8 +77,8 @@ def main(): path="/route1", fn=lambda r: JSONResponse("route1"), method="GET", - authentication=True, - permissions=[Permissions.administrator], + # authentication=True, + # permissions=[Permissions.administrator], ) app.with_routes_directory("routes") @@ -88,9 +88,12 @@ def main(): schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort) ( schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort) - .with_require_any_permission(PostPermissions.read) + # .with_require_any_permission(PostPermissions.read) + .with_public() ) + schema.mutation.with_mutation("post", PostMutation).with_public() + app.with_playground() app.with_graphiql() diff --git a/example/api/src/model/author_dao.py b/example/api/src/model/author_dao.py index 98b997a6..d1b1afc0 100644 --- a/example/api/src/model/author_dao.py +++ b/example/api/src/model/author_dao.py @@ -7,5 +7,5 @@ class AuthorDao(DbModelDaoABC): def __init__(self): DbModelDaoABC.__init__(self, Author, "authors") - self.attribute(Author.first_name, str) - self.attribute(Author.last_name, str) \ No newline at end of file + self.attribute(Author.first_name, str, db_name="firstname") + self.attribute(Author.last_name, str, db_name="lastname") \ No newline at end of file diff --git a/example/api/src/model/post.py b/example/api/src/model/post.py index d5801cd0..15b670b8 100644 --- a/example/api/src/model/post.py +++ b/example/api/src/model/post.py @@ -31,6 +31,14 @@ class Post(DbModelABC[Self]): def title(self) -> str: return self._title + @title.setter + def title(self, value: str): + self._title = value + @property def content(self) -> str: return self._content + + @content.setter + def content(self, value: str): + self._content = value diff --git a/example/api/src/model/post_dao.py b/example/api/src/model/post_dao.py index be8e5668..3205f8de 100644 --- a/example/api/src/model/post_dao.py +++ b/example/api/src/model/post_dao.py @@ -3,7 +3,7 @@ from model.author_dao import AuthorDao from model.post import Post -class PostDao(DbModelDaoABC): +class PostDao(DbModelDaoABC[Post]): def __init__(self, authors: AuthorDao): DbModelDaoABC.__init__(self, Post, "posts") diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py index 381c94ca..6334c51e 100644 --- a/example/api/src/model/post_query.py +++ b/example/api/src/model/post_query.py @@ -1,11 +1,15 @@ from cpl.graphql.query_context import QueryContext from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.schema.input import Input +from cpl.graphql.schema.mutation import Mutation from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort_order import SortOrder from model.author_dao import AuthorDao from model.author_query import AuthorGraphType, AuthorFilter from model.post import Post +from model.post_dao import PostDao + class PostFilter(Filter[Post]): def __init__(self): @@ -38,9 +42,7 @@ class PostGraphType(GraphType[Post]): def r_name(ctx: QueryContext): return ctx.user.username == "admin" - self.object_field("author", AuthorGraphType, resolver=_a).with_require_any( - [], [r_name] - ) + self.object_field("author", AuthorGraphType, resolver=_a).with_public(True)# .with_require_any([], [r_name])) self.string_field( "title", resolver=lambda root: root.title, @@ -49,3 +51,80 @@ class PostGraphType(GraphType[Post]): "content", resolver=lambda root: root.content, ).with_public(True) + + +class PostCreateInput(Input[Post]): + title: str + content: str + author_id: int + + def __init__(self): + Input.__init__(self) + self.string_field("title").with_required() + self.string_field("content").with_required() + self.int_field("author_id").with_required() + +class PostUpdateInput(Input[Post]): + title: str + content: str + author_id: int + + def __init__(self): + Input.__init__(self) + self.int_field("id").with_required() + self.string_field("title").with_required(False) + self.string_field("content").with_required(False) + +class PostMutation(Mutation): + + def __init__(self, posts: PostDao, authors: AuthorDao): + Mutation.__init__(self) + + self._posts = posts + self._authors = authors + + self.field("create", int, resolver=self.create_post).with_public().with_required().with_argument( + "input", + PostCreateInput, + ).with_required() + self.field("update", bool, resolver=self.update_post).with_public().with_required().with_argument( + "input", + PostUpdateInput, + ).with_required() + self.field("delete", bool, resolver=self.delete_post).with_public().with_required().with_argument( + "id", + int, + ).with_required() + self.field("restore", bool, resolver=self.restore_post).with_public().with_required().with_argument( + "id", + int, + ).with_required() + + async def create_post(self, input: PostCreateInput) -> int: + return await self._posts.create(Post(0, input.author_id, input.title, input.content)) + + async def update_post(self, input: PostUpdateInput) -> bool: + post = await self._posts.get_by_id(input.id) + if post is None: + return False + + post.title = input.title if input.title is not None else post.title + post.content = input.content if input.content is not None else post.content + + await self._posts.update(post) + return True + + async def delete_post(self, id: int) -> bool: + post = await self._posts.get_by_id(id) + if post is None: + return False + await self._posts.delete(post) + return True + + async def restore_post(self, id: int) -> bool: + post = await self._posts.get_by_id(id) + if post is None: + return False + await self._posts.restore(post) + return True + diff --git a/example/api/src/queries/hello.py b/example/api/src/queries/hello.py index addd9173..864e39ab 100644 --- a/example/api/src/queries/hello.py +++ b/example/api/src/queries/hello.py @@ -44,7 +44,7 @@ class HelloQuery(Query): self.string_field( "message", resolver=lambda name: f"Hello {name} {get_request().state.request_id}", - ).with_argument(str, "name", "Name to greet", "world") + ).with_argument("name", str, "Name to greet", "world") self.collection_field( UserGraphType, diff --git a/src/cpl-graphql/cpl/graphql/abc/query_abc.py b/src/cpl-graphql/cpl/graphql/abc/query_abc.py new file mode 100644 index 00000000..5023ebea --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/abc/query_abc.py @@ -0,0 +1,178 @@ +import functools +import inspect +from abc import ABC +from asyncio import iscoroutinefunction +from typing import Callable, Type, Any, Optional + +import strawberry +from strawberry.exceptions import StrawberryException + +from cpl.api import Unauthorized, Forbidden +from cpl.core.ctx.user_context import get_user +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol +from cpl.graphql.error import graphql_error +from cpl.graphql.query_context import QueryContext +from cpl.graphql.schema.field import Field +from cpl.graphql.typing import Resolver +from cpl.graphql.utils.type_collector import TypeCollector + + +class QueryABC(StrawberryProtocol, ABC): + + def __init__(self): + ABC.__init__(self) + self._fields: dict[str, Field] = {} + + @property + def fields(self) -> dict[str, Field]: + return self._fields + + @property + def fields_count(self) -> int: + return len(self._fields) + + def get_fields(self) -> dict[str, Field]: + return self._fields + + def field( + self, + name: str, + t: type, + resolver: Resolver = None, + ) -> Field: + from cpl.graphql.schema.field import Field + + self._fields[name] = Field(name, t, resolver) + return self._fields[name] + + def string_field(self, name: str, resolver: Resolver = None) -> Field: + return self.field(name, str, resolver) + + def int_field(self, name: str, resolver: Resolver = None) -> Field: + return self.field(name, int, resolver) + + def float_field(self, name: str, resolver: Resolver = None) -> Field: + return self.field(name, float, resolver) + + def bool_field(self, name: str, resolver: Resolver = None) -> Field: + return self.field(name, bool, resolver) + + def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field: + return self.field(name, list[t], resolver) + + def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field: + return self.field(name, t().to_strawberry(), resolver) + + @staticmethod + def _build_resolver(f: "Field"): + params: list[inspect.Parameter] = [] + for arg in f.arguments.values(): + _type = arg.type + if isinstance(_type, type) and issubclass(_type, StrawberryProtocol): + _type = _type().to_strawberry() + + ann = Optional[_type] if arg.optional else _type + + if arg.default is None: + param = inspect.Parameter( + arg.name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=ann, + ) + else: + param = inspect.Parameter( + arg.name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=ann, + default=arg.default, + ) + + params.append(param) + + sig = inspect.Signature(parameters=params, return_annotation=f.type) + + async def _resolver(*args, **kwargs): + if f.resolver is None: + return None + + if iscoroutinefunction(f.resolver): + return await f.resolver(*args, **kwargs) + return f.resolver(*args, **kwargs) + + _resolver.__signature__ = sig + return _resolver + + def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable: + sig = getattr(resolver, "__signature__", None) + + @functools.wraps(resolver) + async def _auth_resolver(*args, **kwargs): + if f.public: + return await self._run_resolver(resolver, *args, **kwargs) + + user = get_user() + + if user is None: + raise graphql_error(Unauthorized(f"{f.name}: Authentication required")) + + if f.require_any_permission: + if not any([await user.has_permission(p) for p in f.require_any_permission]): + raise graphql_error(Forbidden(f"{f.name}: Permission denied")) + + if f.require_any: + perms, resolvers = f.require_any + if not any([await user.has_permission(p) for p in perms]): + ctx = QueryContext([x.name for x in await user.permissions]) + resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers] + + if not any(resolved): + raise graphql_error(Forbidden(f"{f.name}: Permission denied")) + + return await self._run_resolver(resolver, *args, **kwargs) + + if sig: + _auth_resolver.__signature__ = sig + + return _auth_resolver + + @staticmethod + async def _run_resolver(r: Callable, *args, **kwargs): + if iscoroutinefunction(r): + return await r(*args, **kwargs) + return r(*args, **kwargs) + + def _field_to_strawberry(self, f: Field) -> Any: + resolver = None + try: + if f.arguments: + resolver = self._build_resolver(f) + elif not f.resolver: + resolver = lambda *_, **__: None + else: + ann = getattr(f.resolver, "__annotations__", {}) + if "return" not in ann or ann["return"] is None: + ann = dict(ann) + ann["return"] = f.type + f.resolver.__annotations__ = ann + resolver = f.resolver + + return strawberry.field(resolver=self._wrap_with_auth(f, resolver)) + except StrawberryException as e: + raise Exception(f"Error converting field '{f.name}' to strawberry field: {e}") from e + + def to_strawberry(self) -> Type: + cls = self.__class__ + if TypeCollector.has(cls): + return TypeCollector.get(cls) + + annotations: dict[str, Any] = {} + namespace: dict[str, Any] = {} + + for name, f in self._fields.items(): + annotations[name] = f.type + namespace[name] = self._field_to_strawberry(f) + + namespace["__annotations__"] = annotations + gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace)) + TypeCollector.set(cls, gql_type) + return gql_type diff --git a/src/cpl-graphql/cpl/graphql/graphql_module.py b/src/cpl-graphql/cpl/graphql/graphql_module.py index d9d66aee..b749d16e 100644 --- a/src/cpl-graphql/cpl/graphql/graphql_module.py +++ b/src/cpl-graphql/cpl/graphql/graphql_module.py @@ -6,14 +6,15 @@ from cpl.graphql.schema.filter.date_filter import DateFilter from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.filter.int_filter import IntFilter from cpl.graphql.schema.filter.string_filter import StringFilter +from cpl.graphql.schema.root_mutation import RootMutation from cpl.graphql.schema.root_query import RootQuery -from cpl.graphql.service.schema import Schema from cpl.graphql.service.graphql import GraphQLService +from cpl.graphql.service.schema import Schema class GraphQLModule(Module): dependencies = [ApiModule] - singleton = [Schema, RootQuery] + singleton = [Schema, RootQuery, RootMutation] scoped = [GraphQLService] transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter] diff --git a/src/cpl-graphql/cpl/graphql/query_context.py b/src/cpl-graphql/cpl/graphql/query_context.py index 9b75d694..0c8f5781 100644 --- a/src/cpl-graphql/cpl/graphql/query_context.py +++ b/src/cpl-graphql/cpl/graphql/query_context.py @@ -1,11 +1,10 @@ from enum import Enum -from typing import Optional, Any +from typing import Optional from graphql import GraphQLResolveInfo from cpl.auth.schema import AuthUser, Permission from cpl.core.ctx import get_user -from cpl.core.utils import get_value class QueryContext: diff --git a/src/cpl-graphql/cpl/graphql/schema/argument.py b/src/cpl-graphql/cpl/graphql/schema/argument.py index cbf8b32f..3332ddd0 100644 --- a/src/cpl-graphql/cpl/graphql/schema/argument.py +++ b/src/cpl-graphql/cpl/graphql/schema/argument.py @@ -1,38 +1,54 @@ -from typing import Any +from typing import Any, Self class Argument: def __init__( self, - t: type, name: str, + t: type, description: str = None, - default_value: Any = None, + default: Any = None, optional: bool = None, ): - self._type = t self._name = name + self._type = t self._description = description - self._default_value = default_value + self._default = default self._optional = optional - @property - def type(self) -> type: - return self._type - @property def name(self) -> str: return self._name + @property + def type(self) -> type: + return self._type + @property def description(self) -> str | None: return self._description @property - def default_value(self) -> Any | None: - return self._default_value + def default(self) -> Any | None: + return self._default @property def optional(self) -> bool | None: return self._optional + + def with_description(self, description: str) -> Self: + self._description = description + return self + + def with_default(self, default: Any) -> Self: + self._default = default + return self + + def with_optional(self, optional: bool) -> Self: + self._optional = optional + return self + + def with_required(self, required: bool = True) -> Self: + self._optional = not required + return self diff --git a/src/cpl-graphql/cpl/graphql/schema/field.py b/src/cpl-graphql/cpl/graphql/schema/field.py index 421413a4..8eceba25 100644 --- a/src/cpl-graphql/cpl/graphql/schema/field.py +++ b/src/cpl-graphql/cpl/graphql/schema/field.py @@ -91,22 +91,26 @@ class Field: self._optional = optional return self + def with_required(self, required: bool = True) -> Self: + self._optional = not required + return self + def with_default(self, default) -> Self: self._default = default return self - def with_argument(self, arg_type: type, name: str, description: str = None, default_value=None, optional=True) -> Self: + def with_argument(self, name: str, arg_type: type, description: str = None, default_value=None, optional=True) -> Argument: if name in self._args: raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'") - self._args[name] = Argument(arg_type, name, description, default_value, optional) - return self + self._args[name] = Argument(name, arg_type, description, default_value, optional) + return self._args[name] def with_arguments(self, args: list[Argument]) -> Self: for arg in args: if not isinstance(arg, Argument): raise ValueError(f"Expected Argument instance, got {type(arg)}") - self.with_argument(arg.type, arg.name, arg.description, arg.default_value, arg.optional) + self.with_argument(arg.type, arg.name, arg.description, arg.default, arg.optional) return self def with_require_any_permission(self, *permissions: TRequireAnyPermissions) -> Self: @@ -126,7 +130,7 @@ class Field: self._require_any = (permissions, resolvers) return self - def with_public(self, public: bool = False) -> Self: + def with_public(self, public: bool = True) -> Self: assert self._require_any is None, "Field cannot be public and have require_any set" assert self._require_any_permission is None, "Field cannot be public and have require_any_permission set" self._public = public diff --git a/src/cpl-graphql/cpl/graphql/schema/input.py b/src/cpl-graphql/cpl/graphql/schema/input.py index a4dfebdf..6e639db3 100644 --- a/src/cpl-graphql/cpl/graphql/schema/input.py +++ b/src/cpl-graphql/cpl/graphql/schema/input.py @@ -1,4 +1,4 @@ -from typing import Generic, Dict, Type, Optional, Self, Union +from typing import Generic, Dict, Type, Optional, Union, Any import strawberry @@ -12,12 +12,52 @@ _PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"} class Input(StrawberryProtocol, Generic[T]): def __init__(self): self._fields: Dict[str, Field] = {} + self._values: Dict[str, Any] = {} + + @property + def fields(self) -> Dict[str, Field]: + return self._fields + + def __getattr__(self, item): + if item in self._values: + return self._values[item] + raise AttributeError(f"{self.__class__.__name__} has no attribute {item}") + + def __setattr__(self, key, value): + if key in {"_fields", "_values"}: + super().__setattr__(key, value) + elif key in self._fields: + self._values[key] = value + else: + super().__setattr__(key, value) + + def get(self, key: str, default=None): + return self._values.get(key, default) def get_fields(self) -> dict[str, Field]: return self._fields - def field(self, name: str, typ: Union[type, "Input"], optional: bool = True): + def field(self, name: str, typ: Union[type, "Input"], optional: bool = True) -> Field: self._fields[name] = Field(name, typ, optional=optional) + return self._fields[name] + + def string_field(self, name: str, optional: bool = True) -> Field: + return self.field(name, str) + + def int_field(self, name: str, optional: bool = True) -> Field: + return self.field(name, int, optional) + + def float_field(self, name: str, optional: bool = True) -> Field: + return self.field(name, float, optional) + + def bool_field(self, name: str, optional: bool = True) -> Field: + return self.field(name, bool, optional) + + def list_field(self, name: str, t: type, optional: bool = True) -> Field: + return self.field(name, list[t], optional) + + def object_field(self, name: str, t: Type[StrawberryProtocol], optional: bool = True) -> Field: + return self.field(name, t().to_strawberry(), optional) def to_strawberry(self) -> Type: cls = self.__class__ diff --git a/src/cpl-graphql/cpl/graphql/schema/mutation.py b/src/cpl-graphql/cpl/graphql/schema/mutation.py new file mode 100644 index 00000000..691cee10 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/mutation.py @@ -0,0 +1,25 @@ +from typing import Type + +from cpl.dependency.inject import inject +from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.abc.query_abc import QueryABC +from cpl.graphql.schema.field import Field + + +class Mutation(QueryABC): + + @inject + def __init__(self, provider: ServiceProvider): + QueryABC.__init__(self) + self._provider = provider + + from cpl.graphql.service.schema import Schema + + self._schema = provider.get_service(Schema) + + def with_mutation(self, name: str, cls: Type["Mutation"]) -> Field: + sub = self._provider.get_service(cls) + if not sub: + raise ValueError(f"Mutation '{cls.__name__}' not registered in service provider") + + return self.field(name, sub.to_strawberry(), lambda: sub) diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py index 0b8df16f..cbd05781 100644 --- a/src/cpl-graphql/cpl/graphql/schema/query.py +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -1,76 +1,32 @@ -import functools -import inspect -from asyncio import iscoroutinefunction -from typing import Callable, Type, Any, Optional +from typing import Callable, Type -import strawberry -from strawberry.exceptions import StrawberryException - -from cpl.api import Unauthorized, Forbidden -from cpl.core.ctx import get_user from cpl.database.abc.data_access_object_abc import DataAccessObjectABC from cpl.dependency.inject import inject from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.abc.query_abc import QueryABC from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol -from cpl.graphql.error import graphql_error -from cpl.graphql.query_context import QueryContext from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory from cpl.graphql.schema.field import Field from cpl.graphql.schema.sort.sort_order import SortOrder -from cpl.graphql.typing import Resolver -from cpl.graphql.utils.type_collector import TypeCollector -class Query(StrawberryProtocol): +class Query(QueryABC): @inject def __init__(self, provider: ServiceProvider): + QueryABC.__init__(self) self._provider = provider from cpl.graphql.service.schema import Schema self._schema = provider.get_service(Schema) - self._fields: dict[str, Field] = {} - def get_fields(self) -> dict[str, Field]: - return self._fields - - def field( - self, - name: str, - t: type, - resolver: Resolver = None, - ) -> Field: - from cpl.graphql.schema.field import Field - - self._fields[name] = Field(name, t, resolver) - return self._fields[name] - - def string_field(self, name: str, resolver: Resolver = None) -> Field: - return self.field(name, str, resolver) - - def int_field(self, name: str, resolver: Resolver = None) -> Field: - return self.field(name, int, resolver) - - def float_field(self, name: str, resolver: Resolver = None) -> Field: - return self.field(name, float, resolver) - - def bool_field(self, name: str, resolver: Resolver = None) -> Field: - return self.field(name, bool, resolver) - - def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field: - return self.field(name, list[t], resolver) - - def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field: - return self.field(name, t().to_strawberry(), resolver) - - def with_query(self, name: str, subquery_cls: Type["Query"]): + def with_query(self, name: str, subquery_cls: Type["Query"]) -> Field: sub = self._provider.get_service(subquery_cls) if not sub: raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider") - self.field(name, sub.to_strawberry(), lambda: sub) - return self + return self.field(name, sub.to_strawberry(), lambda: sub) def collection_field( self, @@ -105,10 +61,10 @@ class Query(StrawberryProtocol): raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider") f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection) - f.with_argument(filter.to_strawberry(), "filter") - f.with_argument(sort.to_strawberry(), "sort") - f.with_argument(int, "skip", default_value=0) - f.with_argument(int, "take", default_value=10) + f.with_argument("filter", filter.to_strawberry()) + f.with_argument("sort", sort.to_strawberry()) + f.with_argument("skip", int, default_value=0) + f.with_argument("take", int, default_value=10) return f def dao_collection_field( @@ -168,120 +124,8 @@ class Query(StrawberryProtocol): return Collection(nodes=data, total_count=total_count, count=len(data)) f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver) - f.with_argument(filter.to_strawberry(), "filter") - f.with_argument(sort.to_strawberry(), "sort") - f.with_argument(int, "skip", default_value=0) - f.with_argument(int, "take", default_value=10) + f.with_argument("filter", filter.to_strawberry()) + f.with_argument("sort", sort.to_strawberry()) + f.with_argument("skip", int, default_value=0) + f.with_argument("take", int, default_value=10) return f - - @staticmethod - def _build_resolver(f: "Field"): - params: list[inspect.Parameter] = [] - for arg in f.arguments.values(): - ann = Optional[arg.type] if arg.optional else arg.type - - if arg.default_value is None: - param = inspect.Parameter( - arg.name, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=ann, - ) - else: - param = inspect.Parameter( - arg.name, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=ann, - default=arg.default_value, - ) - - params.append(param) - - sig = inspect.Signature(parameters=params, return_annotation=f.type) - - async def _resolver(*args, **kwargs): - if f.resolver is None: - return None - - if iscoroutinefunction(f.resolver): - return await f.resolver(*args, **kwargs) - return f.resolver(*args, **kwargs) - - _resolver.__signature__ = sig - return _resolver - - def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable: - sig = getattr(resolver, "__signature__", None) - - @functools.wraps(resolver) - async def _auth_resolver(*args, **kwargs): - if f.public: - return await self._run_resolver(resolver, *args, **kwargs) - - user = get_user() - - if user is None: - raise graphql_error(Unauthorized(f"{f.name}: Authentication required")) - - if f.require_any_permission: - if not any([await user.has_permission(p) for p in f.require_any_permission]): - raise graphql_error(Forbidden(f"{f.name}: Permission denied")) - - if f.require_any: - perms, resolvers = f.require_any - if not any([await user.has_permission(p) for p in perms]): - ctx = QueryContext([x.name for x in await user.permissions]) - resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers] - - if not any(resolved): - raise graphql_error(Forbidden(f"{f.name}: Permission denied")) - - return await self._run_resolver(resolver, *args, **kwargs) - - if sig: - _auth_resolver.__signature__ = sig - - return _auth_resolver - - @staticmethod - async def _run_resolver(r: Callable, *args, **kwargs): - if iscoroutinefunction(r): - return await r(*args, **kwargs) - return r(*args, **kwargs) - - def _field_to_strawberry(self, f: Field) -> Any: - resolver = None - try: - if f.arguments: - resolver = self._build_resolver(f) - elif not f.resolver: - resolver = lambda *_, **__: None - else: - ann = getattr(f.resolver, "__annotations__", {}) - if "return" not in ann or ann["return"] is None: - ann = dict(ann) - ann["return"] = f.type - f.resolver.__annotations__ = ann - resolver = f.resolver - - return strawberry.field(resolver=self._wrap_with_auth(f, resolver)) - except StrawberryException as e: - raise Exception( - f"Error converting field '{f.name}' to strawberry field: {e}" - ) from e - - def to_strawberry(self) -> Type: - cls = self.__class__ - if TypeCollector.has(cls): - return TypeCollector.get(cls) - - annotations: dict[str, Any] = {} - namespace: dict[str, Any] = {} - - for name, f in self._fields.items(): - annotations[name] = f.type - namespace[name] = self._field_to_strawberry(f) - - namespace["__annotations__"] = annotations - gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace)) - TypeCollector.set(cls, gql_type) - return gql_type diff --git a/src/cpl-graphql/cpl/graphql/schema/root_mutation.py b/src/cpl-graphql/cpl/graphql/schema/root_mutation.py new file mode 100644 index 00000000..8855d8e7 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/root_mutation.py @@ -0,0 +1,6 @@ +from cpl.graphql.schema.mutation import Mutation + + +class RootMutation(Mutation): + def __init__(self): + Mutation.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py index c1c43cdc..9141f455 100644 --- a/src/cpl-graphql/cpl/graphql/service/schema.py +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -6,6 +6,7 @@ import strawberry from cpl.api.logger import APILogger from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol +from cpl.graphql.schema.root_mutation import RootMutation from cpl.graphql.schema.root_query import RootQuery @@ -25,7 +26,17 @@ class Schema: @property def query(self) -> RootQuery: - return self._provider.get_service(RootQuery) + query = self._provider.get_service(RootQuery) + if not query: + raise ValueError("RootQuery not registered in service provider") + return query + + @property + def mutation(self) -> RootMutation: + mutation = self._provider.get_service(RootMutation) + if not mutation: + raise ValueError("RootMutation not registered in service provider") + return mutation def with_type(self, t: Type[StrawberryProtocol]) -> Self: self._types[t.__name__] = t @@ -43,13 +54,13 @@ class Schema: def build(self) -> strawberry.Schema: logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL) - query = self._provider.get_service(RootQuery) - if not query: - raise ValueError("RootQuery not registered in service provider") + + query = self.query + mutation = self.mutation self._schema = strawberry.Schema( - query=query.to_strawberry(), - mutation=None, + query=query.to_strawberry() if query.fields_count > 0 else None, + mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None, subscription=None, types=self._get_types(), ) From 39351a5eb9b049d6dbf67de27dff0ae237e3860b Mon Sep 17 00:00:00 2001 From: edraft Date: Sun, 28 Sep 2025 21:53:15 +0200 Subject: [PATCH 12/20] Recursive types #181 --- example/api/src/main.py | 8 +-- example/api/src/model/author_query.py | 12 ++-- example/api/src/model/post_query.py | 14 ++-- example/api/src/queries/hello.py | 66 +++++++++---------- src/cpl-api/cpl/api/application/web_app.py | 4 +- src/cpl-api/cpl/api/middleware/request.py | 5 +- .../cpl/auth/permission/role_seeder.py | 4 +- .../auth/schema/_administration/api_key.py | 4 +- .../auth/schema/_administration/auth_user.py | 4 +- .../schema/_permission/api_key_permission.py | 4 +- .../cpl/auth/schema/_permission/permission.py | 4 +- .../cpl/auth/schema/_permission/role.py | 4 +- .../schema/_permission/role_permission.py | 4 +- .../cpl/auth/schema/_permission/role_user.py | 4 +- .../cpl/core/utils/credential_manager.py | 1 + .../cpl/database/abc/db_join_model_abc.py | 4 +- .../cpl/database/abc/db_model_abc.py | 25 ++++--- .../cpl/database/schema/executed_migration.py | 4 +- .../cpl/graphql/_endpoints/graphiql.py | 7 +- .../cpl/graphql/_endpoints/playground.py | 6 +- src/cpl-graphql/cpl/graphql/abc/query_abc.py | 54 +++++++++++---- src/cpl-graphql/cpl/graphql/auth/__init__.py | 0 .../graphql/auth/administration/__init__.py | 0 .../administration/auth_user_graph_type.py | 12 ++++ .../cpl/graphql/auth/graphql_auth_module.py | 6 ++ src/cpl-graphql/cpl/graphql/error.py | 2 +- src/cpl-graphql/cpl/graphql/graphql_module.py | 6 ++ src/cpl-graphql/cpl/graphql/query_context.py | 8 +-- .../cpl/graphql/schema/collection.py | 1 - .../cpl/graphql/schema/db_model_graph_type.py | 60 +++++++++++++++++ src/cpl-graphql/cpl/graphql/schema/field.py | 8 ++- .../cpl/graphql/schema/filter/date_filter.py | 2 +- .../graphql/schema/filter/db_model_filter.py | 20 ++++++ .../cpl/graphql/schema/filter/int_filter.py | 2 +- .../cpl/graphql/schema/graph_type.py | 2 +- src/cpl-graphql/cpl/graphql/schema/input.py | 19 ++++-- .../cpl/graphql/schema/sort/sort_order.py | 2 +- .../cpl/graphql/service/graphql.py | 8 +-- src/cpl-graphql/cpl/graphql/service/schema.py | 5 +- src/cpl-graphql/cpl/graphql/typing.py | 6 +- .../cpl/graphql/utils/type_collector.py | 4 +- 41 files changed, 281 insertions(+), 134 deletions(-) create mode 100644 src/cpl-graphql/cpl/graphql/auth/__init__.py create mode 100644 src/cpl-graphql/cpl/graphql/auth/administration/__init__.py create mode 100644 src/cpl-graphql/cpl/graphql/auth/administration/auth_user_graph_type.py create mode 100644 src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py diff --git a/example/api/src/main.py b/example/api/src/main.py index fdd7dff0..c57d0e39 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -1,7 +1,7 @@ from starlette.responses import JSONResponse from api.src.queries.cities import CityGraphType, CityFilter, CitySort -from api.src.queries.hello import UserGraphType, AuthUserFilter, AuthUserSort, AuthUserGraphType +from api.src.queries.hello import UserGraphType#, AuthUserFilter, AuthUserSort, AuthUserGraphType from api.src.queries.user import UserFilter, UserSort from cpl.api.api_module import ApiModule from cpl.application.application_builder import ApplicationBuilder @@ -47,9 +47,9 @@ def main(): .add_transient(UserGraphType) .add_transient(UserFilter) .add_transient(UserSort) - .add_transient(AuthUserGraphType) - .add_transient(AuthUserFilter) - .add_transient(AuthUserSort) + # .add_transient(AuthUserGraphType) + # .add_transient(AuthUserFilter) + # .add_transient(AuthUserSort) .add_transient(HelloQuery) # test data .add_singleton(TestDataSeeder) diff --git a/example/api/src/model/author_query.py b/example/api/src/model/author_query.py index ae365a7c..3fa4ab65 100644 --- a/example/api/src/model/author_query.py +++ b/example/api/src/model/author_query.py @@ -1,12 +1,12 @@ -from cpl.graphql.schema.filter.filter import Filter -from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.schema.db_model_graph_type import DbModelGraphType +from cpl.graphql.schema.filter.db_model_filter import DbModelFilter from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort_order import SortOrder from model.author import Author -class AuthorFilter(Filter[Author]): +class AuthorFilter(DbModelFilter[Author]): def __init__(self): - Filter.__init__(self) + DbModelFilter.__init__(self, public=True) self.int_field("id") self.string_field("firstName") self.string_field("lastName") @@ -18,10 +18,10 @@ class AuthorSort(Sort[Author]): self.field("firstName", SortOrder) self.field("lastName", SortOrder) -class AuthorGraphType(GraphType[Author]): +class AuthorGraphType(DbModelGraphType[Author]): def __init__(self): - GraphType.__init__(self) + DbModelGraphType.__init__(self, public=True) self.int_field( "id", diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py index 6334c51e..d12f308c 100644 --- a/example/api/src/model/post_query.py +++ b/example/api/src/model/post_query.py @@ -1,6 +1,6 @@ from cpl.graphql.query_context import QueryContext -from cpl.graphql.schema.filter.filter import Filter -from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.schema.db_model_graph_type import DbModelGraphType +from cpl.graphql.schema.filter.db_model_filter import DbModelFilter from cpl.graphql.schema.input import Input from cpl.graphql.schema.mutation import Mutation from cpl.graphql.schema.sort.sort import Sort @@ -11,9 +11,9 @@ from model.post import Post from model.post_dao import PostDao -class PostFilter(Filter[Post]): +class PostFilter(DbModelFilter[Post]): def __init__(self): - Filter.__init__(self) + DbModelFilter.__init__(self, public=True) self.int_field("id") self.filter_field("author", AuthorFilter) self.string_field("title") @@ -26,15 +26,15 @@ class PostSort(Sort[Post]): self.field("title", SortOrder) self.field("content", SortOrder) -class PostGraphType(GraphType[Post]): +class PostGraphType(DbModelGraphType[Post]): def __init__(self, authors: AuthorDao): - GraphType.__init__(self) + DbModelGraphType.__init__(self, public=True) self.int_field( "id", resolver=lambda root: root.id, - ).with_public(True) + ).with_optional().with_public(True) async def _a(root: Post): return await authors.get_by_id(root.author_id) diff --git a/example/api/src/queries/hello.py b/example/api/src/queries/hello.py index 864e39ab..88d9af27 100644 --- a/example/api/src/queries/hello.py +++ b/example/api/src/queries/hello.py @@ -11,32 +11,32 @@ from cpl.graphql.schema.sort.sort_order import SortOrder users = [User(i, f"User {i}") for i in range(1, 101)] cities = [City(i, f"City {i}") for i in range(1, 101)] -class AuthUserFilter(Filter[AuthUser]): - def __init__(self): - Filter.__init__(self) - self.field("id", int) - self.field("username", str) - - -class AuthUserSort(Sort[AuthUser]): - def __init__(self): - Sort.__init__(self) - self.field("id", SortOrder) - self.field("username", SortOrder) - -class AuthUserGraphType(GraphType[AuthUser]): - - def __init__(self): - GraphType.__init__(self) - - self.int_field( - "id", - resolver=lambda root: root.id, - ) - self.string_field( - "username", - resolver=lambda root: root.username, - ) +# class AuthUserFilter(Filter[AuthUser]): +# def __init__(self): +# Filter.__init__(self) +# self.field("id", int) +# self.field("username", str) +# +# +# class AuthUserSort(Sort[AuthUser]): +# def __init__(self): +# Sort.__init__(self) +# self.field("id", SortOrder) +# self.field("username", SortOrder) +# +# class AuthUserGraphType(GraphType[AuthUser]): +# +# def __init__(self): +# GraphType.__init__(self) +# +# self.int_field( +# "id", +# resolver=lambda root: root.id, +# ) +# self.string_field( +# "username", +# resolver=lambda root: root.username, +# ) class HelloQuery(Query): def __init__(self): @@ -60,10 +60,10 @@ class HelloQuery(Query): CitySort, resolver=lambda: cities, ) - self.dao_collection_field( - AuthUserGraphType, - AuthUserDao, - "authUsers", - AuthUserFilter, - AuthUserSort, - ) + # self.dao_collection_field( + # AuthUserGraphType, + # AuthUserDao, + # "authUsers", + # AuthUserFilter, + # AuthUserSort, + # ) diff --git a/src/cpl-api/cpl/api/application/web_app.py b/src/cpl-api/cpl/api/application/web_app.py index deeb2710..b63b4700 100644 --- a/src/cpl-api/cpl/api/application/web_app.py +++ b/src/cpl-api/cpl/api/application/web_app.py @@ -36,7 +36,9 @@ from cpl.dependency.typing import Modules class WebApp(WebAppABC): def __init__(self, services: ServiceProvider, modules: Modules, required_modules: list[str | object] = None): - WebAppABC.__init__(self, services, modules, [AuthModule, PermissionsModule, ApiModule] + (required_modules or [])) + WebAppABC.__init__( + self, services, modules, [AuthModule, PermissionsModule, ApiModule] + (required_modules or []) + ) self._app: Starlette | None = None self._logger = services.get_service(APILogger) diff --git a/src/cpl-api/cpl/api/middleware/request.py b/src/cpl-api/cpl/api/middleware/request.py index 2dc24bc5..6ddea35c 100644 --- a/src/cpl-api/cpl/api/middleware/request.py +++ b/src/cpl-api/cpl/api/middleware/request.py @@ -21,7 +21,9 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa class RequestMiddleware(ASGIMiddleware): - def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao): + def __init__( + self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao + ): ASGIMiddleware.__init__(self, app) self._provider = provider @@ -92,5 +94,6 @@ class RequestMiddleware(ASGIMiddleware): except Exception as e: self._logger.debug(f"Silent user binding failed: {e}") + def get_request() -> Optional[TRequest]: return _request_context.get() diff --git a/src/cpl-auth/cpl/auth/permission/role_seeder.py b/src/cpl-auth/cpl/auth/permission/role_seeder.py index 2c7687bd..15925299 100644 --- a/src/cpl-auth/cpl/auth/permission/role_seeder.py +++ b/src/cpl-auth/cpl/auth/permission/role_seeder.py @@ -23,8 +23,8 @@ class RoleSeeder(DataSeederABC): role_permission_dao: RolePermissionDao, api_key_dao: ApiKeyDao, api_key_permission_dao: ApiKeyPermissionDao, - user_dao: AuthUserDao, - role_user_dao: RoleUserDao, + user_dao: AuthUserDao, + role_user_dao: RoleUserDao, ): DataSeederABC.__init__(self) self._logger = logger diff --git a/src/cpl-auth/cpl/auth/schema/_administration/api_key.py b/src/cpl-auth/cpl/auth/schema/_administration/api_key.py index 995628e2..9a6d5f6c 100644 --- a/src/cpl-auth/cpl/auth/schema/_administration/api_key.py +++ b/src/cpl-auth/cpl/auth/schema/_administration/api_key.py @@ -25,8 +25,8 @@ class ApiKey(DbModelABC[Self]): key: Union[str, bytes], deleted: bool = False, editor_id: Optional[Id] = None, - created: datetime | None= None, - updated: datetime | None= None, + created: datetime | None = None, + updated: datetime | None = None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._identifier = identifier diff --git a/src/cpl-auth/cpl/auth/schema/_administration/auth_user.py b/src/cpl-auth/cpl/auth/schema/_administration/auth_user.py index e9eff14d..950a321c 100644 --- a/src/cpl-auth/cpl/auth/schema/_administration/auth_user.py +++ b/src/cpl-auth/cpl/auth/schema/_administration/auth_user.py @@ -20,8 +20,8 @@ class AuthUser(DbModelABC[Self]): keycloak_id: str, deleted: bool = False, editor_id: SerialId | None = None, - created: datetime | None= None, - updated: datetime | None= None, + created: datetime | None = None, + updated: datetime | None = None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._keycloak_id = keycloak_id diff --git a/src/cpl-auth/cpl/auth/schema/_permission/api_key_permission.py b/src/cpl-auth/cpl/auth/schema/_permission/api_key_permission.py index 59132955..5a807e76 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/api_key_permission.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/api_key_permission.py @@ -16,8 +16,8 @@ class ApiKeyPermission(DbJoinModelABC): permission_id: SerialId, deleted: bool = False, editor_id: SerialId | None = None, - created: datetime | None= None, - updated: datetime | None= None, + created: datetime | None = None, + updated: datetime | None = None, ): DbJoinModelABC.__init__(self, api_key_id, permission_id, id, deleted, editor_id, created, updated) self._api_key_id = api_key_id diff --git a/src/cpl-auth/cpl/auth/schema/_permission/permission.py b/src/cpl-auth/cpl/auth/schema/_permission/permission.py index 8db9c477..6ca5849a 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/permission.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/permission.py @@ -13,8 +13,8 @@ class Permission(DbModelABC[Self]): description: str, deleted: bool = False, editor_id: SerialId | None = None, - created: datetime | None= None, - updated: datetime | None= None, + created: datetime | None = None, + updated: datetime | None = None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._name = name diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role.py b/src/cpl-auth/cpl/auth/schema/_permission/role.py index 3c1b0a1f..d5da2c12 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role.py @@ -17,8 +17,8 @@ class Role(DbModelABC[Self]): description: str, deleted: bool = False, editor_id: SerialId | None = None, - created: datetime | None= None, - updated: datetime | None= None, + created: datetime | None = None, + updated: datetime | None = None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._name = name diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py b/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py index c58d8682..8038227b 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py @@ -16,8 +16,8 @@ class RolePermission(DbModelABC[Self]): permission_id: SerialId, deleted: bool = False, editor_id: SerialId | None = None, - created: datetime | None= None, - updated: datetime | None= None, + created: datetime | None = None, + updated: datetime | None = None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._role_id = role_id diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role_user.py b/src/cpl-auth/cpl/auth/schema/_permission/role_user.py index 72504768..90c4e05c 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role_user.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role_user.py @@ -16,8 +16,8 @@ class RoleUser(DbJoinModelABC): role_id: SerialId, deleted: bool = False, editor_id: SerialId | None = None, - created: datetime | None= None, - updated: datetime | None= None, + created: datetime | None = None, + updated: datetime | None = None, ): DbJoinModelABC.__init__(self, id, user_id, role_id, deleted, editor_id, created, updated) self._user_id = user_id diff --git a/src/cpl-core/cpl/core/utils/credential_manager.py b/src/cpl-core/cpl/core/utils/credential_manager.py index 126afd6a..46df3b43 100644 --- a/src/cpl-core/cpl/core/utils/credential_manager.py +++ b/src/cpl-core/cpl/core/utils/credential_manager.py @@ -11,6 +11,7 @@ class CredentialManager: @classmethod def with_secret(cls, file: str = None): from cpl.core.log import Logger + if file is None: file = ".secret" diff --git a/src/cpl-database/cpl/database/abc/db_join_model_abc.py b/src/cpl-database/cpl/database/abc/db_join_model_abc.py index 55327419..42388418 100644 --- a/src/cpl-database/cpl/database/abc/db_join_model_abc.py +++ b/src/cpl-database/cpl/database/abc/db_join_model_abc.py @@ -13,8 +13,8 @@ class DbJoinModelABC[T](DbModelABC[T]): foreign_id: Id, deleted: bool = False, editor_id: SerialId | None = None, - created: datetime | None= None, - updated: datetime | None= None, + created: datetime | None = None, + updated: datetime | None = None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) diff --git a/src/cpl-database/cpl/database/abc/db_model_abc.py b/src/cpl-database/cpl/database/abc/db_model_abc.py index 5791afe3..4f38a8de 100644 --- a/src/cpl-database/cpl/database/abc/db_model_abc.py +++ b/src/cpl-database/cpl/database/abc/db_model_abc.py @@ -2,7 +2,10 @@ from abc import ABC from datetime import datetime, timezone from typing import Optional, Generic +from async_property import async_property + from cpl.core.typing import Id, SerialId, T +from cpl.dependency import get_provider class DbModelABC(ABC, Generic[T]): @@ -11,8 +14,8 @@ class DbModelABC(ABC, Generic[T]): id: Id, deleted: bool = False, editor_id: SerialId | None = None, - created: datetime | None= None, - updated: datetime | None= None, + created: datetime | None = None, + updated: datetime | None = None, ): self._id = id self._deleted = deleted @@ -41,14 +44,16 @@ class DbModelABC(ABC, Generic[T]): def editor_id(self, value: SerialId): self._editor_id = value - # @async_property - # async def editor(self): - # if self._editor_id is None: - # return None - # - # from data.schemas.administration.user_dao import userDao - # - # return await userDao.get_by_id(self._editor_id) + @async_property + async def editor(self): + if self._editor_id is None: + return None + + from cpl.auth.schema import AuthUserDao + + auth_user_dao = get_provider().get_service(AuthUserDao) + + return await auth_user_dao.get_by_id(self._editor_id) @property def created(self) -> datetime: diff --git a/src/cpl-database/cpl/database/schema/executed_migration.py b/src/cpl-database/cpl/database/schema/executed_migration.py index 02b99dc3..b6ec58ac 100644 --- a/src/cpl-database/cpl/database/schema/executed_migration.py +++ b/src/cpl-database/cpl/database/schema/executed_migration.py @@ -8,8 +8,8 @@ class ExecutedMigration(DbModelABC[Self]): def __init__( self, migration_id: str, - created: datetime | None= None, - modified: datetime | None= None, + created: datetime | None = None, + modified: datetime | None = None, ): DbModelABC.__init__(self, migration_id, False, created, modified) diff --git a/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py b/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py index 2aedb538..70a81ad3 100644 --- a/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py +++ b/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py @@ -1,7 +1,9 @@ from starlette.responses import HTMLResponse + async def graphiql_endpoint(request): - return HTMLResponse(""" + return HTMLResponse( + """ @@ -34,4 +36,5 @@ async def graphiql_endpoint(request): - """) + """ + ) diff --git a/src/cpl-graphql/cpl/graphql/_endpoints/playground.py b/src/cpl-graphql/cpl/graphql/_endpoints/playground.py index 68e59fdf..969cd506 100644 --- a/src/cpl-graphql/cpl/graphql/_endpoints/playground.py +++ b/src/cpl-graphql/cpl/graphql/_endpoints/playground.py @@ -3,7 +3,8 @@ from starlette.responses import Response, HTMLResponse async def playground_endpoint(request: Request) -> Response: - return HTMLResponse(""" + return HTMLResponse( + """ @@ -24,4 +25,5 @@ async def playground_endpoint(request: Request) -> Response: - """) + """ + ) diff --git a/src/cpl-graphql/cpl/graphql/abc/query_abc.py b/src/cpl-graphql/cpl/graphql/abc/query_abc.py index 5023ebea..8cad66d2 100644 --- a/src/cpl-graphql/cpl/graphql/abc/query_abc.py +++ b/src/cpl-graphql/cpl/graphql/abc/query_abc.py @@ -1,7 +1,8 @@ import functools import inspect +import types from abc import ABC -from asyncio import iscoroutinefunction +from asyncio import iscoroutinefunction, iscoroutine from typing import Callable, Type, Any, Optional import strawberry @@ -9,11 +10,12 @@ from strawberry.exceptions import StrawberryException from cpl.api import Unauthorized, Forbidden from cpl.core.ctx.user_context import get_user +from cpl.dependency import get_provider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.error import graphql_error from cpl.graphql.query_context import QueryContext from cpl.graphql.schema.field import Field -from cpl.graphql.typing import Resolver +from cpl.graphql.typing import Resolver, AttributeName from cpl.graphql.utils.type_collector import TypeCollector @@ -36,31 +38,37 @@ class QueryABC(StrawberryProtocol, ABC): def field( self, - name: str, + name: AttributeName, t: type, resolver: Resolver = None, ) -> Field: from cpl.graphql.schema.field import Field + if isinstance(name, property): + name = name.fget.__name__ + self._fields[name] = Field(name, t, resolver) return self._fields[name] - def string_field(self, name: str, resolver: Resolver = None) -> Field: + def string_field(self, name: AttributeName, resolver: Resolver = None) -> Field: return self.field(name, str, resolver) - def int_field(self, name: str, resolver: Resolver = None) -> Field: + def int_field(self, name: AttributeName, resolver: Resolver = None) -> Field: return self.field(name, int, resolver) - def float_field(self, name: str, resolver: Resolver = None) -> Field: + def float_field(self, name: AttributeName, resolver: Resolver = None) -> Field: return self.field(name, float, resolver) - def bool_field(self, name: str, resolver: Resolver = None) -> Field: + def bool_field(self, name: AttributeName, resolver: Resolver = None) -> Field: return self.field(name, bool, resolver) - def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field: + def list_field(self, name: AttributeName, t: type, resolver: Resolver = None) -> Field: return self.field(name, list[t], resolver) def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field: + if not isinstance(t, type) and callable(t): + return self.field(name, t, resolver) + return self.field(name, t().to_strawberry(), resolver) @staticmethod @@ -137,9 +145,10 @@ class QueryABC(StrawberryProtocol, ABC): @staticmethod async def _run_resolver(r: Callable, *args, **kwargs): - if iscoroutinefunction(r): - return await r(*args, **kwargs) - return r(*args, **kwargs) + result = r(*args, **kwargs) + if inspect.isawaitable(result): + return await result + return result def _field_to_strawberry(self, f: Field) -> Any: resolver = None @@ -147,7 +156,7 @@ class QueryABC(StrawberryProtocol, ABC): if f.arguments: resolver = self._build_resolver(f) elif not f.resolver: - resolver = lambda *_, **__: None + resolver = lambda root: None else: ann = getattr(f.resolver, "__annotations__", {}) if "return" not in ann or ann["return"] is None: @@ -165,14 +174,31 @@ class QueryABC(StrawberryProtocol, ABC): if TypeCollector.has(cls): return TypeCollector.get(cls) + gql_cls = type(f"{cls.__name__.replace('GraphType', '')}", (), {}) + # register early to handle recursive types + TypeCollector.set(cls, gql_cls) + annotations: dict[str, Any] = {} namespace: dict[str, Any] = {} for name, f in self._fields.items(): - annotations[name] = f.type + t = f.type + + if callable(t) and not isinstance(t, type): + _t = get_provider().get_service(t()) + if isinstance(_t, StrawberryProtocol): + t = _t.to_strawberry() + else: + t = _t + + annotations[name] = t if not f.optional else Optional[t] namespace[name] = self._field_to_strawberry(f) namespace["__annotations__"] = annotations - gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace)) + for k, v in namespace.items(): + setattr(gql_cls, k, v) + + gql_cls.__annotations__ = annotations + gql_type = strawberry.type(gql_cls) TypeCollector.set(cls, gql_type) return gql_type diff --git a/src/cpl-graphql/cpl/graphql/auth/__init__.py b/src/cpl-graphql/cpl/graphql/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/__init__.py b/src/cpl-graphql/cpl/graphql/auth/administration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/auth_user_graph_type.py b/src/cpl-graphql/cpl/graphql/auth/administration/auth_user_graph_type.py new file mode 100644 index 00000000..d96af34e --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/administration/auth_user_graph_type.py @@ -0,0 +1,12 @@ +from cpl.auth.schema import AuthUser +from cpl.graphql.schema.db_model_graph_type import DbModelGraphType + + +class AuthUserGraphType(DbModelGraphType): + + def __init__(self): + DbModelGraphType.__init__(self) + + self.string_field(AuthUser.keycloak_id, lambda root: root.keycloak_id) + self.string_field(AuthUser.username, lambda root: root.username) + self.string_field(AuthUser.email, lambda root: root.email) diff --git a/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py new file mode 100644 index 00000000..dc53f754 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py @@ -0,0 +1,6 @@ +from cpl.dependency.module.module import Module +from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType + + +class GraphQLAuthModule(Module): + transient = [AuthUserGraphType] diff --git a/src/cpl-graphql/cpl/graphql/error.py b/src/cpl-graphql/cpl/graphql/error.py index e96e41c1..ecab2c06 100644 --- a/src/cpl-graphql/cpl/graphql/error.py +++ b/src/cpl-graphql/cpl/graphql/error.py @@ -11,4 +11,4 @@ def graphql_error(api_error: APIError) -> GraphQLError: "code": api_error.status_code, }, original_error=api_error, - ) \ No newline at end of file + ) diff --git a/src/cpl-graphql/cpl/graphql/graphql_module.py b/src/cpl-graphql/cpl/graphql/graphql_module.py index b749d16e..05a36787 100644 --- a/src/cpl-graphql/cpl/graphql/graphql_module.py +++ b/src/cpl-graphql/cpl/graphql/graphql_module.py @@ -1,6 +1,8 @@ from cpl.api.api_module import ApiModule +from cpl.dependency import ServiceCollection from cpl.dependency.module.module import Module from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule from cpl.graphql.schema.filter.bool_filter import BoolFilter from cpl.graphql.schema.filter.date_filter import DateFilter from cpl.graphql.schema.filter.filter import Filter @@ -18,6 +20,10 @@ class GraphQLModule(Module): scoped = [GraphQLService] transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter] + @staticmethod + def register(collection: ServiceCollection): + collection.add_module(GraphQLAuthModule) + @staticmethod def configure(services: ServiceProvider) -> None: schema = services.get_service(Schema) diff --git a/src/cpl-graphql/cpl/graphql/query_context.py b/src/cpl-graphql/cpl/graphql/query_context.py index 0c8f5781..44a916ee 100644 --- a/src/cpl-graphql/cpl/graphql/query_context.py +++ b/src/cpl-graphql/cpl/graphql/query_context.py @@ -9,13 +9,7 @@ from cpl.core.ctx import get_user class QueryContext: - def __init__( - self, - user_permissions: Optional[list[Enum | Permission]], - is_mutation: bool = False, - *args, - **kwargs - ): + def __init__(self, user_permissions: Optional[list[Enum | Permission]], is_mutation: bool = False, *args, **kwargs): self._user = get_user() self._user_permissions = user_permissions or [] diff --git a/src/cpl-graphql/cpl/graphql/schema/collection.py b/src/cpl-graphql/cpl/graphql/schema/collection.py index 1d37a626..0dbc66c0 100644 --- a/src/cpl-graphql/cpl/graphql/schema/collection.py +++ b/src/cpl-graphql/cpl/graphql/schema/collection.py @@ -19,7 +19,6 @@ class CollectionGraphTypeFactory: if not node_t: raise ValueError(f"Node type '{node_type.__name__}' not registered in service provider") - gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type gql_type = strawberry.type( diff --git a/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py b/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py new file mode 100644 index 00000000..32f6cfbc --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py @@ -0,0 +1,60 @@ +from typing import Type, Optional, Generic, Annotated + +import strawberry + +from cpl.core.typing import T +from cpl.database.abc.data_access_object_abc import DataAccessObjectABC +from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.schema.query import Query + + +class DbModelGraphType(GraphType[T], Generic[T]): + + def __init__(self, t_dao: Type[DataAccessObjectABC] = None, with_history: bool = False, public: bool = False): + Query.__init__(self) + + self._dao: Optional[DataAccessObjectABC] = None + + if t_dao is not None: + dao = self._provider.get_service(t_dao) + if dao is not None: + self._dao = dao + + self.int_field("id", lambda root: root.id).with_public(public) + self.bool_field("deleted", lambda root: root.deleted).with_public(public) + + from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType + + self.object_field("editor", lambda: AuthUserGraphType, lambda root: root.editor).with_public(public) + + self.string_field("created", lambda root: root.created).with_public(public) + self.string_field("updated", lambda root: root.updated).with_public(public) + + # if with_history: + # if self._dao is None: + # raise ValueError("DAO must be provided to enable history") + # self.set_field("history", self._resolve_history).with_public(public) + + self._history_reference_daos: dict[DataAccessObjectABC, str] = {} + + async def _resolve_history(self, root): + if self._dao is None: + raise Exception("DAO not set for history query") + + history = sorted( + [await self._dao.get_by_id(root.id), *await self._dao.get_history(root.id)], + key=lambda h: h.updated, + reverse=True, + ) + return history + + def set_history_reference_dao(self, dao: DataAccessObjectABC, key: str = None): + """ + Set the reference DAO for history resolution. + :param dao: + :param key: The key to use for resolving history. + :return: + """ + if key is None: + key = "id" + self._history_reference_daos[dao] = key diff --git a/src/cpl-graphql/cpl/graphql/schema/field.py b/src/cpl-graphql/cpl/graphql/schema/field.py index 8eceba25..cea91c93 100644 --- a/src/cpl-graphql/cpl/graphql/schema/field.py +++ b/src/cpl-graphql/cpl/graphql/schema/field.py @@ -5,7 +5,7 @@ from cpl.graphql.schema.argument import Argument from cpl.graphql.typing import TQuery, Resolver, TRequireAnyPermissions, TRequireAnyResolvers -class Field: +class Field: def __init__( self, @@ -87,7 +87,7 @@ class Field: self._resolver = resolver return self - def with_optional(self, optional: bool) -> Self: + def with_optional(self, optional: bool = True) -> Self: self._optional = optional return self @@ -99,7 +99,9 @@ class Field: self._default = default return self - def with_argument(self, name: str, arg_type: type, description: str = None, default_value=None, optional=True) -> Argument: + def with_argument( + self, name: str, arg_type: type, description: str = None, default_value=None, optional=True + ) -> Argument: if name in self._args: raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'") self._args[name] = Argument(name, arg_type, description, default_value, optional) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py index 2dd1bcf8..0149a3b9 100644 --- a/src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py +++ b/src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py @@ -15,4 +15,4 @@ class DateFilter(Input[datetime]): self.field("isNull", datetime, optional=True) self.field("isNotNull", datetime, optional=True) self.field("in", list[datetime], optional=True) - self.field("notIn", list[datetime], optional=True) \ No newline at end of file + self.field("notIn", list[datetime], optional=True) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py new file mode 100644 index 00000000..860712fe --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py @@ -0,0 +1,20 @@ +from typing import Type, Generic + +from cpl.core.typing import T +from cpl.graphql.schema.filter.bool_filter import BoolFilter +from cpl.graphql.schema.filter.date_filter import DateFilter +from cpl.graphql.schema.filter.filter import Filter +from cpl.graphql.schema.filter.int_filter import IntFilter +from cpl.graphql.schema.filter.string_filter import StringFilter +from cpl.graphql.schema.input import Input + + +class DbModelFilter(Filter[T], Generic[T]): + def __init__(self, public: bool = False): + Filter.__init__(self) + + self.field("id", IntFilter).with_public(public) + self.field("deleted", BoolFilter).with_public(public) + # self.field("editor", AuthUserFilter) + self.field("created", DateFilter).with_public(public) + self.field("updated", DateFilter).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py index be9eba74..801ad562 100644 --- a/src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py +++ b/src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py @@ -13,4 +13,4 @@ class IntFilter(Input[int]): self.field("isNull", int, optional=True) self.field("isNotNull", int, optional=True) self.field("in", list[int], optional=True) - self.field("notIn", list[int], optional=True) \ No newline at end of file + self.field("notIn", list[int], optional=True) diff --git a/src/cpl-graphql/cpl/graphql/schema/graph_type.py b/src/cpl-graphql/cpl/graphql/schema/graph_type.py index e829b82d..b4d5b422 100644 --- a/src/cpl-graphql/cpl/graphql/schema/graph_type.py +++ b/src/cpl-graphql/cpl/graphql/schema/graph_type.py @@ -7,4 +7,4 @@ from cpl.graphql.schema.query import Query class GraphType(Query, Generic[T]): def __init__(self): - Query.__init__(self) \ No newline at end of file + Query.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/schema/input.py b/src/cpl-graphql/cpl/graphql/schema/input.py index 6e639db3..bcba7ae0 100644 --- a/src/cpl-graphql/cpl/graphql/schema/input.py +++ b/src/cpl-graphql/cpl/graphql/schema/input.py @@ -5,10 +5,12 @@ import strawberry from cpl.core.typing import T from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.schema.field import Field +from cpl.graphql.typing import AttributeName from cpl.graphql.utils.type_collector import TypeCollector _PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"} + class Input(StrawberryProtocol, Generic[T]): def __init__(self): self._fields: Dict[str, Field] = {} @@ -37,26 +39,29 @@ class Input(StrawberryProtocol, Generic[T]): def get_fields(self) -> dict[str, Field]: return self._fields - def field(self, name: str, typ: Union[type, "Input"], optional: bool = True) -> Field: + def field(self, name: AttributeName, typ: Union[type, "Input"], optional: bool = True) -> Field: + if isinstance(name, property): + name = name.fget.__name__ + self._fields[name] = Field(name, typ, optional=optional) return self._fields[name] - def string_field(self, name: str, optional: bool = True) -> Field: + def string_field(self, name: AttributeName, optional: bool = True) -> Field: return self.field(name, str) - def int_field(self, name: str, optional: bool = True) -> Field: + def int_field(self, name: AttributeName, optional: bool = True) -> Field: return self.field(name, int, optional) - def float_field(self, name: str, optional: bool = True) -> Field: + def float_field(self, name: AttributeName, optional: bool = True) -> Field: return self.field(name, float, optional) - def bool_field(self, name: str, optional: bool = True) -> Field: + def bool_field(self, name: AttributeName, optional: bool = True) -> Field: return self.field(name, bool, optional) - def list_field(self, name: str, t: type, optional: bool = True) -> Field: + def list_field(self, name: AttributeName, t: type, optional: bool = True) -> Field: return self.field(name, list[t], optional) - def object_field(self, name: str, t: Type[StrawberryProtocol], optional: bool = True) -> Field: + def object_field(self, name: AttributeName, t: Type[StrawberryProtocol], optional: bool = True) -> Field: return self.field(name, t().to_strawberry(), optional) def to_strawberry(self) -> Type: diff --git a/src/cpl-graphql/cpl/graphql/schema/sort/sort_order.py b/src/cpl-graphql/cpl/graphql/schema/sort/sort_order.py index cb8e8177..db75e06e 100644 --- a/src/cpl-graphql/cpl/graphql/schema/sort/sort_order.py +++ b/src/cpl-graphql/cpl/graphql/schema/sort/sort_order.py @@ -3,4 +3,4 @@ from enum import Enum, auto class SortOrder(Enum): ASC = "ASC" - DESC = "DESC" \ No newline at end of file + DESC = "DESC" diff --git a/src/cpl-graphql/cpl/graphql/service/graphql.py b/src/cpl-graphql/cpl/graphql/service/graphql.py index cb4ee667..7262906d 100644 --- a/src/cpl-graphql/cpl/graphql/service/graphql.py +++ b/src/cpl-graphql/cpl/graphql/service/graphql.py @@ -16,10 +16,10 @@ class GraphQLService: self._schema = schema.schema async def execute( - self, - query: str, - variables: Optional[Dict[str, Any]], - request: TRequest, + self, + query: str, + variables: Optional[Dict[str, Any]], + request: TRequest, ) -> Dict[str, Any]: result = await self._schema.execute( query, diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py index 9141f455..3142adaa 100644 --- a/src/cpl-graphql/cpl/graphql/service/schema.py +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -6,6 +6,7 @@ import strawberry from cpl.api.logger import APILogger from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol +from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType from cpl.graphql.schema.root_mutation import RootMutation from cpl.graphql.schema.root_query import RootQuery @@ -16,7 +17,9 @@ class Schema: self._logger = logger self._provider = provider - self._types: dict[str, Type[StrawberryProtocol]] = {} + self._types: dict[str, Type[StrawberryProtocol]] = { + "AuthUserGraphType": AuthUserGraphType, + } self._schema = None diff --git a/src/cpl-graphql/cpl/graphql/typing.py b/src/cpl-graphql/cpl/graphql/typing.py index d36e3119..3dd33106 100644 --- a/src/cpl-graphql/cpl/graphql/typing.py +++ b/src/cpl-graphql/cpl/graphql/typing.py @@ -7,9 +7,7 @@ from cpl.graphql.query_context import QueryContext TQuery = Type["Query"] Resolver = Callable ScalarType = str | int | float | bool | object - +AttributeName = str | property TRequireAnyPermissions = List[Enum | Permissions] | None -TRequireAnyResolvers = List[ - Callable[[QueryContext], bool | Awaitable[bool]], -] +TRequireAnyResolvers = List[Callable[[QueryContext], bool | Awaitable[bool]],] TRequireAny = Tuple[TRequireAnyPermissions, TRequireAnyResolvers] diff --git a/src/cpl-graphql/cpl/graphql/utils/type_collector.py b/src/cpl-graphql/cpl/graphql/utils/type_collector.py index c51718bf..bf9e4332 100644 --- a/src/cpl-graphql/cpl/graphql/utils/type_collector.py +++ b/src/cpl-graphql/cpl/graphql/utils/type_collector.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Type, Any class TypeCollector: @@ -14,4 +14,4 @@ class TypeCollector: @classmethod def set(cls, base: type, gql_type: Type): - cls._registry[base] = gql_type \ No newline at end of file + cls._registry[base] = gql_type From df69f1c7256572e1cdb62227ccd0d0157b90c431 Mon Sep 17 00:00:00 2001 From: edraft Date: Sun, 28 Sep 2025 22:06:50 +0200 Subject: [PATCH 13/20] Recursive filter #181 --- .../auth/administration/auth_user_filter.py | 11 ++++++ .../cpl/graphql/auth/graphql_auth_module.py | 3 +- .../graphql/schema/filter/db_model_filter.py | 8 ++-- src/cpl-graphql/cpl/graphql/schema/input.py | 39 +++++++++++++------ 4 files changed, 45 insertions(+), 16 deletions(-) create mode 100644 src/cpl-graphql/cpl/graphql/auth/administration/auth_user_filter.py diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/auth_user_filter.py b/src/cpl-graphql/cpl/graphql/auth/administration/auth_user_filter.py new file mode 100644 index 00000000..19264a46 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/administration/auth_user_filter.py @@ -0,0 +1,11 @@ +from cpl.auth.schema import AuthUser +from cpl.graphql.schema.filter.db_model_filter import DbModelFilter +from cpl.graphql.schema.filter.string_filter import StringFilter + + +class AuthUserFilter(DbModelFilter[AuthUser]): + def __init__(self, public: bool = False): + DbModelFilter.__init__(self, public) + + self.field("username", StringFilter).with_public(public) + self.field("email", StringFilter).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py index dc53f754..a0724910 100644 --- a/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py +++ b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py @@ -1,6 +1,7 @@ from cpl.dependency.module.module import Module +from cpl.graphql.auth.administration.auth_user_filter import AuthUserFilter from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType class GraphQLAuthModule(Module): - transient = [AuthUserGraphType] + transient = [AuthUserGraphType, AuthUserFilter] diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py index 860712fe..aa4fb4d8 100644 --- a/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py +++ b/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py @@ -1,12 +1,10 @@ -from typing import Type, Generic +from typing import Generic from cpl.core.typing import T from cpl.graphql.schema.filter.bool_filter import BoolFilter from cpl.graphql.schema.filter.date_filter import DateFilter from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.filter.int_filter import IntFilter -from cpl.graphql.schema.filter.string_filter import StringFilter -from cpl.graphql.schema.input import Input class DbModelFilter(Filter[T], Generic[T]): @@ -15,6 +13,8 @@ class DbModelFilter(Filter[T], Generic[T]): self.field("id", IntFilter).with_public(public) self.field("deleted", BoolFilter).with_public(public) - # self.field("editor", AuthUserFilter) + from cpl.graphql.auth.administration.auth_user_filter import AuthUserFilter + + self.field("editor", lambda: AuthUserFilter).with_public(public) self.field("created", DateFilter).with_public(public) self.field("updated", DateFilter).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/schema/input.py b/src/cpl-graphql/cpl/graphql/schema/input.py index bcba7ae0..ce7817ab 100644 --- a/src/cpl-graphql/cpl/graphql/schema/input.py +++ b/src/cpl-graphql/cpl/graphql/schema/input.py @@ -1,8 +1,10 @@ +import types from typing import Generic, Dict, Type, Optional, Union, Any import strawberry from cpl.core.typing import T +from cpl.dependency import get_provider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.schema.field import Field from cpl.graphql.typing import AttributeName @@ -39,7 +41,7 @@ class Input(StrawberryProtocol, Generic[T]): def get_fields(self) -> dict[str, Field]: return self._fields - def field(self, name: AttributeName, typ: Union[type, "Input"], optional: bool = True) -> Field: + def field(self, name: AttributeName, typ: type, optional: bool = True) -> Field: if isinstance(name, property): name = name.fget.__name__ @@ -62,6 +64,9 @@ class Input(StrawberryProtocol, Generic[T]): return self.field(name, list[t], optional) def object_field(self, name: AttributeName, t: Type[StrawberryProtocol], optional: bool = True) -> Field: + if not isinstance(t, type) and callable(t): + return self.field(name, t, optional) + return self.field(name, t().to_strawberry(), optional) def to_strawberry(self) -> Type: @@ -69,20 +74,28 @@ class Input(StrawberryProtocol, Generic[T]): if TypeCollector.has(cls): return TypeCollector.get(cls) - annotations = {} - namespace = {} + gql_cls = type(f"{cls.__name__.replace('GraphType', '')}", (), {}) + # register early to handle recursive types + TypeCollector.set(cls, gql_cls) + + annotations: dict[str, Any] = {} + namespace: dict[str, Any] = {} for name, f in self._fields.items(): - typ = f.type - if isinstance(typ, type) and issubclass(typ, Input): - typ = typ().to_strawberry() - elif isinstance(typ, Input): - typ = typ.to_strawberry() + t = f.type - ann = typ if not f.optional else Optional[typ] + if isinstance(t, types.FunctionType): + _t = get_provider().get_service(t()) + if _t is None: + raise ValueError(f"'{t()}' could not be resolved from the provider") + t = _t.to_strawberry() + elif isinstance(t, type) and issubclass(t, Input): + t = t().to_strawberry() + elif isinstance(t, Input): + t = t.to_strawberry() py_name = name + "_" if name in _PYTHON_KEYWORDS else name - annotations[py_name] = ann + annotations[py_name] = t if not f.optional else Optional[t] field_args = {} if py_name != name: @@ -93,6 +106,10 @@ class Input(StrawberryProtocol, Generic[T]): namespace["__annotations__"] = annotations - gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace)) + for k, v in namespace.items(): + setattr(gql_cls, k, v) + + gql_cls.__annotations__ = annotations + gql_type = strawberry.input(gql_cls) TypeCollector.set(cls, gql_type) return gql_type From e7e3712e084b1c34f8b33439364d0c4064d5983f Mon Sep 17 00:00:00 2001 From: edraft Date: Mon, 29 Sep 2025 08:31:59 +0200 Subject: [PATCH 14/20] Renamed AuthUsers -> Users & completed user gql #181 --- example/api/src/main.py | 16 +-- example/api/src/queries/hello.py | 18 +-- src/cpl-api/cpl/api/application/web_app.py | 5 +- .../cpl/api/middleware/authentication.py | 8 +- .../cpl/api/middleware/authorization.py | 4 +- src/cpl-api/cpl/api/middleware/request.py | 8 +- src/cpl-api/cpl/api/typing.py | 4 +- src/cpl-auth/cpl/auth/auth_module.py | 4 +- .../cpl/auth/permission/role_seeder.py | 4 +- src/cpl-auth/cpl/auth/schema/__init__.py | 4 +- .../_administration/{auth_user.py => user.py} | 20 ++-- .../{auth_user_dao.py => user_dao.py} | 16 +-- .../cpl/auth/schema/_permission/role_user.py | 6 +- .../cpl/auth/scripts/mysql/1-users.sql | 18 +-- .../cpl/auth/scripts/mysql/2-api-key.sql | 2 +- .../scripts/mysql/3-roles-permissions.sql | 28 ++--- .../scripts/mysql/4-api-key-permissions.sql | 2 +- .../cpl/auth/scripts/postgres/1-users.sql | 10 +- .../cpl/auth/scripts/postgres/2-api-key.sql | 2 +- .../scripts/postgres/3-roles-permissions.sql | 18 +-- .../postgres/4-api-key-permissions.sql | 2 +- src/cpl-core/cpl/core/ctx/user_context.py | 8 +- .../cpl/database/abc/db_model_abc.py | 6 +- .../cpl/database/abc/db_model_dao_abc.py | 2 +- .../cpl/database/table_manager.py | 10 +- .../cpl/graphql/application/graphql_app.py | 13 ++ .../administration/auth_user_graph_type.py | 12 -- .../auth/administration/user/__init__.py | 0 .../user_filter.py} | 4 +- .../administration/user/user_graph_type.py | 12 ++ .../auth/administration/user/user_input.py | 23 ++++ .../auth/administration/user/user_mutation.py | 112 ++++++++++++++++++ .../cpl/graphql/auth/graphql_auth_module.py | 22 +++- src/cpl-graphql/cpl/graphql/graphql_module.py | 6 - src/cpl-graphql/cpl/graphql/query_context.py | 4 +- .../cpl/graphql/schema/db_model_graph_type.py | 7 +- .../graphql/schema/filter/db_model_filter.py | 6 +- .../cpl/graphql/schema/filter/filter.py | 10 +- .../cpl/graphql/schema/mutation.py | 76 +++++++++++- src/cpl-graphql/cpl/graphql/service/schema.py | 5 +- 40 files changed, 387 insertions(+), 150 deletions(-) rename src/cpl-auth/cpl/auth/schema/_administration/{auth_user.py => user.py} (78%) rename src/cpl-auth/cpl/auth/schema/_administration/{auth_user_dao.py => user_dao.py} (83%) delete mode 100644 src/cpl-graphql/cpl/graphql/auth/administration/auth_user_graph_type.py create mode 100644 src/cpl-graphql/cpl/graphql/auth/administration/user/__init__.py rename src/cpl-graphql/cpl/graphql/auth/administration/{auth_user_filter.py => user/user_filter.py} (80%) create mode 100644 src/cpl-graphql/cpl/graphql/auth/administration/user/user_graph_type.py create mode 100644 src/cpl-graphql/cpl/graphql/auth/administration/user/user_input.py create mode 100644 src/cpl-graphql/cpl/graphql/auth/administration/user/user_mutation.py diff --git a/example/api/src/main.py b/example/api/src/main.py index c57d0e39..d4fae505 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -1,17 +1,18 @@ from starlette.responses import JSONResponse from api.src.queries.cities import CityGraphType, CityFilter, CitySort -from api.src.queries.hello import UserGraphType#, AuthUserFilter, AuthUserSort, AuthUserGraphType +from api.src.queries.hello import UserGraphType#, UserFilter, UserSort, UserGraphType from api.src.queries.user import UserFilter, UserSort from cpl.api.api_module import ApiModule from cpl.application.application_builder import ApplicationBuilder -from cpl.auth.schema import AuthUser, Role +from cpl.auth.schema import User, Role from cpl.core.configuration import Configuration from cpl.core.console import Console from cpl.core.environment import Environment from cpl.core.utils.cache import Cache from cpl.database.mysql.mysql_module import MySQLModule from cpl.graphql.application.graphql_app import GraphQLApp +from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule from cpl.graphql.graphql_module import GraphQLModule from model.author_dao import AuthorDao from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort @@ -38,8 +39,9 @@ def main(): .add_module(MySQLModule) .add_module(ApiModule) .add_module(GraphQLModule) + .add_module(GraphQLAuthModule) .add_scoped(ScopedService) - .add_cache(AuthUser) + .add_cache(User) .add_cache(Role) .add_transient(CityGraphType) .add_transient(CityFilter) @@ -47,9 +49,9 @@ def main(): .add_transient(UserGraphType) .add_transient(UserFilter) .add_transient(UserSort) - # .add_transient(AuthUserGraphType) - # .add_transient(AuthUserFilter) - # .add_transient(AuthUserSort) + # .add_transient(UserGraphType) + # .add_transient(UserFilter) + # .add_transient(UserSort) .add_transient(HelloQuery) # test data .add_singleton(TestDataSeeder) @@ -100,7 +102,7 @@ def main(): app.with_permissions(PostPermissions) provider = builder.service_provider - user_cache = provider.get_service(Cache[AuthUser]) + user_cache = provider.get_service(Cache[User]) role_cache = provider.get_service(Cache[Role]) if role_cache == user_cache: diff --git a/example/api/src/queries/hello.py b/example/api/src/queries/hello.py index 88d9af27..c53ce008 100644 --- a/example/api/src/queries/hello.py +++ b/example/api/src/queries/hello.py @@ -1,7 +1,7 @@ from api.src.queries.cities import CityFilter, CitySort, CityGraphType, City from api.src.queries.user import User, UserFilter, UserSort, UserGraphType from cpl.api.middleware.request import get_request -from cpl.auth.schema import AuthUserDao, AuthUser +from cpl.auth.schema import UserDao, User from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.graph_type import GraphType from cpl.graphql.schema.query import Query @@ -11,20 +11,20 @@ from cpl.graphql.schema.sort.sort_order import SortOrder users = [User(i, f"User {i}") for i in range(1, 101)] cities = [City(i, f"City {i}") for i in range(1, 101)] -# class AuthUserFilter(Filter[AuthUser]): +# class UserFilter(Filter[User]): # def __init__(self): # Filter.__init__(self) # self.field("id", int) # self.field("username", str) # # -# class AuthUserSort(Sort[AuthUser]): +# class UserSort(Sort[User]): # def __init__(self): # Sort.__init__(self) # self.field("id", SortOrder) # self.field("username", SortOrder) # -# class AuthUserGraphType(GraphType[AuthUser]): +# class UserGraphType(GraphType[User]): # # def __init__(self): # GraphType.__init__(self) @@ -61,9 +61,9 @@ class HelloQuery(Query): resolver=lambda: cities, ) # self.dao_collection_field( - # AuthUserGraphType, - # AuthUserDao, - # "authUsers", - # AuthUserFilter, - # AuthUserSort, + # UserGraphType, + # UserDao, + # "Users", + # UserFilter, + # UserSort, # ) diff --git a/src/cpl-api/cpl/api/application/web_app.py b/src/cpl-api/cpl/api/application/web_app.py index b63b4700..f994444e 100644 --- a/src/cpl-api/cpl/api/application/web_app.py +++ b/src/cpl-api/cpl/api/application/web_app.py @@ -214,6 +214,9 @@ class WebApp(WebAppABC): self.with_middleware(AuthorizationMiddleware) return self + async def _log_before_startup(self): + self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}") + async def main(self): self._logger.debug(f"Preparing API") self._validate_policies() @@ -237,7 +240,7 @@ class WebApp(WebAppABC): else: app = self._app - self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}") + await self._log_before_startup() config = uvicorn.Config( app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio" diff --git a/src/cpl-api/cpl/api/middleware/authentication.py b/src/cpl-api/cpl/api/middleware/authentication.py index 9b45c076..8b40cdd1 100644 --- a/src/cpl-api/cpl/api/middleware/authentication.py +++ b/src/cpl-api/cpl/api/middleware/authentication.py @@ -7,13 +7,13 @@ from cpl.api.logger import APILogger from cpl.api.middleware.request import get_request from cpl.api.router import Router from cpl.auth.keycloak import KeycloakClient -from cpl.auth.schema import AuthUserDao, AuthUser +from cpl.auth.schema import UserDao, User from cpl.core.ctx import set_user class AuthenticationMiddleware(ASGIMiddleware): - def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao): + def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao): ASGIMiddleware.__init__(self, app) self._logger = logger @@ -72,12 +72,12 @@ class AuthenticationMiddleware(ASGIMiddleware): return await self._call_next(scope, receive, send) - async def _get_or_crate_user(self, keycloak_id: str) -> AuthUser: + async def _get_or_crate_user(self, keycloak_id: str) -> User: existing = await self._user_dao.find_by_keycloak_id(keycloak_id) if existing is not None: return existing - user = AuthUser(0, keycloak_id) + user = User(0, keycloak_id) uid = await self._user_dao.create(user) return await self._user_dao.get_by_id(uid) diff --git a/src/cpl-api/cpl/api/middleware/authorization.py b/src/cpl-api/cpl/api/middleware/authorization.py index b0b0d18c..64347cdc 100644 --- a/src/cpl-api/cpl/api/middleware/authorization.py +++ b/src/cpl-api/cpl/api/middleware/authorization.py @@ -7,13 +7,13 @@ from cpl.api.middleware.request import get_request from cpl.api.model.validation_match import ValidationMatch from cpl.api.registry.policy import PolicyRegistry from cpl.api.router import Router -from cpl.auth.schema._administration.auth_user_dao import AuthUserDao +from cpl.auth.schema._administration.user_dao import UserDao from cpl.core.ctx.user_context import get_user class AuthorizationMiddleware(ASGIMiddleware): - def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao): + def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: UserDao): ASGIMiddleware.__init__(self, app) self._logger = logger diff --git a/src/cpl-api/cpl/api/middleware/request.py b/src/cpl-api/cpl/api/middleware/request.py index 6ddea35c..05a291e3 100644 --- a/src/cpl-api/cpl/api/middleware/request.py +++ b/src/cpl-api/cpl/api/middleware/request.py @@ -10,8 +10,8 @@ from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.logger import APILogger from cpl.api.typing import TRequest from cpl.auth.keycloak.keycloak_client import KeycloakClient -from cpl.auth.schema import AuthUser -from cpl.auth.schema._administration.auth_user_dao import AuthUserDao +from cpl.auth.schema import User +from cpl.auth.schema._administration.user_dao import UserDao from cpl.core.ctx import set_user from cpl.dependency.inject import inject from cpl.dependency.service_provider import ServiceProvider @@ -22,7 +22,7 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa class RequestMiddleware(ASGIMiddleware): def __init__( - self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao + self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao ): ASGIMiddleware.__init__(self, app) @@ -80,7 +80,7 @@ class RequestMiddleware(ASGIMiddleware): user = await self._user_dao.find_by_keycloak_id(keycloak_id) if not user: - user = AuthUser(0, keycloak_id) + user = User(0, keycloak_id) uid = await self._user_dao.create(user) user = await self._user_dao.get_by_id(uid) diff --git a/src/cpl-api/cpl/api/typing.py b/src/cpl-api/cpl/api/typing.py index a62d4927..8d5f0c73 100644 --- a/src/cpl-api/cpl/api/typing.py +++ b/src/cpl-api/cpl/api/typing.py @@ -7,7 +7,7 @@ from starlette.types import ASGIApp from starlette.websockets import WebSocket from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware -from cpl.auth.schema import AuthUser +from cpl.auth.schema import User TRequest = Union[Request, WebSocket] TEndpoint = Callable[[TRequest, ...], Awaitable[Response]] | Callable[[TRequest, ...], Response] @@ -18,5 +18,5 @@ PartialMiddleware = Union[ Middleware, Callable[[ASGIApp], ASGIApp], ] -PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]] +PolicyResolver = Callable[[User], bool | Awaitable[bool]] PolicyInput = Union[dict[str, PolicyResolver], "Policy"] diff --git a/src/cpl-auth/cpl/auth/auth_module.py b/src/cpl-auth/cpl/auth/auth_module.py index ea2b8582..aa1f7bef 100644 --- a/src/cpl-auth/cpl/auth/auth_module.py +++ b/src/cpl-auth/cpl/auth/auth_module.py @@ -12,7 +12,7 @@ from cpl.dependency.service_provider import ServiceProvider from .keycloak.keycloak_admin import KeycloakAdmin from .keycloak.keycloak_client import KeycloakClient from .schema._administration.api_key_dao import ApiKeyDao -from .schema._administration.auth_user_dao import AuthUserDao +from .schema._administration.user_dao import UserDao from .schema._permission.api_key_permission_dao import ApiKeyPermissionDao from .schema._permission.permission_dao import PermissionDao from .schema._permission.role_dao import RoleDao @@ -26,7 +26,7 @@ class AuthModule(Module): singleton = [ KeycloakClient, KeycloakAdmin, - AuthUserDao, + UserDao, ApiKeyDao, ApiKeyPermissionDao, PermissionDao, diff --git a/src/cpl-auth/cpl/auth/permission/role_seeder.py b/src/cpl-auth/cpl/auth/permission/role_seeder.py index 15925299..b6a2db43 100644 --- a/src/cpl-auth/cpl/auth/permission/role_seeder.py +++ b/src/cpl-auth/cpl/auth/permission/role_seeder.py @@ -6,7 +6,7 @@ from cpl.auth.schema import ( RolePermissionDao, ApiKeyDao, ApiKeyPermissionDao, - AuthUserDao, + UserDao, RoleUserDao, RoleUser, ) @@ -23,7 +23,7 @@ class RoleSeeder(DataSeederABC): role_permission_dao: RolePermissionDao, api_key_dao: ApiKeyDao, api_key_permission_dao: ApiKeyPermissionDao, - user_dao: AuthUserDao, + user_dao: UserDao, role_user_dao: RoleUserDao, ): DataSeederABC.__init__(self) diff --git a/src/cpl-auth/cpl/auth/schema/__init__.py b/src/cpl-auth/cpl/auth/schema/__init__.py index cdb4b9d1..af3373ee 100644 --- a/src/cpl-auth/cpl/auth/schema/__init__.py +++ b/src/cpl-auth/cpl/auth/schema/__init__.py @@ -1,7 +1,7 @@ from ._administration.api_key import ApiKey from ._administration.api_key_dao import ApiKeyDao -from ._administration.auth_user import AuthUser -from ._administration.auth_user_dao import AuthUserDao +from ._administration.user import User +from ._administration.user_dao import UserDao from ._permission.api_key_permission import ApiKeyPermission from ._permission.api_key_permission_dao import ApiKeyPermissionDao diff --git a/src/cpl-auth/cpl/auth/schema/_administration/auth_user.py b/src/cpl-auth/cpl/auth/schema/_administration/user.py similarity index 78% rename from src/cpl-auth/cpl/auth/schema/_administration/auth_user.py rename to src/cpl-auth/cpl/auth/schema/_administration/user.py index 950a321c..f20740e6 100644 --- a/src/cpl-auth/cpl/auth/schema/_administration/auth_user.py +++ b/src/cpl-auth/cpl/auth/schema/_administration/user.py @@ -13,7 +13,7 @@ from cpl.database.logger import DBLogger from cpl.dependency import get_provider -class AuthUser(DbModelABC[Self]): +class User(DbModelABC[Self]): def __init__( self, id: SerialId, @@ -69,21 +69,21 @@ class AuthUser(DbModelABC[Self]): @async_property async def permissions(self): - from cpl.auth.schema._administration.auth_user_dao import AuthUserDao + from cpl.auth.schema._administration.user_dao import UserDao - auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) - return await auth_user_dao.get_permissions(self.id) + user_dao: UserDao = get_provider().get_service(UserDao) + return await user_dao.get_permissions(self.id) async def has_permission(self, permission: Permissions) -> bool: - from cpl.auth.schema._administration.auth_user_dao import AuthUserDao + from cpl.auth.schema._administration.user_dao import UserDao - auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) - return await auth_user_dao.has_permission(self.id, permission) + user_dao: UserDao = get_provider().get_service(UserDao) + return await user_dao.has_permission(self.id, permission) async def anonymize(self): - from cpl.auth.schema._administration.auth_user_dao import AuthUserDao + from cpl.auth.schema._administration.user_dao import UserDao - auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) + user_dao: UserDao = get_provider().get_service(UserDao) self._keycloak_id = str(uuid.UUID(int=0)) - await auth_user_dao.update(self) + await user_dao.update(self) diff --git a/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py b/src/cpl-auth/cpl/auth/schema/_administration/user_dao.py similarity index 83% rename from src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py rename to src/cpl-auth/cpl/auth/schema/_administration/user_dao.py index bf59a534..206ab553 100644 --- a/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py +++ b/src/cpl-auth/cpl/auth/schema/_administration/user_dao.py @@ -3,21 +3,21 @@ from typing import Optional, Union from cpl.auth.permission.permissions import Permissions from cpl.auth.schema._permission.permission_dao import PermissionDao from cpl.auth.schema._permission.permission import Permission -from cpl.auth.schema._administration.auth_user import AuthUser +from cpl.auth.schema._administration.user import User from cpl.database import TableManager from cpl.database.abc import DbModelDaoABC from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder from cpl.dependency.context import get_provider -class AuthUserDao(DbModelDaoABC[AuthUser]): +class UserDao(DbModelDaoABC[User]): def __init__(self, permission_dao: PermissionDao): - DbModelDaoABC.__init__(self, AuthUser, TableManager.get("auth_users")) + DbModelDaoABC.__init__(self, User, TableManager.get("users")) self._permissions = permission_dao - self.attribute(AuthUser.keycloak_id, str) + self.attribute(User.keycloak_id, str) async def get_users(): return [(x.id, x.username, x.email) for x in await self.get_all()] @@ -31,11 +31,11 @@ class AuthUserDao(DbModelDaoABC[AuthUser]): .with_value_getter(get_users) ) - async def get_by_keycloak_id(self, keycloak_id: str) -> AuthUser: - return await self.get_single_by({AuthUser.keycloak_id: keycloak_id}) + async def get_by_keycloak_id(self, keycloak_id: str) -> User: + return await self.get_single_by({User.keycloak_id: keycloak_id}) - async def find_by_keycloak_id(self, keycloak_id: str) -> Optional[AuthUser]: - return await self.find_single_by({AuthUser.keycloak_id: keycloak_id}) + async def find_by_keycloak_id(self, keycloak_id: str) -> Optional[User]: + return await self.find_single_by({User.keycloak_id: keycloak_id}) async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool: from cpl.auth.schema._permission.permission_dao import PermissionDao diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role_user.py b/src/cpl-auth/cpl/auth/schema/_permission/role_user.py index 90c4e05c..53806c9c 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role_user.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role_user.py @@ -29,10 +29,10 @@ class RoleUser(DbJoinModelABC): @async_property async def user(self): - from cpl.auth.schema._administration.auth_user_dao import AuthUserDao + from cpl.auth.schema._administration.user_dao import UserDao - auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) - return await auth_user_dao.get_by_id(self._user_id) + user_dao: UserDao = get_provider().get_service(UserDao) + return await user_dao.get_by_id(self._user_id) @property def role_id(self) -> int: diff --git a/src/cpl-auth/cpl/auth/scripts/mysql/1-users.sql b/src/cpl-auth/cpl/auth/scripts/mysql/1-users.sql index c3e09082..2226a9c2 100644 --- a/src/cpl-auth/cpl/auth/scripts/mysql/1-users.sql +++ b/src/cpl-auth/cpl/auth/scripts/mysql/1-users.sql @@ -1,4 +1,4 @@ -CREATE TABLE IF NOT EXISTS administration_auth_users +CREATE TABLE IF NOT EXISTS administration_users ( id INT AUTO_INCREMENT PRIMARY KEY, keycloakId CHAR(36) NOT NULL, @@ -9,10 +9,10 @@ CREATE TABLE IF NOT EXISTS administration_auth_users updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, CONSTRAINT UC_KeycloakId UNIQUE (keycloakId), - CONSTRAINT FK_EditorId FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) + CONSTRAINT FK_EditorId FOREIGN KEY (editorId) REFERENCES administration_users (id) ); -CREATE TABLE IF NOT EXISTS administration_auth_users_history +CREATE TABLE IF NOT EXISTS administration_users_history ( id INT NOT NULL, keycloakId CHAR(36) NOT NULL, @@ -23,22 +23,22 @@ CREATE TABLE IF NOT EXISTS administration_auth_users_history updated TIMESTAMP NOT NULL ); -CREATE TRIGGER TR_administration_auth_usersUpdate +CREATE TRIGGER TR_administration_usersUpdate AFTER UPDATE - ON administration_auth_users + ON administration_users FOR EACH ROW BEGIN - INSERT INTO administration_auth_users_history + INSERT INTO administration_users_history (id, keycloakId, deleted, editorId, created, updated) VALUES (OLD.id, OLD.keycloakId, OLD.deleted, OLD.editorId, OLD.created, NOW()); END; -CREATE TRIGGER TR_administration_auth_usersDelete +CREATE TRIGGER TR_administration_usersDelete AFTER DELETE - ON administration_auth_users + ON administration_users FOR EACH ROW BEGIN - INSERT INTO administration_auth_users_history + INSERT INTO administration_users_history (id, keycloakId, deleted, editorId, created, updated) VALUES (OLD.id, OLD.keycloakId, 1, OLD.editorId, OLD.created, NOW()); END; \ No newline at end of file diff --git a/src/cpl-auth/cpl/auth/scripts/mysql/2-api-key.sql b/src/cpl-auth/cpl/auth/scripts/mysql/2-api-key.sql index 134c6c78..09418f91 100644 --- a/src/cpl-auth/cpl/auth/scripts/mysql/2-api-key.sql +++ b/src/cpl-auth/cpl/auth/scripts/mysql/2-api-key.sql @@ -10,7 +10,7 @@ CREATE TABLE IF NOT EXISTS administration_api_keys CONSTRAINT UC_Identifier_Key UNIQUE (identifier, keyString), CONSTRAINT UC_Key UNIQUE (keyString), - CONSTRAINT FK_ApiKeys_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) + CONSTRAINT FK_ApiKeys_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id) ); CREATE TABLE IF NOT EXISTS administration_api_keys_history diff --git a/src/cpl-auth/cpl/auth/scripts/mysql/3-roles-permissions.sql b/src/cpl-auth/cpl/auth/scripts/mysql/3-roles-permissions.sql index 63a58fbf..23b4ecc8 100644 --- a/src/cpl-auth/cpl/auth/scripts/mysql/3-roles-permissions.sql +++ b/src/cpl-auth/cpl/auth/scripts/mysql/3-roles-permissions.sql @@ -8,7 +8,7 @@ CREATE TABLE IF NOT EXISTS permission_permissions created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, CONSTRAINT UQ_PermissionName UNIQUE (name), - CONSTRAINT FK_Permissions_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) + CONSTRAINT FK_Permissions_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id) ); CREATE TABLE IF NOT EXISTS permission_permissions_history @@ -52,7 +52,7 @@ CREATE TABLE IF NOT EXISTS permission_roles created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, CONSTRAINT UQ_RoleName UNIQUE (name), - CONSTRAINT FK_Roles_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) + CONSTRAINT FK_Roles_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id) ); CREATE TABLE IF NOT EXISTS permission_roles_history @@ -98,7 +98,7 @@ CREATE TABLE IF NOT EXISTS permission_role_permissions CONSTRAINT UQ_RolePermission UNIQUE (roleId, permissionId), CONSTRAINT FK_RolePermissions_Role FOREIGN KEY (roleId) REFERENCES permission_roles (id) ON DELETE CASCADE, CONSTRAINT FK_RolePermissions_Permission FOREIGN KEY (permissionId) REFERENCES permission_permissions (id) ON DELETE CASCADE, - CONSTRAINT FK_RolePermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) + CONSTRAINT FK_RolePermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id) ); CREATE TABLE IF NOT EXISTS permission_role_permissions_history @@ -132,7 +132,7 @@ BEGIN VALUES (OLD.id, OLD.roleId, OLD.permissionId, 1, OLD.editorId, OLD.created, NOW()); END; -CREATE TABLE IF NOT EXISTS permission_role_auth_users +CREATE TABLE IF NOT EXISTS permission_role_users ( id INT AUTO_INCREMENT PRIMARY KEY, roleId INT NOT NULL, @@ -142,12 +142,12 @@ CREATE TABLE IF NOT EXISTS permission_role_auth_users created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, CONSTRAINT UQ_RoleUser UNIQUE (roleId, userId), - CONSTRAINT FK_Roleauth_users_Role FOREIGN KEY (roleId) REFERENCES permission_roles (id) ON DELETE CASCADE, - CONSTRAINT FK_Roleauth_users_User FOREIGN KEY (userId) REFERENCES administration_auth_users (id) ON DELETE CASCADE, - CONSTRAINT FK_Roleauth_users_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) + CONSTRAINT FK_Roleusers_Role FOREIGN KEY (roleId) REFERENCES permission_roles (id) ON DELETE CASCADE, + CONSTRAINT FK_Roleusers_User FOREIGN KEY (userId) REFERENCES administration_users (id) ON DELETE CASCADE, + CONSTRAINT FK_Roleusers_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id) ); -CREATE TABLE IF NOT EXISTS permission_role_auth_users_history +CREATE TABLE IF NOT EXISTS permission_role_users_history ( id INT NOT NULL, roleId INT NOT NULL, @@ -158,22 +158,22 @@ CREATE TABLE IF NOT EXISTS permission_role_auth_users_history updated TIMESTAMP NOT NULL ); -CREATE TRIGGER TR_Roleauth_usersUpdate +CREATE TRIGGER TR_RoleusersUpdate AFTER UPDATE - ON permission_role_auth_users + ON permission_role_users FOR EACH ROW BEGIN - INSERT INTO permission_role_auth_users_history + INSERT INTO permission_role_users_history (id, roleId, userId, deleted, editorId, created, updated) VALUES (OLD.id, OLD.roleId, OLD.userId, OLD.deleted, OLD.editorId, OLD.created, NOW()); END; -CREATE TRIGGER TR_Roleauth_usersDelete +CREATE TRIGGER TR_RoleusersDelete AFTER DELETE - ON permission_role_auth_users + ON permission_role_users FOR EACH ROW BEGIN - INSERT INTO permission_role_auth_users_history + INSERT INTO permission_role_users_history (id, roleId, userId, deleted, editorId, created, updated) VALUES (OLD.id, OLD.roleId, OLD.userId, 1, OLD.editorId, OLD.created, NOW()); END; diff --git a/src/cpl-auth/cpl/auth/scripts/mysql/4-api-key-permissions.sql b/src/cpl-auth/cpl/auth/scripts/mysql/4-api-key-permissions.sql index 8f8253fd..3effa6c0 100644 --- a/src/cpl-auth/cpl/auth/scripts/mysql/4-api-key-permissions.sql +++ b/src/cpl-auth/cpl/auth/scripts/mysql/4-api-key-permissions.sql @@ -10,7 +10,7 @@ CREATE TABLE IF NOT EXISTS permission_api_key_permissions CONSTRAINT UQ_ApiKeyPermission UNIQUE (apiKeyId, permissionId), CONSTRAINT FK_ApiKeyPermissions_ApiKey FOREIGN KEY (apiKeyId) REFERENCES administration_api_keys (id) ON DELETE CASCADE, CONSTRAINT FK_ApiKeyPermissions_Permission FOREIGN KEY (permissionId) REFERENCES permission_permissions (id) ON DELETE CASCADE, - CONSTRAINT FK_ApiKeyPermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) + CONSTRAINT FK_ApiKeyPermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id) ); CREATE TABLE IF NOT EXISTS permission_api_key_permissions_history diff --git a/src/cpl-auth/cpl/auth/scripts/postgres/1-users.sql b/src/cpl-auth/cpl/auth/scripts/postgres/1-users.sql index 41d15483..1735852a 100644 --- a/src/cpl-auth/cpl/auth/scripts/postgres/1-users.sql +++ b/src/cpl-auth/cpl/auth/scripts/postgres/1-users.sql @@ -1,26 +1,26 @@ CREATE SCHEMA IF NOT EXISTS administration; -CREATE TABLE IF NOT EXISTS administration.auth_users +CREATE TABLE IF NOT EXISTS administration.users ( id SERIAL PRIMARY KEY, keycloakId UUID NOT NULL, -- for history deleted BOOLEAN NOT NULL DEFAULT FALSE, - editorId INT NULL REFERENCES administration.auth_users (id), + editorId INT NULL REFERENCES administration.users (id), created timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(), CONSTRAINT UC_KeycloakId UNIQUE (keycloakId) ); -CREATE TABLE IF NOT EXISTS administration.auth_users_history +CREATE TABLE IF NOT EXISTS administration.users_history ( - LIKE administration.auth_users + LIKE administration.users ); CREATE TRIGGER users_history_trigger BEFORE INSERT OR UPDATE OR DELETE - ON administration.auth_users + ON administration.users FOR EACH ROW EXECUTE FUNCTION public.history_trigger_function(); diff --git a/src/cpl-auth/cpl/auth/scripts/postgres/2-api-key.sql b/src/cpl-auth/cpl/auth/scripts/postgres/2-api-key.sql index 9944d667..e96ed708 100644 --- a/src/cpl-auth/cpl/auth/scripts/postgres/2-api-key.sql +++ b/src/cpl-auth/cpl/auth/scripts/postgres/2-api-key.sql @@ -7,7 +7,7 @@ CREATE TABLE IF NOT EXISTS administration.api_keys keyString VARCHAR(255) NOT NULL, -- for history deleted BOOLEAN NOT NULL DEFAULT FALSE, - editorId INT NULL REFERENCES administration.auth_users (id), + editorId INT NULL REFERENCES administration.users (id), created timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(), diff --git a/src/cpl-auth/cpl/auth/scripts/postgres/3-roles-permissions.sql b/src/cpl-auth/cpl/auth/scripts/postgres/3-roles-permissions.sql index 72400191..8ac5e1b1 100644 --- a/src/cpl-auth/cpl/auth/scripts/postgres/3-roles-permissions.sql +++ b/src/cpl-auth/cpl/auth/scripts/postgres/3-roles-permissions.sql @@ -9,7 +9,7 @@ CREATE TABLE permission.permissions -- for history deleted BOOLEAN NOT NULL DEFAULT FALSE, - editorId INT NULL REFERENCES administration.auth_users (id), + editorId INT NULL REFERENCES administration.users (id), created timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(), CONSTRAINT UQ_PermissionName UNIQUE (name) @@ -35,7 +35,7 @@ CREATE TABLE permission.roles -- for history deleted BOOLEAN NOT NULL DEFAULT FALSE, - editorId INT NULL REFERENCES administration.auth_users (id), + editorId INT NULL REFERENCES administration.users (id), created timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(), CONSTRAINT UQ_RoleName UNIQUE (name) @@ -61,7 +61,7 @@ CREATE TABLE permission.role_permissions -- for history deleted BOOLEAN NOT NULL DEFAULT FALSE, - editorId INT NULL REFERENCES administration.auth_users (id), + editorId INT NULL REFERENCES administration.users (id), created timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(), CONSTRAINT UQ_RolePermission UNIQUE (RoleId, permissionId) @@ -79,27 +79,27 @@ CREATE TRIGGER versioning_trigger EXECUTE PROCEDURE public.history_trigger_function(); -- Role user -CREATE TABLE permission.role_auth_users +CREATE TABLE permission.role_users ( id SERIAL PRIMARY KEY, RoleId INT NOT NULL REFERENCES permission.roles (id) ON DELETE CASCADE, - UserId INT NOT NULL REFERENCES administration.auth_users (id) ON DELETE CASCADE, + UserId INT NOT NULL REFERENCES administration.users (id) ON DELETE CASCADE, -- for history deleted BOOLEAN NOT NULL DEFAULT FALSE, - editorId INT NULL REFERENCES administration.auth_users (id), + editorId INT NULL REFERENCES administration.users (id), created timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(), CONSTRAINT UQ_RoleUser UNIQUE (RoleId, UserId) ); -CREATE TABLE permission.role_auth_users_history +CREATE TABLE permission.role_users_history ( - LIKE permission.role_auth_users + LIKE permission.role_users ); CREATE TRIGGER versioning_trigger BEFORE INSERT OR UPDATE OR DELETE - ON permission.role_auth_users + ON permission.role_users FOR EACH ROW EXECUTE PROCEDURE public.history_trigger_function(); \ No newline at end of file diff --git a/src/cpl-auth/cpl/auth/scripts/postgres/4-api-key-permissions.sql b/src/cpl-auth/cpl/auth/scripts/postgres/4-api-key-permissions.sql index 18e0d706..e0d677bb 100644 --- a/src/cpl-auth/cpl/auth/scripts/postgres/4-api-key-permissions.sql +++ b/src/cpl-auth/cpl/auth/scripts/postgres/4-api-key-permissions.sql @@ -6,7 +6,7 @@ CREATE TABLE permission.api_key_permissions -- for history deleted BOOLEAN NOT NULL DEFAULT FALSE, - editorId INT NULL REFERENCES administration.auth_users (id), + editorId INT NULL REFERENCES administration.users (id), created timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(), CONSTRAINT UQ_ApiKeyPermission UNIQUE (apiKeyId, permissionId) diff --git a/src/cpl-core/cpl/core/ctx/user_context.py b/src/cpl-core/cpl/core/ctx/user_context.py index a60d69f9..7aaa3584 100644 --- a/src/cpl-core/cpl/core/ctx/user_context.py +++ b/src/cpl-core/cpl/core/ctx/user_context.py @@ -1,13 +1,13 @@ from contextvars import ContextVar from typing import Optional -from cpl.auth.schema._administration.auth_user import AuthUser +from cpl.auth.schema._administration.user import User from cpl.dependency import get_provider -_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None) +_user_context: ContextVar[Optional[User]] = ContextVar("user", default=None) -def set_user(user: Optional[AuthUser]): +def set_user(user: Optional[User]): from cpl.core.log.logger_abc import LoggerABC logger = get_provider().get_service(LoggerABC) @@ -15,5 +15,5 @@ def set_user(user: Optional[AuthUser]): _user_context.set(user) -def get_user() -> Optional[AuthUser]: +def get_user() -> Optional[User]: return _user_context.get() diff --git a/src/cpl-database/cpl/database/abc/db_model_abc.py b/src/cpl-database/cpl/database/abc/db_model_abc.py index 4f38a8de..3272bf67 100644 --- a/src/cpl-database/cpl/database/abc/db_model_abc.py +++ b/src/cpl-database/cpl/database/abc/db_model_abc.py @@ -49,11 +49,11 @@ class DbModelABC(ABC, Generic[T]): if self._editor_id is None: return None - from cpl.auth.schema import AuthUserDao + from cpl.auth.schema import UserDao - auth_user_dao = get_provider().get_service(AuthUserDao) + user_dao = get_provider().get_service(UserDao) - return await auth_user_dao.get_by_id(self._editor_id) + return await user_dao.get_by_id(self._editor_id) @property def created(self) -> datetime: diff --git a/src/cpl-database/cpl/database/abc/db_model_dao_abc.py b/src/cpl-database/cpl/database/abc/db_model_dao_abc.py index 9d9bfef6..873ba4fd 100644 --- a/src/cpl-database/cpl/database/abc/db_model_dao_abc.py +++ b/src/cpl-database/cpl/database/abc/db_model_dao_abc.py @@ -18,7 +18,7 @@ class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]): self.attribute(DbModelABC.editor_id, int, db_name="editorId", ignore=True) # handled by db trigger self.reference( - "editor", "id", DbModelABC.editor_id, TableManager.get("auth_users") + "editor", "id", DbModelABC.editor_id, TableManager.get("users") ) # not relevant for updates due to editor_id self.attribute(DbModelABC.created, datetime, ignore=True) # handled by db trigger diff --git a/src/cpl-database/cpl/database/table_manager.py b/src/cpl-database/cpl/database/table_manager.py index 2d5ac533..7ca8d4e9 100644 --- a/src/cpl-database/cpl/database/table_manager.py +++ b/src/cpl-database/cpl/database/table_manager.py @@ -7,9 +7,9 @@ class TableManager: ServerTypes.POSTGRES: "system._executed_migrations", ServerTypes.MYSQL: "system__executed_migrations", }, - "auth_users": { - ServerTypes.POSTGRES: "administration.auth_users", - ServerTypes.MYSQL: "administration_auth_users", + "users": { + ServerTypes.POSTGRES: "administration.users", + ServerTypes.MYSQL: "administration_users", }, "api_keys": { ServerTypes.POSTGRES: "administration.api_keys", @@ -32,8 +32,8 @@ class TableManager: ServerTypes.MYSQL: "permission_role_permissions", }, "role_users": { - ServerTypes.POSTGRES: "permission.role_auth_users", - ServerTypes.MYSQL: "permission_role_auth_users", + ServerTypes.POSTGRES: "permission.role_users", + ServerTypes.MYSQL: "permission_role_users", }, } diff --git a/src/cpl-graphql/cpl/graphql/application/graphql_app.py b/src/cpl-graphql/cpl/graphql/application/graphql_app.py index ad4b06f0..bb422941 100644 --- a/src/cpl-graphql/cpl/graphql/application/graphql_app.py +++ b/src/cpl-graphql/cpl/graphql/application/graphql_app.py @@ -16,6 +16,9 @@ class GraphQLApp(WebApp): def __init__(self, services: ServiceProvider, modules: Modules): WebApp.__init__(self, services, modules, [GraphQLModule]) + self._with_graphiql = False + self._with_playground = False + def with_graphql( self, authentication: bool = False, @@ -57,6 +60,7 @@ class GraphQLApp(WebApp): policies=policies, match=match, ) + self._with_graphiql = True return self def with_playground( @@ -77,4 +81,13 @@ class GraphQLApp(WebApp): policies=policies, match=match, ) + self._with_playground = True return self + + + async def _log_before_startup(self): + self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}") + if self._with_graphiql: + self._logger.warning(f"GraphiQL available at http://{self._api_settings.host}:{self._api_settings.port}/api/graphiql") + if self._with_playground: + self._logger.warning(f"GraphQL Playground available at http://{self._api_settings.host}:{self._api_settings.port}/api/playground") \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/auth_user_graph_type.py b/src/cpl-graphql/cpl/graphql/auth/administration/auth_user_graph_type.py deleted file mode 100644 index d96af34e..00000000 --- a/src/cpl-graphql/cpl/graphql/auth/administration/auth_user_graph_type.py +++ /dev/null @@ -1,12 +0,0 @@ -from cpl.auth.schema import AuthUser -from cpl.graphql.schema.db_model_graph_type import DbModelGraphType - - -class AuthUserGraphType(DbModelGraphType): - - def __init__(self): - DbModelGraphType.__init__(self) - - self.string_field(AuthUser.keycloak_id, lambda root: root.keycloak_id) - self.string_field(AuthUser.username, lambda root: root.username) - self.string_field(AuthUser.email, lambda root: root.email) diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/user/__init__.py b/src/cpl-graphql/cpl/graphql/auth/administration/user/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/auth_user_filter.py b/src/cpl-graphql/cpl/graphql/auth/administration/user/user_filter.py similarity index 80% rename from src/cpl-graphql/cpl/graphql/auth/administration/auth_user_filter.py rename to src/cpl-graphql/cpl/graphql/auth/administration/user/user_filter.py index 19264a46..991e6efb 100644 --- a/src/cpl-graphql/cpl/graphql/auth/administration/auth_user_filter.py +++ b/src/cpl-graphql/cpl/graphql/auth/administration/user/user_filter.py @@ -1,9 +1,9 @@ -from cpl.auth.schema import AuthUser +from cpl.auth.schema import User from cpl.graphql.schema.filter.db_model_filter import DbModelFilter from cpl.graphql.schema.filter.string_filter import StringFilter -class AuthUserFilter(DbModelFilter[AuthUser]): +class UserFilter(DbModelFilter[User]): def __init__(self, public: bool = False): DbModelFilter.__init__(self, public) diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/user/user_graph_type.py b/src/cpl-graphql/cpl/graphql/auth/administration/user/user_graph_type.py new file mode 100644 index 00000000..d27ce05a --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/administration/user/user_graph_type.py @@ -0,0 +1,12 @@ +from cpl.auth.schema import User +from cpl.graphql.schema.db_model_graph_type import DbModelGraphType + + +class UserGraphType(DbModelGraphType): + + def __init__(self): + DbModelGraphType.__init__(self) + + self.string_field(User.keycloak_id, lambda root: root.keycloak_id) + self.string_field(User.username, lambda root: root.username) + self.string_field(User.email, lambda root: root.email) diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/user/user_input.py b/src/cpl-graphql/cpl/graphql/auth/administration/user/user_input.py new file mode 100644 index 00000000..be46dd10 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/administration/user/user_input.py @@ -0,0 +1,23 @@ +from cpl.auth.schema import User +from cpl.core.typing import SerialId +from cpl.graphql.schema.input import Input + + +class UserCreateInput(Input[User]): + keycloak_id: str + roles: list[SerialId] + + def __init__(self): + Input.__init__(self) + self.string_field("keycloak_id").with_required() + self.list_field("roles", SerialId) + + +class UserUpdateInput(Input[User]): + id: SerialId + roles: list[SerialId] + + def __init__(self): + Input.__init__(self) + self.int_field("id").with_required() + self.list_field("roles", SerialId) diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/user/user_mutation.py b/src/cpl-graphql/cpl/graphql/auth/administration/user/user_mutation.py new file mode 100644 index 00000000..c33fd76c --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/administration/user/user_mutation.py @@ -0,0 +1,112 @@ +from cpl.api import APILogger +from cpl.auth.keycloak import KeycloakAdmin +from cpl.auth.permission import Permissions +from cpl.auth.schema import UserDao, User, RoleUser, RoleUserDao, RoleDao +from cpl.core.ctx.user_context import get_user +from cpl.graphql.auth.administration.user.user_input import UserCreateInput, UserUpdateInput +from cpl.graphql.schema.mutation import Mutation + + +class UserMutation(Mutation): + def __init__( + self, + logger: APILogger, + user_dao: UserDao, + role_user_dao: RoleUserDao, + role_dao: RoleDao, + keycloak_admin: KeycloakAdmin, + ): + Mutation.__init__(self) + self._logger = logger + self._user_dao = user_dao + self._role_user_dao = role_user_dao + self._role_dao = role_dao + self._keycloak_admin = keycloak_admin + + self.int_field( + "create", + self.resolve_create, + ).with_require_any_permission(Permissions.users_create).with_argument( + "input", + UserCreateInput, + ).with_required() + + self.bool_field( + "update", + self.resolve_update, + ).with_require_any_permission(Permissions.users_update).with_argument( + "input", + UserUpdateInput, + ).with_required() + + self.bool_field( + "delete", + self.resolve_delete, + ).with_require_any_permission(Permissions.users_delete).with_argument( + "id", + int, + ).with_required() + + self.bool_field( + "restore", + self.resolve_restore, + ).with_require_any_permission(Permissions.users_delete).with_argument( + "id", + int, + ).with_required() + + async def resolve_create(self, input: UserCreateInput): + self._logger.debug(f"create user: {input.__dict__}") + + # ensure keycloak knows a user with this keycloak_id + # get_user should raise an exception if the user does not exist + kc_user = self._keycloak_admin.get_user(input.keycloak_id) + if kc_user is None: + raise ValueError(f"Keycloak user with id {input.keycloak_id} does not exist") + + user = User(0, input.keycloak_id, input.license) + user_id = await self._user_dao.create(user) + user = await self._user_dao.get_by_id(user_id) + await self._role_user_dao.create_many([RoleUser(0, user.id, x) for x in set(input.roles)]) + + return user + + async def resolve_update(self, input: UserUpdateInput): + self._logger.debug(f"update user: {input.__dict__}") + user = await self._user_dao.get_by_id(input.id) + + if input.license: + user.license = input.license + + await self._user_dao.update(user) + await self._resolve_assignments( + input.roles or [], + user, + RoleUser.user_id, + RoleUser.role_id, + self._user_dao, + self._role_user_dao, + RoleUser, + self._role_dao, + ) + + return user + + async def resolve_delete(self, id: int): + self._logger.debug(f"delete user: {id}") + user = await self._user_dao.get_by_id(id) + await self._user_dao.delete(user) + try: + active_user = get_user() + if active_user is not None and active_user.id == user.id: + # await broadcast.publish("userLogout", user.id) + self._keycloak_admin.user_logout(user_id=user.keycloak_id) + except Exception as e: + self._logger.error(f"Failed to logout user from Keycloak", e) + return True + + async def resolve_restore(self, id: int): + self._logger.debug(f"restore user: {id}") + user = await self._user_dao.get_by_id(id) + await self._user_dao.restore(user) + return True diff --git a/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py index a0724910..871676ff 100644 --- a/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py +++ b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py @@ -1,7 +1,23 @@ +from cpl.core.configuration import Configuration +from cpl.dependency import ServiceProvider from cpl.dependency.module.module import Module -from cpl.graphql.auth.administration.auth_user_filter import AuthUserFilter -from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType +from cpl.dependency.service_collection import ServiceCollection +from cpl.graphql.auth.administration.user.user_filter import UserFilter +from cpl.graphql.auth.administration.user.user_graph_type import UserGraphType +from cpl.graphql.auth.administration.user.user_mutation import UserMutation +from cpl.graphql.graphql_module import GraphQLModule +from cpl.graphql.service.schema import Schema class GraphQLAuthModule(Module): - transient = [AuthUserGraphType, AuthUserFilter] + dependencies = [GraphQLModule] + transient = [UserGraphType, UserMutation, UserFilter] + + @staticmethod + def register(collection: ServiceCollection): + Configuration.set("GraphQLAuthModuleEnabled", True) + + @staticmethod + def configure(provider: ServiceProvider): + schema = provider.get_service(Schema) + schema.with_type(UserGraphType) \ No newline at end of file diff --git a/src/cpl-graphql/cpl/graphql/graphql_module.py b/src/cpl-graphql/cpl/graphql/graphql_module.py index 05a36787..b749d16e 100644 --- a/src/cpl-graphql/cpl/graphql/graphql_module.py +++ b/src/cpl-graphql/cpl/graphql/graphql_module.py @@ -1,8 +1,6 @@ from cpl.api.api_module import ApiModule -from cpl.dependency import ServiceCollection from cpl.dependency.module.module import Module from cpl.dependency.service_provider import ServiceProvider -from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule from cpl.graphql.schema.filter.bool_filter import BoolFilter from cpl.graphql.schema.filter.date_filter import DateFilter from cpl.graphql.schema.filter.filter import Filter @@ -20,10 +18,6 @@ class GraphQLModule(Module): scoped = [GraphQLService] transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter] - @staticmethod - def register(collection: ServiceCollection): - collection.add_module(GraphQLAuthModule) - @staticmethod def configure(services: ServiceProvider) -> None: schema = services.get_service(Schema) diff --git a/src/cpl-graphql/cpl/graphql/query_context.py b/src/cpl-graphql/cpl/graphql/query_context.py index 44a916ee..831273c4 100644 --- a/src/cpl-graphql/cpl/graphql/query_context.py +++ b/src/cpl-graphql/cpl/graphql/query_context.py @@ -3,7 +3,7 @@ from typing import Optional from graphql import GraphQLResolveInfo -from cpl.auth.schema import AuthUser, Permission +from cpl.auth.schema import User, Permission from cpl.core.ctx import get_user @@ -25,7 +25,7 @@ class QueryContext: self._is_mutation = is_mutation @property - def user(self) -> AuthUser: + def user(self) -> User: return self._user @property diff --git a/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py b/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py index 32f6cfbc..a2e5ee1f 100644 --- a/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py +++ b/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py @@ -2,6 +2,7 @@ from typing import Type, Optional, Generic, Annotated import strawberry +from cpl.core.configuration import Configuration from cpl.core.typing import T from cpl.database.abc.data_access_object_abc import DataAccessObjectABC from cpl.graphql.schema.graph_type import GraphType @@ -23,9 +24,9 @@ class DbModelGraphType(GraphType[T], Generic[T]): self.int_field("id", lambda root: root.id).with_public(public) self.bool_field("deleted", lambda root: root.deleted).with_public(public) - from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType - - self.object_field("editor", lambda: AuthUserGraphType, lambda root: root.editor).with_public(public) + if Configuration.get("GraphQLAuthModuleEnabled", False): + from cpl.graphql.auth.administration.user.user_graph_type import UserGraphType + self.object_field("editor", lambda: UserGraphType, lambda root: root.editor).with_public(public) self.string_field("created", lambda root: root.created).with_public(public) self.string_field("updated", lambda root: root.updated).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py index aa4fb4d8..6e7681a7 100644 --- a/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py +++ b/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py @@ -1,5 +1,6 @@ from typing import Generic +from cpl.core.configuration.configuration import Configuration from cpl.core.typing import T from cpl.graphql.schema.filter.bool_filter import BoolFilter from cpl.graphql.schema.filter.date_filter import DateFilter @@ -13,8 +14,9 @@ class DbModelFilter(Filter[T], Generic[T]): self.field("id", IntFilter).with_public(public) self.field("deleted", BoolFilter).with_public(public) - from cpl.graphql.auth.administration.auth_user_filter import AuthUserFilter + if Configuration.get("GraphQLAuthModuleEnabled", False): + from cpl.graphql.auth.administration.user.user_filter import UserFilter + self.field("editor", lambda: UserFilter).with_public(public) - self.field("editor", lambda: AuthUserFilter).with_public(public) self.field("created", DateFilter).with_public(public) self.field("updated", DateFilter).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py index 6463ace9..75bd3c3c 100644 --- a/src/cpl-graphql/cpl/graphql/schema/filter/filter.py +++ b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py @@ -13,16 +13,16 @@ class Filter(Input[T]): Input.__init__(self) def filter_field(self, name: str, filter_type: Type["Filter"]): - self.field(name, filter_type()) + self.field(name, filter_type) def string_field(self, name: str): - self.field(name, StringFilter()) + self.field(name, StringFilter) def int_field(self, name: str): - self.field(name, IntFilter()) + self.field(name, IntFilter) def bool_field(self, name: str): - self.field(name, BoolFilter()) + self.field(name, BoolFilter) def date_field(self, name: str): - self.field(name, DateFilter()) + self.field(name, DateFilter) diff --git a/src/cpl-graphql/cpl/graphql/schema/mutation.py b/src/cpl-graphql/cpl/graphql/schema/mutation.py index 691cee10..82a707e9 100644 --- a/src/cpl-graphql/cpl/graphql/schema/mutation.py +++ b/src/cpl-graphql/cpl/graphql/schema/mutation.py @@ -1,5 +1,7 @@ -from typing import Type +from typing import Type, Union +from cpl.core.typing import T +from cpl.database.abc import DataAccessObjectABC, DbJoinModelABC from cpl.dependency.inject import inject from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.abc.query_abc import QueryABC @@ -23,3 +25,75 @@ class Mutation(QueryABC): raise ValueError(f"Mutation '{cls.__name__}' not registered in service provider") return self.field(name, sub.to_strawberry(), lambda: sub) + + @staticmethod + async def _resolve_assignments( + foreign_objs: list[int], + resolved_obj: T, + reference_key_own: Union[str, property], + reference_key_foreign: Union[str, property], + source_dao: DataAccessObjectABC[T], + join_dao: DataAccessObjectABC[T], + join_type: Type[DbJoinModelABC], + foreign_dao: DataAccessObjectABC[T], + ): + if foreign_objs is None: + return + + reference_key_foreign_attr = reference_key_foreign + if isinstance(reference_key_foreign, property): + reference_key_foreign_attr = reference_key_foreign.fget.__name__ + + foreign_list = await join_dao.find_by( + [{reference_key_own: resolved_obj.id}, {"deleted": False}] + ) + + to_delete = ( + foreign_list + if len(foreign_objs) == 0 + else await join_dao.find_by( + [ + {reference_key_own: resolved_obj.id}, + {reference_key_foreign: {"notIn": foreign_objs}}, + ] + ) + ) + foreign_ids = [getattr(x, reference_key_foreign_attr) for x in foreign_list] + deleted_foreign_ids = [ + getattr(x, reference_key_foreign_attr) + for x in await join_dao.find_by( + [{reference_key_own: resolved_obj.id}, {"deleted": True}] + ) + ] + + to_create = [ + join_type(0, resolved_obj.id, x) + for x in foreign_objs + if x not in foreign_ids and x not in deleted_foreign_ids + ] + to_restore = [ + await join_dao.get_single_by( + [ + {reference_key_own: resolved_obj.id}, + {reference_key_foreign: x}, + ] + ) + for x in foreign_objs + if x not in foreign_ids and x in deleted_foreign_ids + ] + + if len(to_delete) > 0: + await join_dao.delete_many(to_delete) + + if len(to_create) > 0: + await join_dao.create_many(to_create) + + if len(to_restore) > 0: + await join_dao.restore_many(to_restore) + + foreign_changes = [*to_delete, *to_create, *to_restore] + if len(foreign_changes) > 0: + await source_dao.touch(resolved_obj) + await foreign_dao.touch_many_by_id( + [getattr(x, reference_key_foreign_attr) for x in foreign_changes] + ) diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py index 3142adaa..9141f455 100644 --- a/src/cpl-graphql/cpl/graphql/service/schema.py +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -6,7 +6,6 @@ import strawberry from cpl.api.logger import APILogger from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol -from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType from cpl.graphql.schema.root_mutation import RootMutation from cpl.graphql.schema.root_query import RootQuery @@ -17,9 +16,7 @@ class Schema: self._logger = logger self._provider = provider - self._types: dict[str, Type[StrawberryProtocol]] = { - "AuthUserGraphType": AuthUserGraphType, - } + self._types: dict[str, Type[StrawberryProtocol]] = {} self._schema = None From 5b3872a1fe085b183cca852416645a7821dcd230 Mon Sep 17 00:00:00 2001 From: edraft Date: Mon, 29 Sep 2025 08:46:51 +0200 Subject: [PATCH 15/20] Added possibility to put auth schema to root graphs #181 --- example/api/src/main.py | 3 ++ .../cpl/graphql/application/graphql_app.py | 20 ++++++++- .../cpl/graphql/auth/graphql_auth_module.py | 4 +- .../cpl/graphql/schema/collection.py | 42 +++++++++++-------- .../cpl/graphql/utils/type_collector.py | 8 ++-- 5 files changed, 54 insertions(+), 23 deletions(-) diff --git a/example/api/src/main.py b/example/api/src/main.py index d4fae505..ac906f9b 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -96,6 +96,9 @@ def main(): schema.mutation.with_mutation("post", PostMutation).with_public() + app.with_auth_root_queries(True) + app.with_auth_root_mutations(True) + app.with_playground() app.with_graphiql() diff --git a/src/cpl-graphql/cpl/graphql/application/graphql_app.py b/src/cpl-graphql/cpl/graphql/application/graphql_app.py index bb422941..13e688c3 100644 --- a/src/cpl-graphql/cpl/graphql/application/graphql_app.py +++ b/src/cpl-graphql/cpl/graphql/application/graphql_app.py @@ -5,11 +5,16 @@ from cpl.api.application import WebApp from cpl.api.model.validation_match import ValidationMatch from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.typing import Modules +from queries.user import UserGraphType, UserFilter, UserSort from .._endpoints.graphiql import graphiql_endpoint from .._endpoints.graphql import graphql_endpoint from .._endpoints.playground import playground_endpoint +from ..auth.administration.user.user_mutation import UserMutation from ..graphql_module import GraphQLModule from ..service.schema import Schema +from ...application.abc.application_abc import __not_implemented__ +from ...auth.schema import UserDao +from ...core.configuration import Configuration class GraphQLApp(WebApp): @@ -84,10 +89,23 @@ class GraphQLApp(WebApp): self._with_playground = True return self + def with_auth_root_queries(self, public: bool = False): + if not Configuration.get("GraphQLAuthModuleEnabled", False): + raise Exception("GraphQLAuthModule is not loaded yet. Make sure to run 'add_module(GraphQLAuthModule)'") + + schema = self._services.get_service(Schema) + schema.query.dao_collection_field(UserGraphType, UserDao, "users", UserFilter, UserSort).with_public(public) + + def with_auth_root_mutations(self, public: bool = False): + if not Configuration.get("GraphQLAuthModuleEnabled", False): + raise Exception("GraphQLAuthModule is not loaded yet. Make sure to run 'add_module(GraphQLAuthModule)'") + + schema = self._services.get_service(Schema) + schema.mutation.with_mutation("user", UserMutation).with_public(public) async def _log_before_startup(self): self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}") if self._with_graphiql: self._logger.warning(f"GraphiQL available at http://{self._api_settings.host}:{self._api_settings.port}/api/graphiql") if self._with_playground: - self._logger.warning(f"GraphQL Playground available at http://{self._api_settings.host}:{self._api_settings.port}/api/playground") \ No newline at end of file + self._logger.warning(f"GraphQL Playground available at http://{self._api_settings.host}:{self._api_settings.port}/api/playground") diff --git a/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py index 871676ff..e340eaeb 100644 --- a/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py +++ b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py @@ -20,4 +20,6 @@ class GraphQLAuthModule(Module): @staticmethod def configure(provider: ServiceProvider): schema = provider.get_service(Schema) - schema.with_type(UserGraphType) \ No newline at end of file + schema.with_type(UserGraphType) + + diff --git a/src/cpl-graphql/cpl/graphql/schema/collection.py b/src/cpl-graphql/cpl/graphql/schema/collection.py index 0dbc66c0..9d600ab9 100644 --- a/src/cpl-graphql/cpl/graphql/schema/collection.py +++ b/src/cpl-graphql/cpl/graphql/schema/collection.py @@ -7,13 +7,15 @@ from cpl.dependency import get_provider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol -class CollectionGraphTypeFactory: - _cache: Dict[Type, Type] = {} +from cpl.graphql.utils.type_collector import TypeCollector +class CollectionGraphTypeFactory: @classmethod def get(cls, node_type: Type[StrawberryProtocol]) -> Type: - if node_type in cls._cache: - return cls._cache[node_type] + type_name = f"{node_type.__name__.replace('GraphType', '')}Collection" + + if TypeCollector.has(type_name): + return TypeCollector.get(type_name) node_t = get_provider().get_service(node_type) if not node_t: @@ -21,24 +23,30 @@ class CollectionGraphTypeFactory: gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type - gql_type = strawberry.type( - type( - f"{node_type.__name__.replace("GraphType", "")}Collection", - (), - { - "__annotations__": { - "nodes": List[gql_node], - "total_count": int, - "count": int, - } - }, - ) + gql_cls = type( + type_name, + (), + {} ) - cls._cache[node_type] = gql_type + TypeCollector.set(type_name, gql_cls) + + gql_cls.__annotations__ = { + "nodes": List[gql_node], + "total_count": int, + "count": int, + } + for k in gql_cls.__annotations__.keys(): + setattr(gql_cls, k, strawberry.field()) + + gql_type = strawberry.type(gql_cls) + + TypeCollector.set(type_name, gql_type) return gql_type + + class Collection: def __init__(self, nodes: list[T], total_count: int, count: int): self._nodes = nodes diff --git a/src/cpl-graphql/cpl/graphql/utils/type_collector.py b/src/cpl-graphql/cpl/graphql/utils/type_collector.py index bf9e4332..439d3ec2 100644 --- a/src/cpl-graphql/cpl/graphql/utils/type_collector.py +++ b/src/cpl-graphql/cpl/graphql/utils/type_collector.py @@ -2,16 +2,16 @@ from typing import Type, Any class TypeCollector: - _registry: dict[type, Type] = {} + _registry: dict[type | str, Type] = {} @classmethod - def has(cls, base: type) -> bool: + def has(cls, base: type | str) -> bool: return base in cls._registry @classmethod - def get(cls, base: type) -> Type: + def get(cls, base: type | str) -> Type: return cls._registry[base] @classmethod - def set(cls, base: type, gql_type: Type): + def set(cls, base: type | str, gql_type: Type): cls._registry[base] = gql_type From 262e26cb8312fe43478fc47e753982920f931d89 Mon Sep 17 00:00:00 2001 From: edraft Date: Mon, 29 Sep 2025 19:51:59 +0200 Subject: [PATCH 16/20] Internal api key gql #181 --- .../cpl/graphql/application/graphql_app.py | 31 +++--- .../auth/administration/api_key/__init__.py | 0 .../administration/api_key/api_key_filter.py | 10 ++ .../api_key/api_key_graph_type.py | 14 +++ .../administration/api_key/api_key_input.py | 25 +++++ .../api_key/api_key_mutation.py | 96 +++++++++++++++++++ .../auth/administration/user/user_input.py | 4 +- .../cpl/graphql/auth/graphql_auth_module.py | 7 +- 8 files changed, 170 insertions(+), 17 deletions(-) create mode 100644 src/cpl-graphql/cpl/graphql/auth/administration/api_key/__init__.py create mode 100644 src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_filter.py create mode 100644 src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_graph_type.py create mode 100644 src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_input.py create mode 100644 src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_mutation.py diff --git a/src/cpl-graphql/cpl/graphql/application/graphql_app.py b/src/cpl-graphql/cpl/graphql/application/graphql_app.py index 13e688c3..207c1861 100644 --- a/src/cpl-graphql/cpl/graphql/application/graphql_app.py +++ b/src/cpl-graphql/cpl/graphql/application/graphql_app.py @@ -1,20 +1,21 @@ +import socket from enum import Enum from typing import Self from cpl.api.application import WebApp from cpl.api.model.validation_match import ValidationMatch +from cpl.auth.schema import UserDao +from cpl.core.configuration import Configuration +from cpl.core.environment import Environment from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.typing import Modules from queries.user import UserGraphType, UserFilter, UserSort -from .._endpoints.graphiql import graphiql_endpoint -from .._endpoints.graphql import graphql_endpoint -from .._endpoints.playground import playground_endpoint -from ..auth.administration.user.user_mutation import UserMutation -from ..graphql_module import GraphQLModule -from ..service.schema import Schema -from ...application.abc.application_abc import __not_implemented__ -from ...auth.schema import UserDao -from ...core.configuration import Configuration +from cpl.graphql._endpoints.graphiql import graphiql_endpoint +from cpl.graphql._endpoints.graphql import graphql_endpoint +from cpl.graphql._endpoints.playground import playground_endpoint +from cpl.graphql.auth.administration.user.user_mutation import UserMutation +from cpl.graphql.graphql_module import GraphQLModule +from cpl.graphql.service.schema import Schema class GraphQLApp(WebApp): @@ -104,8 +105,14 @@ class GraphQLApp(WebApp): schema.mutation.with_mutation("user", UserMutation).with_public(public) async def _log_before_startup(self): - self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}") + host = self._api_settings.host + if host == "0.0.0.0" and Environment.get_environment() == "development": + host = "localhost" + elif host == "0.0.0.0": + host = socket.gethostbyname(socket.gethostname()) + + self._logger.info(f"Start API on {host}:{self._api_settings.port}") if self._with_graphiql: - self._logger.warning(f"GraphiQL available at http://{self._api_settings.host}:{self._api_settings.port}/api/graphiql") + self._logger.warning(f"GraphiQL available at http://{host}:{self._api_settings.port}/api/graphiql") if self._with_playground: - self._logger.warning(f"GraphQL Playground available at http://{self._api_settings.host}:{self._api_settings.port}/api/playground") + self._logger.warning(f"GraphQL Playground available at http://{host}:{self._api_settings.port}/api/playground") diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/api_key/__init__.py b/src/cpl-graphql/cpl/graphql/auth/administration/api_key/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_filter.py b/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_filter.py new file mode 100644 index 00000000..9c5752d2 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_filter.py @@ -0,0 +1,10 @@ +from cpl.auth.schema import ApiKey +from cpl.graphql.schema.filter.db_model_filter import DbModelFilter +from cpl.graphql.schema.filter.string_filter import StringFilter + + +class ApiKeyFilter(DbModelFilter[ApiKey]): + def __init__(self, public: bool = False): + DbModelFilter.__init__(self, public) + + self.field("identifier", StringFilter).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_graph_type.py b/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_graph_type.py new file mode 100644 index 00000000..c70959a1 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_graph_type.py @@ -0,0 +1,14 @@ +from cpl.auth.schema import ApiKey, RolePermissionDao +from cpl.graphql.schema.db_model_graph_type import DbModelGraphType + + +class ApiKeyGraphType(DbModelGraphType): + + def __init__(self, role_permission_dao: RolePermissionDao): + DbModelGraphType.__init__(self) + + self.string_field(ApiKey.identifier, lambda root: root.identifier) + self.string_field(ApiKey.key, lambda root: root.key) + self.string_field(ApiKey.permissions, lambda root: root.permissions) + + self.set_history_reference_dao(role_permission_dao, "apikeyid") diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_input.py b/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_input.py new file mode 100644 index 00000000..a669fce1 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_input.py @@ -0,0 +1,25 @@ +from cpl.auth.schema import ApiKey +from cpl.core.typing import SerialId +from cpl.graphql.schema.input import Input + + +class ApiKeyCreateInput(Input[ApiKey]): + identifier: str + permissions: list[SerialId] + + def __init__(self): + Input.__init__(self) + self.string_field("identifier").with_required() + self.list_field("permissions", SerialId) + + +class ApiKeyUpdateInput(Input[ApiKey]): + id: SerialId + identifier: str | None + permissions: list[SerialId] | None + + def __init__(self): + Input.__init__(self) + self.int_field("id").with_required() + self.string_field("identifier").with_required() + self.list_field("permissions", SerialId) diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_mutation.py b/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_mutation.py new file mode 100644 index 00000000..ea2f9cf1 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_mutation.py @@ -0,0 +1,96 @@ +from cpl.api import APILogger +from cpl.auth.keycloak import KeycloakAdmin +from cpl.auth.permission import Permissions +from cpl.auth.schema import ApiKey, ApiKeyDao, ApiKeyPermissionDao, ApiKeyPermission +from cpl.graphql.auth.administration.api_key.api_key_input import ApiKeyUpdateInput, ApiKeyCreateInput +from cpl.graphql.schema.mutation import Mutation + + +class ApiKeyMutation(Mutation): + def __init__( + self, + + logger: APILogger, + api_key_dao: ApiKeyDao, + api_key_permission_dao: ApiKeyPermissionDao, + permission_dao: ApiKeyPermissionDao, + keycloak_admin: KeycloakAdmin, + ): + Mutation.__init__(self) + self._logger = logger + self._api_key_dao = api_key_dao + self._api_key_permission_dao = api_key_permission_dao + self._permission_dao = permission_dao + self._keycloak_admin = keycloak_admin + + self.int_field( + "create", + self.resolve_create, + ).with_require_any_permission(Permissions.users_create).with_argument( + "input", + ApiKeyCreateInput, + ).with_required() + + self.bool_field( + "update", + self.resolve_update, + ).with_require_any_permission(Permissions.users_update).with_argument( + "input", + ApiKeyUpdateInput, + ).with_required() + + self.bool_field( + "delete", + self.resolve_delete, + ).with_require_any_permission(Permissions.users_delete).with_argument( + "id", + int, + ).with_required() + + self.bool_field( + "restore", + self.resolve_restore, + ).with_require_any_permission(Permissions.users_delete).with_argument( + "id", + int, + ).with_required() + + async def resolve_create(self, obj: ApiKeyCreateInput): + self._logger.debug(f"create api key: {obj.__dict__}") + + api_key = ApiKey.new(obj.identifier) + await self._api_key_dao.create(api_key) + api_key = await self._api_key_dao.get_single_by([{ApiKey.identifier: obj.identifier}]) + await self._api_key_permission_dao.create_many( + [ApiKeyPermission(0, api_key.id, x) for x in obj.permissions] + ) + return api_key + + async def resolve_update(self, input: ApiKeyUpdateInput): + self._logger.debug(f"update api key: {input}") + api_key = await self._api_key_dao.get_by_id(input.id) + + await self._resolve_assignments( + input.permissions or [], + api_key, + ApiKeyPermission.api_key_id, + ApiKeyPermission.permission_id, + self._api_key_dao, + self._api_key_permission_dao, + ApiKeyPermission, + self._permission_dao, + ) + + return api_key + + async def resolve_delete(self, id: str): + self._logger.debug(f"delete api key: {id}") + api_key = await self._api_key_dao.get_by_id(id) + await self._api_key_dao.delete(api_key) + return True + + async def resolve_restore(self, id: str): + self._logger.debug(f"restore api key: {id}") + api_key = await self._api_key_dao.get_by_id(id) + await self._api_key_dao.restore(api_key) + return True diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/user/user_input.py b/src/cpl-graphql/cpl/graphql/auth/administration/user/user_input.py index be46dd10..c5f5ac07 100644 --- a/src/cpl-graphql/cpl/graphql/auth/administration/user/user_input.py +++ b/src/cpl-graphql/cpl/graphql/auth/administration/user/user_input.py @@ -5,7 +5,7 @@ from cpl.graphql.schema.input import Input class UserCreateInput(Input[User]): keycloak_id: str - roles: list[SerialId] + roles: list[SerialId] | None def __init__(self): Input.__init__(self) @@ -15,7 +15,7 @@ class UserCreateInput(Input[User]): class UserUpdateInput(Input[User]): id: SerialId - roles: list[SerialId] + roles: list[SerialId] | None def __init__(self): Input.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py index e340eaeb..ba0da432 100644 --- a/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py +++ b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py @@ -2,6 +2,9 @@ from cpl.core.configuration import Configuration from cpl.dependency import ServiceProvider from cpl.dependency.module.module import Module from cpl.dependency.service_collection import ServiceCollection +from cpl.graphql.auth.administration.api_key.api_key_filter import ApiKeyFilter +from cpl.graphql.auth.administration.api_key.api_key_graph_type import ApiKeyGraphType +from cpl.graphql.auth.administration.api_key.api_key_mutation import ApiKeyMutation from cpl.graphql.auth.administration.user.user_filter import UserFilter from cpl.graphql.auth.administration.user.user_graph_type import UserGraphType from cpl.graphql.auth.administration.user.user_mutation import UserMutation @@ -11,7 +14,7 @@ from cpl.graphql.service.schema import Schema class GraphQLAuthModule(Module): dependencies = [GraphQLModule] - transient = [UserGraphType, UserMutation, UserFilter] + transient = [UserGraphType, UserMutation, UserFilter, ApiKeyGraphType, ApiKeyMutation, ApiKeyFilter] @staticmethod def register(collection: ServiceCollection): @@ -21,5 +24,3 @@ class GraphQLAuthModule(Module): def configure(provider: ServiceProvider): schema = provider.get_service(Schema) schema.with_type(UserGraphType) - - From e362b7fb61fb5f6b30cba3f617bb6689bc9df93a Mon Sep 17 00:00:00 2001 From: edraft Date: Mon, 29 Sep 2025 21:00:24 +0200 Subject: [PATCH 17/20] Added/Fixed api_key/user/role gql #181 --- .../schema/_permission/role_permission.py | 20 ++-- src/cpl-graphql/cpl/graphql/abc/query_abc.py | 39 +++++-- .../cpl/graphql/application/graphql_app.py | 23 ++-- .../{administration => api_key}/__init__.py | 0 .../api_key/api_key_filter.py | 0 .../api_key/api_key_graph_type.py | 2 +- .../api_key/api_key_input.py | 0 .../api_key/api_key_mutation.py | 2 +- .../cpl/graphql/auth/api_key/api_key_sort.py | 9 ++ .../cpl/graphql/auth/graphql_auth_module.py | 58 ++++++++-- .../api_key => role}/__init__.py | 0 .../cpl/graphql/auth/role/role_filter.py | 11 ++ .../cpl/graphql/auth/role/role_graph_type.py | 14 +++ .../cpl/graphql/auth/role/role_input.py | 29 +++++ .../cpl/graphql/auth/role/role_mutation.py | 101 ++++++++++++++++++ .../cpl/graphql/auth/role/role_sort.py | 10 ++ .../{administration => }/user/__init__.py | 0 .../{administration => }/user/user_filter.py | 0 .../user/user_graph_type.py | 2 +- .../{administration => }/user/user_input.py | 0 .../user/user_mutation.py | 2 +- .../cpl/graphql/auth/user/user_sort.py | 10 ++ .../cpl/graphql/schema/db_model_graph_type.py | 2 +- .../graphql/schema/filter/db_model_filter.py | 2 +- .../cpl/graphql/schema/sort/db_model_sort.py | 19 ++++ 25 files changed, 311 insertions(+), 44 deletions(-) rename src/cpl-graphql/cpl/graphql/auth/{administration => api_key}/__init__.py (100%) rename src/cpl-graphql/cpl/graphql/auth/{administration => }/api_key/api_key_filter.py (100%) rename src/cpl-graphql/cpl/graphql/auth/{administration => }/api_key/api_key_graph_type.py (91%) rename src/cpl-graphql/cpl/graphql/auth/{administration => }/api_key/api_key_input.py (100%) rename src/cpl-graphql/cpl/graphql/auth/{administration => }/api_key/api_key_mutation.py (96%) create mode 100644 src/cpl-graphql/cpl/graphql/auth/api_key/api_key_sort.py rename src/cpl-graphql/cpl/graphql/auth/{administration/api_key => role}/__init__.py (100%) create mode 100644 src/cpl-graphql/cpl/graphql/auth/role/role_filter.py create mode 100644 src/cpl-graphql/cpl/graphql/auth/role/role_graph_type.py create mode 100644 src/cpl-graphql/cpl/graphql/auth/role/role_input.py create mode 100644 src/cpl-graphql/cpl/graphql/auth/role/role_mutation.py create mode 100644 src/cpl-graphql/cpl/graphql/auth/role/role_sort.py rename src/cpl-graphql/cpl/graphql/auth/{administration => }/user/__init__.py (100%) rename src/cpl-graphql/cpl/graphql/auth/{administration => }/user/user_filter.py (100%) rename src/cpl-graphql/cpl/graphql/auth/{administration => }/user/user_graph_type.py (89%) rename src/cpl-graphql/cpl/graphql/auth/{administration => }/user/user_input.py (100%) rename src/cpl-graphql/cpl/graphql/auth/{administration => }/user/user_mutation.py (97%) create mode 100644 src/cpl-graphql/cpl/graphql/auth/user/user_sort.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/sort/db_model_sort.py diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py b/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py index 8038227b..6aea5fbf 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py @@ -1,14 +1,14 @@ from datetime import datetime -from typing import Optional, Self +from typing import Self from async_property import async_property from cpl.core.typing import SerialId -from cpl.database.abc import DbModelABC -from cpl.dependency import ServiceProvider, get_provider +from cpl.database.abc import DbJoinModelABC +from cpl.dependency import get_provider -class RolePermission(DbModelABC[Self]): +class RolePermission(DbJoinModelABC[Self]): def __init__( self, id: SerialId, @@ -19,28 +19,26 @@ class RolePermission(DbModelABC[Self]): created: datetime | None = None, updated: datetime | None = None, ): - DbModelABC.__init__(self, id, deleted, editor_id, created, updated) - self._role_id = role_id - self._permission_id = permission_id + DbJoinModelABC.__init__(self, id, role_id, permission_id, deleted, editor_id, created, updated) @property def role_id(self) -> int: - return self._role_id + return self._source_id @async_property async def role(self): from cpl.auth.schema._permission.role_dao import RoleDao role_dao: RoleDao = get_provider().get_service(RoleDao) - return await role_dao.get_by_id(self._role_id) + return await role_dao.get_by_id(self._source_id) @property def permission_id(self) -> int: - return self._permission_id + return self._foreign_id @async_property async def permission(self): from cpl.auth.schema._permission.permission_dao import PermissionDao permission_dao: PermissionDao = get_provider().get_service(PermissionDao) - return await permission_dao.get_by_id(self._permission_id) + return await permission_dao.get_by_id(self._foreign_id) diff --git a/src/cpl-graphql/cpl/graphql/abc/query_abc.py b/src/cpl-graphql/cpl/graphql/abc/query_abc.py index 8cad66d2..1c7cb648 100644 --- a/src/cpl-graphql/cpl/graphql/abc/query_abc.py +++ b/src/cpl-graphql/cpl/graphql/abc/query_abc.py @@ -2,10 +2,11 @@ import functools import inspect import types from abc import ABC -from asyncio import iscoroutinefunction, iscoroutine +from asyncio import iscoroutinefunction from typing import Callable, Type, Any, Optional import strawberry +from async_property.base import AsyncPropertyDescriptor from strawberry.exceptions import StrawberryException from cpl.api import Unauthorized, Forbidden @@ -169,6 +170,15 @@ class QueryABC(StrawberryProtocol, ABC): except StrawberryException as e: raise Exception(f"Error converting field '{f.name}' to strawberry field: {e}") from e + @staticmethod + def _type_to_strawberry(t: Type) -> Type: + _t = get_provider().get_service(t) + + if isinstance(_t, StrawberryProtocol): + return _t.to_strawberry() + + return t + def to_strawberry(self) -> Type: cls = self.__class__ if TypeCollector.has(cls): @@ -183,22 +193,35 @@ class QueryABC(StrawberryProtocol, ABC): for name, f in self._fields.items(): t = f.type + if isinstance(name, property): + name = name.fget.__name__ + if isinstance(name, AsyncPropertyDescriptor): + name = name.field_name + + if isinstance(t, types.GenericAlias): + t = t.__args__[0] if callable(t) and not isinstance(t, type): - _t = get_provider().get_service(t()) - if isinstance(_t, StrawberryProtocol): - t = _t.to_strawberry() - else: - t = _t + t = self._type_to_strawberry(t()) + elif issubclass(t, StrawberryProtocol): + t = self._type_to_strawberry(t) annotations[name] = t if not f.optional else Optional[t] namespace[name] = self._field_to_strawberry(f) namespace["__annotations__"] = annotations for k, v in namespace.items(): + if isinstance(k, property): + k = k.fget.__name__ + if isinstance(k, AsyncPropertyDescriptor): + k = k.field_name + setattr(gql_cls, k, v) - gql_cls.__annotations__ = annotations - gql_type = strawberry.type(gql_cls) + try: + gql_cls.__annotations__ = annotations + gql_type = strawberry.type(gql_cls) + except Exception as e: + raise Exception(f"Error creating strawberry type for '{cls.__name__}': {e}") from e TypeCollector.set(cls, gql_type) return gql_type diff --git a/src/cpl-graphql/cpl/graphql/application/graphql_app.py b/src/cpl-graphql/cpl/graphql/application/graphql_app.py index 207c1861..8c15dec8 100644 --- a/src/cpl-graphql/cpl/graphql/application/graphql_app.py +++ b/src/cpl-graphql/cpl/graphql/application/graphql_app.py @@ -4,16 +4,13 @@ from typing import Self from cpl.api.application import WebApp from cpl.api.model.validation_match import ValidationMatch -from cpl.auth.schema import UserDao -from cpl.core.configuration import Configuration +from cpl.application.abc.application_abc import __not_implemented__ from cpl.core.environment import Environment from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.typing import Modules -from queries.user import UserGraphType, UserFilter, UserSort from cpl.graphql._endpoints.graphiql import graphiql_endpoint from cpl.graphql._endpoints.graphql import graphql_endpoint from cpl.graphql._endpoints.playground import playground_endpoint -from cpl.graphql.auth.administration.user.user_mutation import UserMutation from cpl.graphql.graphql_module import GraphQLModule from cpl.graphql.service.schema import Schema @@ -91,18 +88,20 @@ class GraphQLApp(WebApp): return self def with_auth_root_queries(self, public: bool = False): - if not Configuration.get("GraphQLAuthModuleEnabled", False): - raise Exception("GraphQLAuthModule is not loaded yet. Make sure to run 'add_module(GraphQLAuthModule)'") + try: + from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule - schema = self._services.get_service(Schema) - schema.query.dao_collection_field(UserGraphType, UserDao, "users", UserFilter, UserSort).with_public(public) + GraphQLAuthModule.with_auth_root_queries(self._services, public=public) + except ImportError: + __not_implemented__("cpl-auth & cpl-graphql", self.with_auth_root_mutations) def with_auth_root_mutations(self, public: bool = False): - if not Configuration.get("GraphQLAuthModuleEnabled", False): - raise Exception("GraphQLAuthModule is not loaded yet. Make sure to run 'add_module(GraphQLAuthModule)'") + try: + from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule - schema = self._services.get_service(Schema) - schema.mutation.with_mutation("user", UserMutation).with_public(public) + GraphQLAuthModule.with_auth_root_mutations(self._services, public=public) + except ImportError: + __not_implemented__("cpl-auth & cpl-graphql", self.with_auth_root_mutations) async def _log_before_startup(self): host = self._api_settings.host diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/__init__.py b/src/cpl-graphql/cpl/graphql/auth/api_key/__init__.py similarity index 100% rename from src/cpl-graphql/cpl/graphql/auth/administration/__init__.py rename to src/cpl-graphql/cpl/graphql/auth/api_key/__init__.py diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_filter.py b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_filter.py similarity index 100% rename from src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_filter.py rename to src/cpl-graphql/cpl/graphql/auth/api_key/api_key_filter.py diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_graph_type.py b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_graph_type.py similarity index 91% rename from src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_graph_type.py rename to src/cpl-graphql/cpl/graphql/auth/api_key/api_key_graph_type.py index c70959a1..0bb52bbb 100644 --- a/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_graph_type.py +++ b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_graph_type.py @@ -2,7 +2,7 @@ from cpl.auth.schema import ApiKey, RolePermissionDao from cpl.graphql.schema.db_model_graph_type import DbModelGraphType -class ApiKeyGraphType(DbModelGraphType): +class ApiKeyGraphType(DbModelGraphType[ApiKey]): def __init__(self, role_permission_dao: RolePermissionDao): DbModelGraphType.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_input.py b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_input.py similarity index 100% rename from src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_input.py rename to src/cpl-graphql/cpl/graphql/auth/api_key/api_key_input.py diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_mutation.py b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py similarity index 96% rename from src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_mutation.py rename to src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py index ea2f9cf1..c431eee8 100644 --- a/src/cpl-graphql/cpl/graphql/auth/administration/api_key/api_key_mutation.py +++ b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py @@ -2,7 +2,7 @@ from cpl.api import APILogger from cpl.auth.keycloak import KeycloakAdmin from cpl.auth.permission import Permissions from cpl.auth.schema import ApiKey, ApiKeyDao, ApiKeyPermissionDao, ApiKeyPermission -from cpl.graphql.auth.administration.api_key.api_key_input import ApiKeyUpdateInput, ApiKeyCreateInput +from cpl.graphql.auth.api_key.api_key_input import ApiKeyUpdateInput, ApiKeyCreateInput from cpl.graphql.schema.mutation import Mutation diff --git a/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_sort.py b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_sort.py new file mode 100644 index 00000000..af3d0c18 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_sort.py @@ -0,0 +1,9 @@ +from cpl.auth.schema import ApiKey +from cpl.graphql.schema.sort.db_model_sort import DbModelSort +from cpl.graphql.schema.sort.sort_order import SortOrder + + +class ApiKeySort(DbModelSort[ApiKey]): + def __init__(self): + DbModelSort.__init__(self) + self.field("identifier", SortOrder) diff --git a/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py index ba0da432..9b41cc8e 100644 --- a/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py +++ b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py @@ -1,20 +1,40 @@ +from cpl.auth.schema import UserDao, ApiKeyDao, RoleDao from cpl.core.configuration import Configuration from cpl.dependency import ServiceProvider from cpl.dependency.module.module import Module from cpl.dependency.service_collection import ServiceCollection -from cpl.graphql.auth.administration.api_key.api_key_filter import ApiKeyFilter -from cpl.graphql.auth.administration.api_key.api_key_graph_type import ApiKeyGraphType -from cpl.graphql.auth.administration.api_key.api_key_mutation import ApiKeyMutation -from cpl.graphql.auth.administration.user.user_filter import UserFilter -from cpl.graphql.auth.administration.user.user_graph_type import UserGraphType -from cpl.graphql.auth.administration.user.user_mutation import UserMutation +from cpl.graphql.auth.api_key.api_key_filter import ApiKeyFilter +from cpl.graphql.auth.api_key.api_key_graph_type import ApiKeyGraphType +from cpl.graphql.auth.api_key.api_key_mutation import ApiKeyMutation +from cpl.graphql.auth.api_key.api_key_sort import ApiKeySort +from cpl.graphql.auth.role.role_filter import RoleFilter +from cpl.graphql.auth.role.role_graph_type import RoleGraphType +from cpl.graphql.auth.role.role_mutation import RoleMutation +from cpl.graphql.auth.role.role_sort import RoleSort +from cpl.graphql.auth.user.user_filter import UserFilter +from cpl.graphql.auth.user.user_graph_type import UserGraphType +from cpl.graphql.auth.user.user_mutation import UserMutation +from cpl.graphql.auth.user.user_sort import UserSort from cpl.graphql.graphql_module import GraphQLModule from cpl.graphql.service.schema import Schema class GraphQLAuthModule(Module): dependencies = [GraphQLModule] - transient = [UserGraphType, UserMutation, UserFilter, ApiKeyGraphType, ApiKeyMutation, ApiKeyFilter] + transient = [ + UserGraphType, + UserMutation, + UserFilter, + UserSort, + ApiKeyGraphType, + ApiKeyMutation, + ApiKeyFilter, + ApiKeySort, + RoleGraphType, + RoleMutation, + RoleFilter, + RoleSort, + ] @staticmethod def register(collection: ServiceCollection): @@ -24,3 +44,27 @@ class GraphQLAuthModule(Module): def configure(provider: ServiceProvider): schema = provider.get_service(Schema) schema.with_type(UserGraphType) + schema.with_type(ApiKeyGraphType) + schema.with_type(RoleGraphType) + + @staticmethod + def with_auth_root_queries(provider: ServiceProvider, public: bool = False): + if not Configuration.get("GraphQLAuthModuleEnabled", False): + raise Exception("GraphQLAuthModule is not loaded yet. Make sure to run 'add_module(GraphQLAuthModule)'") + + schema = provider.get_service(Schema) + schema.query.dao_collection_field(UserGraphType, UserDao, "users", UserFilter, UserSort).with_public(public) + schema.query.dao_collection_field(ApiKeyGraphType, ApiKeyDao, "apiKeys", ApiKeyFilter, ApiKeySort).with_public( + public + ) + schema.query.dao_collection_field(RoleGraphType, RoleDao, "roles", RoleFilter, RoleSort).with_public(public) + + @staticmethod + def with_auth_root_mutations(provider: ServiceProvider, public: bool = False): + if not Configuration.get("GraphQLAuthModuleEnabled", False): + raise Exception("GraphQLAuthModule is not loaded yet. Make sure to run 'add_module(GraphQLAuthModule)'") + + schema = provider.get_service(Schema) + schema.mutation.with_mutation("user", UserMutation).with_public(public) + schema.mutation.with_mutation("apiKey", ApiKeyMutation).with_public(public) + schema.mutation.with_mutation("role", RoleMutation).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/api_key/__init__.py b/src/cpl-graphql/cpl/graphql/auth/role/__init__.py similarity index 100% rename from src/cpl-graphql/cpl/graphql/auth/administration/api_key/__init__.py rename to src/cpl-graphql/cpl/graphql/auth/role/__init__.py diff --git a/src/cpl-graphql/cpl/graphql/auth/role/role_filter.py b/src/cpl-graphql/cpl/graphql/auth/role/role_filter.py new file mode 100644 index 00000000..f31dbf4f --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/role/role_filter.py @@ -0,0 +1,11 @@ +from cpl.auth.schema import User, Role +from cpl.graphql.schema.filter.db_model_filter import DbModelFilter +from cpl.graphql.schema.filter.string_filter import StringFilter + + +class RoleFilter(DbModelFilter[Role]): + def __init__(self, public: bool = False): + DbModelFilter.__init__(self, public) + + self.field("name", StringFilter).with_public(public) + self.field("description", StringFilter).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/auth/role/role_graph_type.py b/src/cpl-graphql/cpl/graphql/auth/role/role_graph_type.py new file mode 100644 index 00000000..27ce9309 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/role/role_graph_type.py @@ -0,0 +1,14 @@ +from cpl.auth.schema import Role +from cpl.graphql.auth.user.user_graph_type import UserGraphType +from cpl.graphql.schema.db_model_graph_type import DbModelGraphType + + +class RoleGraphType(DbModelGraphType[Role]): + + def __init__(self, public: bool = False): + DbModelGraphType.__init__(self) + + self.string_field("name", lambda root: root.name).with_public(public) + self.string_field("description", lambda root: root.description).with_public(public) + self.list_field("permissions", str, lambda root: root.permissions).with_public(public) + self.list_field("users", UserGraphType, lambda root: root.users).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/auth/role/role_input.py b/src/cpl-graphql/cpl/graphql/auth/role/role_input.py new file mode 100644 index 00000000..7ae1334f --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/role/role_input.py @@ -0,0 +1,29 @@ +from cpl.auth.schema import User, Role +from cpl.core.typing import SerialId +from cpl.graphql.schema.input import Input + + +class RoleCreateInput(Input[Role]): + name: str + description: str | None + permissions: list[SerialId] | None + + def __init__(self): + Input.__init__(self) + self.string_field("name").with_required() + self.string_field("description") + self.list_field("permissions", SerialId) + + +class RoleUpdateInput(Input[Role]): + id: SerialId + name: str | None + description: str | None + permissions: list[SerialId] | None + + def __init__(self): + Input.__init__(self) + self.int_field("id").with_required() + self.string_field("name") + self.string_field("description") + self.list_field("permissions", SerialId) diff --git a/src/cpl-graphql/cpl/graphql/auth/role/role_mutation.py b/src/cpl-graphql/cpl/graphql/auth/role/role_mutation.py new file mode 100644 index 00000000..df7d06d8 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/role/role_mutation.py @@ -0,0 +1,101 @@ +from cpl.api import APILogger +from cpl.auth.keycloak import KeycloakAdmin +from cpl.auth.permission import Permissions +from cpl.auth.schema import RoleDao, Role, RolePermissionDao, RolePermission +from cpl.graphql.auth.role.role_input import RoleCreateInput, RoleUpdateInput +from cpl.graphql.schema.mutation import Mutation + + +class RoleMutation(Mutation): + def __init__( + self, + logger: APILogger, + role_dao: RoleDao, + role_permission_dao: RolePermissionDao, + permission_dao: RolePermissionDao, + keycloak_admin: KeycloakAdmin, + ): + Mutation.__init__(self) + self._logger = logger + self._role_dao = role_dao + self._role_permission_dao = role_permission_dao + self._permission_dao = permission_dao + self._keycloak_admin = keycloak_admin + + self.int_field( + "create", + self.resolve_create, + ).with_require_any_permission(Permissions.roles_create).with_argument( + "input", + RoleCreateInput, + ).with_required() + + self.bool_field( + "update", + self.resolve_update, + ).with_require_any_permission(Permissions.roles_update).with_argument( + "input", + RoleUpdateInput, + ).with_required() + + self.bool_field( + "delete", + self.resolve_delete, + ).with_require_any_permission(Permissions.roles_delete).with_argument( + "id", + int, + ).with_required() + + self.bool_field( + "restore", + self.resolve_restore, + ).with_require_any_permission(Permissions.roles_delete).with_argument( + "id", + int, + ).with_required() + + async def resolve_create(self, input: RoleCreateInput, *_): + self._logger.debug(f"create role: {input.__dict__}") + + role = Role( + 0, + input.name, + input.description, + ) + await self._role_dao.create(role) + role = await self._role_dao.get_by_name(role.name) + await self._role_permission_dao.create_many([RolePermission(0, role.id, x) for x in input.permissions]) + + return role + + async def resolve_update(self, input: RoleUpdateInput, *_): + self._logger.debug(f"update role: {input.__dict__}") + role = await self._role_dao.get_by_id(input.id) + role.name = input.get("name", role.name) + role.description = input.get("description", role.description) + await self._role_dao.update(role) + + await self._resolve_assignments( + input.get("permissions", []), + role, + RolePermission.role_id, + RolePermission.permission_id, + self._role_dao, + self._role_permission_dao, + RolePermission, + self._permission_dao, + ) + + return role + + async def resolve_delete(self, id: int): + self._logger.debug(f"delete role: {id}") + role = await self._role_dao.get_by_id(id) + await self._role_dao.delete(role) + return True + + async def resolve_restore(self, id: int): + self._logger.debug(f"restore role: {id}") + role = await self._role_dao.get_by_id(id) + await self._role_dao.restore(role) + return True diff --git a/src/cpl-graphql/cpl/graphql/auth/role/role_sort.py b/src/cpl-graphql/cpl/graphql/auth/role/role_sort.py new file mode 100644 index 00000000..6c55568e --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/role/role_sort.py @@ -0,0 +1,10 @@ +from cpl.auth.schema import Role +from cpl.graphql.schema.sort.db_model_sort import DbModelSort +from cpl.graphql.schema.sort.sort_order import SortOrder + + +class RoleSort(DbModelSort[Role]): + def __init__(self): + DbModelSort.__init__(self) + self.field("name", SortOrder) + self.field("description", SortOrder) diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/user/__init__.py b/src/cpl-graphql/cpl/graphql/auth/user/__init__.py similarity index 100% rename from src/cpl-graphql/cpl/graphql/auth/administration/user/__init__.py rename to src/cpl-graphql/cpl/graphql/auth/user/__init__.py diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/user/user_filter.py b/src/cpl-graphql/cpl/graphql/auth/user/user_filter.py similarity index 100% rename from src/cpl-graphql/cpl/graphql/auth/administration/user/user_filter.py rename to src/cpl-graphql/cpl/graphql/auth/user/user_filter.py diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/user/user_graph_type.py b/src/cpl-graphql/cpl/graphql/auth/user/user_graph_type.py similarity index 89% rename from src/cpl-graphql/cpl/graphql/auth/administration/user/user_graph_type.py rename to src/cpl-graphql/cpl/graphql/auth/user/user_graph_type.py index d27ce05a..73a44c37 100644 --- a/src/cpl-graphql/cpl/graphql/auth/administration/user/user_graph_type.py +++ b/src/cpl-graphql/cpl/graphql/auth/user/user_graph_type.py @@ -2,7 +2,7 @@ from cpl.auth.schema import User from cpl.graphql.schema.db_model_graph_type import DbModelGraphType -class UserGraphType(DbModelGraphType): +class UserGraphType(DbModelGraphType[User]): def __init__(self): DbModelGraphType.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/user/user_input.py b/src/cpl-graphql/cpl/graphql/auth/user/user_input.py similarity index 100% rename from src/cpl-graphql/cpl/graphql/auth/administration/user/user_input.py rename to src/cpl-graphql/cpl/graphql/auth/user/user_input.py diff --git a/src/cpl-graphql/cpl/graphql/auth/administration/user/user_mutation.py b/src/cpl-graphql/cpl/graphql/auth/user/user_mutation.py similarity index 97% rename from src/cpl-graphql/cpl/graphql/auth/administration/user/user_mutation.py rename to src/cpl-graphql/cpl/graphql/auth/user/user_mutation.py index c33fd76c..59afb752 100644 --- a/src/cpl-graphql/cpl/graphql/auth/administration/user/user_mutation.py +++ b/src/cpl-graphql/cpl/graphql/auth/user/user_mutation.py @@ -3,7 +3,7 @@ from cpl.auth.keycloak import KeycloakAdmin from cpl.auth.permission import Permissions from cpl.auth.schema import UserDao, User, RoleUser, RoleUserDao, RoleDao from cpl.core.ctx.user_context import get_user -from cpl.graphql.auth.administration.user.user_input import UserCreateInput, UserUpdateInput +from cpl.graphql.auth.user.user_input import UserCreateInput, UserUpdateInput from cpl.graphql.schema.mutation import Mutation diff --git a/src/cpl-graphql/cpl/graphql/auth/user/user_sort.py b/src/cpl-graphql/cpl/graphql/auth/user/user_sort.py new file mode 100644 index 00000000..fe0cb8b1 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/user/user_sort.py @@ -0,0 +1,10 @@ +from cpl.auth.schema import User +from cpl.graphql.schema.sort.db_model_sort import DbModelSort +from cpl.graphql.schema.sort.sort_order import SortOrder + + +class UserSort(DbModelSort[User]): + def __init__(self): + DbModelSort.__init__(self) + self.field("username", SortOrder) + self.field("email", SortOrder) diff --git a/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py b/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py index a2e5ee1f..2b9a39bb 100644 --- a/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py +++ b/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py @@ -25,7 +25,7 @@ class DbModelGraphType(GraphType[T], Generic[T]): self.bool_field("deleted", lambda root: root.deleted).with_public(public) if Configuration.get("GraphQLAuthModuleEnabled", False): - from cpl.graphql.auth.administration.user.user_graph_type import UserGraphType + from cpl.graphql.auth.user.user_graph_type import UserGraphType self.object_field("editor", lambda: UserGraphType, lambda root: root.editor).with_public(public) self.string_field("created", lambda root: root.created).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py index 6e7681a7..a7a22cb7 100644 --- a/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py +++ b/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py @@ -15,7 +15,7 @@ class DbModelFilter(Filter[T], Generic[T]): self.field("id", IntFilter).with_public(public) self.field("deleted", BoolFilter).with_public(public) if Configuration.get("GraphQLAuthModuleEnabled", False): - from cpl.graphql.auth.administration.user.user_filter import UserFilter + from cpl.graphql.auth.user.user_filter import UserFilter self.field("editor", lambda: UserFilter).with_public(public) self.field("created", DateFilter).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/schema/sort/db_model_sort.py b/src/cpl-graphql/cpl/graphql/schema/sort/db_model_sort.py new file mode 100644 index 00000000..02726ec8 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/sort/db_model_sort.py @@ -0,0 +1,19 @@ +from typing import Generic + +from cpl.core.configuration import Configuration +from cpl.core.typing import T +from cpl.graphql.schema.sort.sort import Sort +from cpl.graphql.schema.sort.sort_order import SortOrder + + +class DbModelSort(Sort[T], Generic[T]): + def __init__( + self, + ): + Sort.__init__(self) + self.field("id", SortOrder) + self.field("deleted", SortOrder) + if Configuration.get("GraphQLAuthModuleEnabled", False): + self.field("editor", SortOrder) + self.field("created", SortOrder) + self.field("updated", SortOrder) From cdb5e4ff894e9f08c85d8dc017c858c91dae6d81 Mon Sep 17 00:00:00 2001 From: edraft Date: Sat, 4 Oct 2025 06:57:14 +0200 Subject: [PATCH 18/20] [WIP] Subscriptions --- example/api/src/main.py | 14 ++- example/api/src/model/post_query.py | 29 ++++- example/api/src/queries/hello.py | 5 +- src/cpl-api/cpl/api/middleware/request.py | 4 +- .../cpl/database/mysql/mysql_pool.py | 9 +- .../cpl/dependency/event_bus.py | 10 ++ .../cpl/graphql/abc/strawberry_protocol.py | 4 +- .../cpl/graphql/application/graphql_app.py | 4 +- .../graphql/auth/api_key/api_key_mutation.py | 17 ++- .../cpl/graphql/event_bus/__init__.py | 0 .../cpl/graphql/event_bus/memory.py | 27 +++++ src/cpl-graphql/cpl/graphql/graphql_module.py | 3 +- .../cpl/graphql/schema/collection.py | 9 +- .../cpl/graphql/schema/db_model_graph_type.py | 1 + .../graphql/schema/filter/db_model_filter.py | 1 + .../cpl/graphql/schema/mutation.py | 28 ++--- .../cpl/graphql/schema/root_subscription.py | 6 + .../cpl/graphql/schema/subscription.py | 109 ++++++++++++++++++ .../cpl/graphql/schema/subscription_field.py | 11 ++ src/cpl-graphql/cpl/graphql/service/schema.py | 10 +- src/cpl-graphql/cpl/graphql/typing.py | 5 +- 21 files changed, 254 insertions(+), 52 deletions(-) create mode 100644 src/cpl-dependency/cpl/dependency/event_bus.py create mode 100644 src/cpl-graphql/cpl/graphql/event_bus/__init__.py create mode 100644 src/cpl-graphql/cpl/graphql/event_bus/memory.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/root_subscription.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/subscription.py create mode 100644 src/cpl-graphql/cpl/graphql/schema/subscription_field.py diff --git a/example/api/src/main.py b/example/api/src/main.py index ac906f9b..7dc53652 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -1,8 +1,10 @@ from starlette.responses import JSONResponse -from api.src.queries.cities import CityGraphType, CityFilter, CitySort -from api.src.queries.hello import UserGraphType#, UserFilter, UserSort, UserGraphType -from api.src.queries.user import UserFilter, UserSort +from cpl.dependency.event_bus import EventBusABC +from cpl.graphql.event_bus.memory import InMemoryEventBus +from queries.cities import CityGraphType, CityFilter, CitySort +from queries.hello import UserGraphType # , UserFilter, UserSort, UserGraphType +from queries.user import UserFilter, UserSort from cpl.api.api_module import ApiModule from cpl.application.application_builder import ApplicationBuilder from cpl.auth.schema import User, Role @@ -17,7 +19,7 @@ from cpl.graphql.graphql_module import GraphQLModule from model.author_dao import AuthorDao from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort from model.post_dao import PostDao -from model.post_query import PostFilter, PostSort, PostGraphType, PostMutation +from model.post_query import PostFilter, PostSort, PostGraphType, PostMutation, PostSubscription from permissions import PostPermissions from queries.hello import HelloQuery from scoped_service import ScopedService @@ -41,6 +43,7 @@ def main(): .add_module(GraphQLModule) .add_module(GraphQLAuthModule) .add_scoped(ScopedService) + .add_singleton(EventBusABC, InMemoryEventBus) .add_cache(User) .add_cache(Role) .add_transient(CityGraphType) @@ -66,6 +69,7 @@ def main(): .add_transient(PostFilter) .add_transient(PostSort) .add_transient(PostMutation) + .add_transient(PostSubscription) ) app = builder.build() @@ -96,6 +100,8 @@ def main(): schema.mutation.with_mutation("post", PostMutation).with_public() + schema.subscription.with_subscription("post", PostSubscription) + app.with_auth_root_queries(True) app.with_auth_root_mutations(True) diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py index d12f308c..eab7525f 100644 --- a/example/api/src/model/post_query.py +++ b/example/api/src/model/post_query.py @@ -1,3 +1,4 @@ +from cpl.dependency.event_bus import EventBusABC from cpl.graphql.query_context import QueryContext from cpl.graphql.schema.db_model_graph_type import DbModelGraphType from cpl.graphql.schema.filter.db_model_filter import DbModelFilter @@ -5,6 +6,7 @@ from cpl.graphql.schema.input import Input from cpl.graphql.schema.mutation import Mutation from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort_order import SortOrder +from cpl.graphql.schema.subscription import Subscription from model.author_dao import AuthorDao from model.author_query import AuthorGraphType, AuthorFilter from model.post import Post @@ -19,6 +21,7 @@ class PostFilter(DbModelFilter[Post]): self.string_field("title") self.string_field("content") + class PostSort(Sort[Post]): def __init__(self): Sort.__init__(self) @@ -26,6 +29,7 @@ class PostSort(Sort[Post]): self.field("title", SortOrder) self.field("content", SortOrder) + class PostGraphType(DbModelGraphType[Post]): def __init__(self, authors: AuthorDao): @@ -42,7 +46,7 @@ class PostGraphType(DbModelGraphType[Post]): def r_name(ctx: QueryContext): return ctx.user.username == "admin" - self.object_field("author", AuthorGraphType, resolver=_a).with_public(True)# .with_require_any([], [r_name])) + self.object_field("author", AuthorGraphType, resolver=_a).with_public(True) # .with_require_any([], [r_name])) self.string_field( "title", resolver=lambda root: root.title, @@ -64,6 +68,7 @@ class PostCreateInput(Input[Post]): self.string_field("content").with_required() self.int_field("author_id").with_required() + class PostUpdateInput(Input[Post]): title: str content: str @@ -75,13 +80,31 @@ class PostUpdateInput(Input[Post]): self.string_field("title").with_required(False) self.string_field("content").with_required(False) + +class PostSubscription(Subscription): + def __init__(self, bus: EventBusABC): + Subscription.__init__(self) + self._bus = bus + + async def post_changed(): + async for event in await self._bus.subscribe("postChange"): + print("Event:", event, type(event)) + yield event + + def selector(event: Post, info) -> bool: + return True + + self.subscription_field("postChange", PostGraphType, post_changed, selector) + + class PostMutation(Mutation): - def __init__(self, posts: PostDao, authors: AuthorDao): + def __init__(self, posts: PostDao, authors: AuthorDao, bus: EventBusABC): Mutation.__init__(self) self._posts = posts self._authors = authors + self._bus = bus self.field("create", int, resolver=self.create_post).with_public().with_required().with_argument( "input", @@ -112,6 +135,7 @@ class PostMutation(Mutation): post.content = input.content if input.content is not None else post.content await self._posts.update(post) + await self._bus.publish("postChange", post) return True async def delete_post(self, id: int) -> bool: @@ -127,4 +151,3 @@ class PostMutation(Mutation): return False await self._posts.restore(post) return True - diff --git a/example/api/src/queries/hello.py b/example/api/src/queries/hello.py index c53ce008..19a1f774 100644 --- a/example/api/src/queries/hello.py +++ b/example/api/src/queries/hello.py @@ -1,5 +1,5 @@ -from api.src.queries.cities import CityFilter, CitySort, CityGraphType, City -from api.src.queries.user import User, UserFilter, UserSort, UserGraphType +from queries.cities import CityFilter, CitySort, CityGraphType, City +from queries.user import User, UserFilter, UserSort, UserGraphType from cpl.api.middleware.request import get_request from cpl.auth.schema import UserDao, User from cpl.graphql.schema.filter.filter import Filter @@ -38,6 +38,7 @@ cities = [City(i, f"City {i}") for i in range(1, 101)] # resolver=lambda root: root.username, # ) + class HelloQuery(Query): def __init__(self): Query.__init__(self) diff --git a/src/cpl-api/cpl/api/middleware/request.py b/src/cpl-api/cpl/api/middleware/request.py index 05a291e3..4f3ae5a4 100644 --- a/src/cpl-api/cpl/api/middleware/request.py +++ b/src/cpl-api/cpl/api/middleware/request.py @@ -21,9 +21,7 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa class RequestMiddleware(ASGIMiddleware): - def __init__( - self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao - ): + def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao): ASGIMiddleware.__init__(self, app) self._provider = provider diff --git a/src/cpl-database/cpl/database/mysql/mysql_pool.py b/src/cpl-database/cpl/database/mysql/mysql_pool.py index 474bf6ce..b482229c 100644 --- a/src/cpl-database/cpl/database/mysql/mysql_pool.py +++ b/src/cpl-database/cpl/database/mysql/mysql_pool.py @@ -119,6 +119,13 @@ class MySQLPool: try: async with await con.cursor(dictionary=True) as cursor: res = await self._exec_sql(cursor, query, args, multi) - return list(res) + decoded_res = [] + for row in res: + decoded_row = { + k: (v.decode("utf-8") if isinstance(v, (bytes, bytearray)) else v) for k, v in row.items() + } + decoded_res.append(decoded_row) + + return decoded_res finally: await con.close() diff --git a/src/cpl-dependency/cpl/dependency/event_bus.py b/src/cpl-dependency/cpl/dependency/event_bus.py new file mode 100644 index 00000000..efd372aa --- /dev/null +++ b/src/cpl-dependency/cpl/dependency/event_bus.py @@ -0,0 +1,10 @@ +from abc import abstractmethod, ABC +from typing import Any, AsyncGenerator + + +class EventBusABC(ABC): + @abstractmethod + async def publish(self, channel: str, event: Any) -> None: ... + + @abstractmethod + async def subscribe(self, channel: str) -> AsyncGenerator[Any, None]: ... diff --git a/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py b/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py index 1c0b6592..ad8f18b8 100644 --- a/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py +++ b/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py @@ -1,9 +1,11 @@ from typing import Protocol, Type, runtime_checkable from cpl.graphql.schema.field import Field +from cpl.graphql.schema.subscription_field import SubscriptionField @runtime_checkable class StrawberryProtocol(Protocol): def to_strawberry(self) -> Type: ... - def get_fields(self) -> dict[str, Field]: ... + + def get_fields(self) -> dict[str, Field | SubscriptionField]: ... diff --git a/src/cpl-graphql/cpl/graphql/application/graphql_app.py b/src/cpl-graphql/cpl/graphql/application/graphql_app.py index 8c15dec8..5a67de13 100644 --- a/src/cpl-graphql/cpl/graphql/application/graphql_app.py +++ b/src/cpl-graphql/cpl/graphql/application/graphql_app.py @@ -114,4 +114,6 @@ class GraphQLApp(WebApp): if self._with_graphiql: self._logger.warning(f"GraphiQL available at http://{host}:{self._api_settings.port}/api/graphiql") if self._with_playground: - self._logger.warning(f"GraphQL Playground available at http://{host}:{self._api_settings.port}/api/playground") + self._logger.warning( + f"GraphQL Playground available at http://{host}:{self._api_settings.port}/api/playground" + ) diff --git a/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py index c431eee8..67444e3b 100644 --- a/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py +++ b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py @@ -8,13 +8,12 @@ from cpl.graphql.schema.mutation import Mutation class ApiKeyMutation(Mutation): def __init__( - self, - - logger: APILogger, - api_key_dao: ApiKeyDao, - api_key_permission_dao: ApiKeyPermissionDao, - permission_dao: ApiKeyPermissionDao, - keycloak_admin: KeycloakAdmin, + self, + logger: APILogger, + api_key_dao: ApiKeyDao, + api_key_permission_dao: ApiKeyPermissionDao, + permission_dao: ApiKeyPermissionDao, + keycloak_admin: KeycloakAdmin, ): Mutation.__init__(self) self._logger = logger @@ -61,9 +60,7 @@ class ApiKeyMutation(Mutation): api_key = ApiKey.new(obj.identifier) await self._api_key_dao.create(api_key) api_key = await self._api_key_dao.get_single_by([{ApiKey.identifier: obj.identifier}]) - await self._api_key_permission_dao.create_many( - [ApiKeyPermission(0, api_key.id, x) for x in obj.permissions] - ) + await self._api_key_permission_dao.create_many([ApiKeyPermission(0, api_key.id, x) for x in obj.permissions]) return api_key async def resolve_update(self, input: ApiKeyUpdateInput): diff --git a/src/cpl-graphql/cpl/graphql/event_bus/__init__.py b/src/cpl-graphql/cpl/graphql/event_bus/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/event_bus/memory.py b/src/cpl-graphql/cpl/graphql/event_bus/memory.py new file mode 100644 index 00000000..4d74c1af --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/event_bus/memory.py @@ -0,0 +1,27 @@ +import asyncio +from typing import Any, AsyncGenerator + +from cpl.dependency.event_bus import EventBusABC + + +class InMemoryEventBus(EventBusABC): + def __init__(self): + self._subscribers: dict[str, list[asyncio.Queue]] = {} + + async def publish(self, channel: str, event: Any) -> None: + queues = self._subscribers.get(channel, []) + for q in queues.copy(): + await q.put(event) + + async def subscribe(self, channel: str) -> AsyncGenerator[Any, None]: + q = asyncio.Queue() + if channel not in self._subscribers: + self._subscribers[channel] = [] + self._subscribers[channel].append(q) + + try: + while True: + item = await q.get() + yield item + finally: + self._subscribers[channel].remove(q) diff --git a/src/cpl-graphql/cpl/graphql/graphql_module.py b/src/cpl-graphql/cpl/graphql/graphql_module.py index b749d16e..3672e119 100644 --- a/src/cpl-graphql/cpl/graphql/graphql_module.py +++ b/src/cpl-graphql/cpl/graphql/graphql_module.py @@ -8,13 +8,14 @@ from cpl.graphql.schema.filter.int_filter import IntFilter from cpl.graphql.schema.filter.string_filter import StringFilter from cpl.graphql.schema.root_mutation import RootMutation from cpl.graphql.schema.root_query import RootQuery +from cpl.graphql.schema.root_subscription import RootSubscription from cpl.graphql.service.graphql import GraphQLService from cpl.graphql.service.schema import Schema class GraphQLModule(Module): dependencies = [ApiModule] - singleton = [Schema, RootQuery, RootMutation] + singleton = [Schema, RootQuery, RootMutation, RootSubscription] scoped = [GraphQLService] transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter] diff --git a/src/cpl-graphql/cpl/graphql/schema/collection.py b/src/cpl-graphql/cpl/graphql/schema/collection.py index 9d600ab9..650fc71e 100644 --- a/src/cpl-graphql/cpl/graphql/schema/collection.py +++ b/src/cpl-graphql/cpl/graphql/schema/collection.py @@ -9,6 +9,7 @@ from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.utils.type_collector import TypeCollector + class CollectionGraphTypeFactory: @classmethod def get(cls, node_type: Type[StrawberryProtocol]) -> Type: @@ -23,11 +24,7 @@ class CollectionGraphTypeFactory: gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type - gql_cls = type( - type_name, - (), - {} - ) + gql_cls = type(type_name, (), {}) TypeCollector.set(type_name, gql_cls) @@ -45,8 +42,6 @@ class CollectionGraphTypeFactory: return gql_type - - class Collection: def __init__(self, nodes: list[T], total_count: int, count: int): self._nodes = nodes diff --git a/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py b/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py index 2b9a39bb..ed4153a2 100644 --- a/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py +++ b/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py @@ -26,6 +26,7 @@ class DbModelGraphType(GraphType[T], Generic[T]): if Configuration.get("GraphQLAuthModuleEnabled", False): from cpl.graphql.auth.user.user_graph_type import UserGraphType + self.object_field("editor", lambda: UserGraphType, lambda root: root.editor).with_public(public) self.string_field("created", lambda root: root.created).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py index a7a22cb7..4a91544c 100644 --- a/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py +++ b/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py @@ -16,6 +16,7 @@ class DbModelFilter(Filter[T], Generic[T]): self.field("deleted", BoolFilter).with_public(public) if Configuration.get("GraphQLAuthModuleEnabled", False): from cpl.graphql.auth.user.user_filter import UserFilter + self.field("editor", lambda: UserFilter).with_public(public) self.field("created", DateFilter).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/schema/mutation.py b/src/cpl-graphql/cpl/graphql/schema/mutation.py index 82a707e9..d336f3b1 100644 --- a/src/cpl-graphql/cpl/graphql/schema/mutation.py +++ b/src/cpl-graphql/cpl/graphql/schema/mutation.py @@ -28,14 +28,14 @@ class Mutation(QueryABC): @staticmethod async def _resolve_assignments( - foreign_objs: list[int], - resolved_obj: T, - reference_key_own: Union[str, property], - reference_key_foreign: Union[str, property], - source_dao: DataAccessObjectABC[T], - join_dao: DataAccessObjectABC[T], - join_type: Type[DbJoinModelABC], - foreign_dao: DataAccessObjectABC[T], + foreign_objs: list[int], + resolved_obj: T, + reference_key_own: Union[str, property], + reference_key_foreign: Union[str, property], + source_dao: DataAccessObjectABC[T], + join_dao: DataAccessObjectABC[T], + join_type: Type[DbJoinModelABC], + foreign_dao: DataAccessObjectABC[T], ): if foreign_objs is None: return @@ -44,9 +44,7 @@ class Mutation(QueryABC): if isinstance(reference_key_foreign, property): reference_key_foreign_attr = reference_key_foreign.fget.__name__ - foreign_list = await join_dao.find_by( - [{reference_key_own: resolved_obj.id}, {"deleted": False}] - ) + foreign_list = await join_dao.find_by([{reference_key_own: resolved_obj.id}, {"deleted": False}]) to_delete = ( foreign_list @@ -61,9 +59,7 @@ class Mutation(QueryABC): foreign_ids = [getattr(x, reference_key_foreign_attr) for x in foreign_list] deleted_foreign_ids = [ getattr(x, reference_key_foreign_attr) - for x in await join_dao.find_by( - [{reference_key_own: resolved_obj.id}, {"deleted": True}] - ) + for x in await join_dao.find_by([{reference_key_own: resolved_obj.id}, {"deleted": True}]) ] to_create = [ @@ -94,6 +90,4 @@ class Mutation(QueryABC): foreign_changes = [*to_delete, *to_create, *to_restore] if len(foreign_changes) > 0: await source_dao.touch(resolved_obj) - await foreign_dao.touch_many_by_id( - [getattr(x, reference_key_foreign_attr) for x in foreign_changes] - ) + await foreign_dao.touch_many_by_id([getattr(x, reference_key_foreign_attr) for x in foreign_changes]) diff --git a/src/cpl-graphql/cpl/graphql/schema/root_subscription.py b/src/cpl-graphql/cpl/graphql/schema/root_subscription.py new file mode 100644 index 00000000..fab2bc8f --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/root_subscription.py @@ -0,0 +1,6 @@ +from cpl.graphql.schema.subscription import Subscription + + +class RootSubscription(Subscription): + def __init__(self): + Subscription.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/schema/subscription.py b/src/cpl-graphql/cpl/graphql/schema/subscription.py new file mode 100644 index 00000000..5d73a980 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/subscription.py @@ -0,0 +1,109 @@ +import inspect +import types +from abc import ABC +from typing import Any, Callable, Dict, Type, AsyncGenerator, Optional + +import strawberry +from strawberry.exceptions import StrawberryException + +from cpl.dependency import ServiceProvider, get_provider, inject +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol +from cpl.graphql.schema.subscription_field import SubscriptionField +from cpl.graphql.typing import Selector +from cpl.graphql.utils.type_collector import TypeCollector + + +class Subscription(ABC, StrawberryProtocol): + + @inject + def __init__(self, provider: ServiceProvider): + ABC.__init__(self) + self._provider = provider + self._fields: Dict[str, SubscriptionField] = {} + + @property + def fields(self) -> dict[str, SubscriptionField]: + return self._fields + + @property + def fields_count(self) -> int: + return len(self._fields) + + def get_fields(self) -> dict[str, SubscriptionField]: + return self._fields + + def subscription_field( + self, + name: str, + type_: Type, + resolver: Callable[..., AsyncGenerator], + selector: Selector = None, + ) -> SubscriptionField: + f = SubscriptionField(name, type_, resolver, selector) + self._fields[name] = f + return f + + def with_subscription(self, name: str, sub_cls: Type["Subscription"]) -> SubscriptionField: + sub = self._provider.get_service(sub_cls) + if not sub: + raise ValueError(f"Subscription '{sub_cls.__name__}' not registered in service provider") + + async def _resolver(root, info): + return sub + + self._fields[name] = SubscriptionField(name, sub.to_strawberry(), resolver=_resolver) + return self._fields[name] + + @staticmethod + def _type_to_strawberry(t: Type) -> Type: + _t = get_provider().get_service(t) + if isinstance(_t, StrawberryProtocol): + return _t.to_strawberry() + return t + + @staticmethod + def _build_resolver(f: SubscriptionField) -> Callable: + async def _resolver(root, info): + async for event in f.resolver(root, info): + if not f.selector or f.selector(event, info): + yield event + + return _resolver + + def to_strawberry(self) -> Type: + cls = self.__class__ + if TypeCollector.has(cls): + return TypeCollector.get(cls) + + gql_cls = type(cls.__name__, (), {}) + TypeCollector.set(cls, gql_cls) + + annotations: dict[str, Any] = {} + namespace: dict[str, Any] = {} + + for name, f in self._fields.items(): + t = f.type + if isinstance(t, types.GenericAlias): + t = t.__args__[0] + elif isinstance(t, type) and issubclass(t, StrawberryProtocol): + t = self._type_to_strawberry(t) + + annotations[name] = Optional[t] + + try: + namespace[name] = strawberry.subscription(resolver=self._build_resolver(f)) + + except StrawberryException as e: + raise Exception(f"Error converting subscription field '{f.name}': {e}") from e + + gql_cls.__annotations__ = annotations + for k, v in namespace.items(): + setattr(gql_cls, k, v) + + try: + gql_type = strawberry.type(gql_cls) + except Exception as e: + raise Exception(f"Error creating strawberry type for '{cls.__name__}': {e}") from e + + TypeCollector.set(cls, gql_type) + return gql_type diff --git a/src/cpl-graphql/cpl/graphql/schema/subscription_field.py b/src/cpl-graphql/cpl/graphql/schema/subscription_field.py new file mode 100644 index 00000000..3e60cda4 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/subscription_field.py @@ -0,0 +1,11 @@ +from typing import Type, Callable + +from cpl.graphql.typing import Selector + + +class SubscriptionField: + def __init__(self, name: str, type_: Type, resolver: Callable, selector: Selector = None): + self.name = name + self.type = type_ + self.resolver = resolver + self.selector = selector diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py index 9141f455..2a5a2dd0 100644 --- a/src/cpl-graphql/cpl/graphql/service/schema.py +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -8,6 +8,7 @@ from cpl.dependency.service_provider import ServiceProvider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.schema.root_mutation import RootMutation from cpl.graphql.schema.root_query import RootQuery +from cpl.graphql.schema.root_subscription import RootSubscription class Schema: @@ -38,6 +39,13 @@ class Schema: raise ValueError("RootMutation not registered in service provider") return mutation + @property + def subscription(self) -> RootSubscription: + subscription = self._provider.get_service(RootSubscription) + if not subscription: + raise ValueError("RootSubscription not registered in service provider") + return subscription + def with_type(self, t: Type[StrawberryProtocol]) -> Self: self._types[t.__name__] = t return self @@ -61,7 +69,7 @@ class Schema: self._schema = strawberry.Schema( query=query.to_strawberry() if query.fields_count > 0 else None, mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None, - subscription=None, + subscription=self.subscription.to_strawberry() if self.subscription.fields_count > 0 else None, types=self._get_types(), ) return self._schema diff --git a/src/cpl-graphql/cpl/graphql/typing.py b/src/cpl-graphql/cpl/graphql/typing.py index 3dd33106..bb8cda8e 100644 --- a/src/cpl-graphql/cpl/graphql/typing.py +++ b/src/cpl-graphql/cpl/graphql/typing.py @@ -1,11 +1,14 @@ from enum import Enum -from typing import Type, Callable, List, Tuple, Awaitable +from typing import Type, Callable, List, Tuple, Awaitable, Any + +import strawberry from cpl.auth.permission import Permissions from cpl.graphql.query_context import QueryContext TQuery = Type["Query"] Resolver = Callable +Selector = Callable[[Any, strawberry.types.Info], bool] ScalarType = str | int | float | bool | object AttributeName = str | property TRequireAnyPermissions = List[Enum | Permissions] | None From 3774cef56a4468546bf1f0f11a854a80fdbc975e Mon Sep 17 00:00:00 2001 From: edraft Date: Wed, 8 Oct 2025 17:27:11 +0200 Subject: [PATCH 19/20] Updated permissions #181 --- .../graphql/auth/api_key/api_key_mutation.py | 8 ++++---- .../cpl/graphql/auth/graphql_auth_module.py | 17 ++++++++++++----- .../cpl/graphql/auth/user/user_graph_type.py | 8 ++++---- src/cpl-graphql/cpl/graphql/schema/field.py | 6 ++++-- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py index 67444e3b..dd3a4665 100644 --- a/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py +++ b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py @@ -25,7 +25,7 @@ class ApiKeyMutation(Mutation): self.int_field( "create", self.resolve_create, - ).with_require_any_permission(Permissions.users_create).with_argument( + ).with_require_any_permission(Permissions.api_keys_create).with_argument( "input", ApiKeyCreateInput, ).with_required() @@ -33,7 +33,7 @@ class ApiKeyMutation(Mutation): self.bool_field( "update", self.resolve_update, - ).with_require_any_permission(Permissions.users_update).with_argument( + ).with_require_any_permission(Permissions.api_keys_update).with_argument( "input", ApiKeyUpdateInput, ).with_required() @@ -41,7 +41,7 @@ class ApiKeyMutation(Mutation): self.bool_field( "delete", self.resolve_delete, - ).with_require_any_permission(Permissions.users_delete).with_argument( + ).with_require_any_permission(Permissions.api_keys_delete).with_argument( "id", int, ).with_required() @@ -49,7 +49,7 @@ class ApiKeyMutation(Mutation): self.bool_field( "restore", self.resolve_restore, - ).with_require_any_permission(Permissions.users_delete).with_argument( + ).with_require_any_permission(Permissions.api_keys_delete).with_argument( "id", int, ).with_required() diff --git a/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py index 9b41cc8e..7ce2a0b4 100644 --- a/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py +++ b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py @@ -1,3 +1,4 @@ +from cpl.auth.permission import Permissions from cpl.auth.schema import UserDao, ApiKeyDao, RoleDao from cpl.core.configuration import Configuration from cpl.dependency import ServiceProvider @@ -53,11 +54,17 @@ class GraphQLAuthModule(Module): raise Exception("GraphQLAuthModule is not loaded yet. Make sure to run 'add_module(GraphQLAuthModule)'") schema = provider.get_service(Schema) - schema.query.dao_collection_field(UserGraphType, UserDao, "users", UserFilter, UserSort).with_public(public) - schema.query.dao_collection_field(ApiKeyGraphType, ApiKeyDao, "apiKeys", ApiKeyFilter, ApiKeySort).with_public( - public - ) - schema.query.dao_collection_field(RoleGraphType, RoleDao, "roles", RoleFilter, RoleSort).with_public(public) + schema.query.dao_collection_field( + UserGraphType, UserDao, "users", UserFilter, UserSort + ).with_require_any_permission(Permissions.users).with_public(public) + + schema.query.dao_collection_field( + ApiKeyGraphType, ApiKeyDao, "apiKeys", ApiKeyFilter, ApiKeySort + ).with_require_any_permission(Permissions.api_keys).with_public(public) + + schema.query.dao_collection_field( + RoleGraphType, RoleDao, "roles", RoleFilter, RoleSort + ).with_require_any_permission(Permissions.roles).with_public(public) @staticmethod def with_auth_root_mutations(provider: ServiceProvider, public: bool = False): diff --git a/src/cpl-graphql/cpl/graphql/auth/user/user_graph_type.py b/src/cpl-graphql/cpl/graphql/auth/user/user_graph_type.py index 73a44c37..f0ffa1ab 100644 --- a/src/cpl-graphql/cpl/graphql/auth/user/user_graph_type.py +++ b/src/cpl-graphql/cpl/graphql/auth/user/user_graph_type.py @@ -4,9 +4,9 @@ from cpl.graphql.schema.db_model_graph_type import DbModelGraphType class UserGraphType(DbModelGraphType[User]): - def __init__(self): + def __init__(self, public: bool = False): DbModelGraphType.__init__(self) - self.string_field(User.keycloak_id, lambda root: root.keycloak_id) - self.string_field(User.username, lambda root: root.username) - self.string_field(User.email, lambda root: root.email) + self.string_field(User.keycloak_id, lambda root: root.keycloak_id).with_public(public) + self.string_field(User.username, lambda root: root.username).with_public(public) + self.string_field(User.email, lambda root: root.email).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/schema/field.py b/src/cpl-graphql/cpl/graphql/schema/field.py index cea91c93..7866fafa 100644 --- a/src/cpl-graphql/cpl/graphql/schema/field.py +++ b/src/cpl-graphql/cpl/graphql/schema/field.py @@ -133,7 +133,9 @@ class Field: return self def with_public(self, public: bool = True) -> Self: - assert self._require_any is None, "Field cannot be public and have require_any set" - assert self._require_any_permission is None, "Field cannot be public and have require_any_permission set" + if public: + self._require_any = None + self._require_any_permission = None + self._public = public return self From 545540d05dde8774b7f31eb9b528aeff1796e19c Mon Sep 17 00:00:00 2001 From: edraft Date: Wed, 8 Oct 2025 21:22:51 +0200 Subject: [PATCH 20/20] Added subscriptions final #181 --- example/api/src/main.py | 2 +- example/api/src/model/post_query.py | 9 +- src/cpl-api/cpl/api/application/web_app.py | 24 +++ src/cpl-api/cpl/api/middleware/request.py | 3 +- src/cpl-api/cpl/api/model/websocket_route.py | 31 ++++ src/cpl-api/cpl/api/registry/route.py | 17 +- src/cpl-api/cpl/api/router.py | 16 ++ .../cpl/core/log/structured_logger.py | 2 +- .../cpl/graphql/_endpoints/graphiql.py | 41 ++++- .../graphql/_endpoints/lazy_graphql_app.py | 27 ++++ .../cpl/graphql/application/graphql_app.py | 7 + .../cpl/graphql/schema/subscription.py | 151 ++++++++---------- .../cpl/graphql/schema/subscription_field.py | 26 ++- src/cpl-graphql/cpl/graphql/service/schema.py | 3 +- 14 files changed, 243 insertions(+), 116 deletions(-) create mode 100644 src/cpl-api/cpl/api/model/websocket_route.py create mode 100644 src/cpl-graphql/cpl/graphql/_endpoints/lazy_graphql_app.py diff --git a/example/api/src/main.py b/example/api/src/main.py index 7dc53652..4c71bbc9 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -100,7 +100,7 @@ def main(): schema.mutation.with_mutation("post", PostMutation).with_public() - schema.subscription.with_subscription("post", PostSubscription) + schema.subscription.with_subscription(PostSubscription) app.with_auth_root_queries(True) app.with_auth_root_mutations(True) diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py index eab7525f..5fe134f6 100644 --- a/example/api/src/model/post_query.py +++ b/example/api/src/model/post_query.py @@ -86,15 +86,10 @@ class PostSubscription(Subscription): Subscription.__init__(self) self._bus = bus - async def post_changed(): - async for event in await self._bus.subscribe("postChange"): - print("Event:", event, type(event)) - yield event - def selector(event: Post, info) -> bool: - return True + return event.id == 101 - self.subscription_field("postChange", PostGraphType, post_changed, selector) + self.subscription_field("postChange", PostGraphType, selector).with_public() class PostMutation(Mutation): diff --git a/src/cpl-api/cpl/api/application/web_app.py b/src/cpl-api/cpl/api/application/web_app.py index f994444e..f94694f9 100644 --- a/src/cpl-api/cpl/api/application/web_app.py +++ b/src/cpl-api/cpl/api/application/web_app.py @@ -169,6 +169,30 @@ class WebApp(WebAppABC): return self + def with_websocket( + self, + path: str, + fn: TEndpoint, + authentication: bool = False, + roles: list[str | Enum] = None, + permissions: list[str | Enum] = None, + policies: list[str] = None, + match: ValidationMatch = None, + ) -> Self: + self._check_for_app() + assert path is not None, "path must not be None" + assert fn is not None, "fn must not be None" + + Router.websocket(path, registry=self._routes)(fn) + + if authentication: + Router.authenticate()(fn) + + if roles or permissions or policies: + Router.authorize(roles, permissions, policies, match)(fn) + + return self + def with_middleware(self, middleware: PartialMiddleware) -> Self: self._check_for_app() diff --git a/src/cpl-api/cpl/api/middleware/request.py b/src/cpl-api/cpl/api/middleware/request.py index 4f3ae5a4..d5e73721 100644 --- a/src/cpl-api/cpl/api/middleware/request.py +++ b/src/cpl-api/cpl/api/middleware/request.py @@ -5,6 +5,7 @@ from uuid import uuid4 from starlette.requests import Request from starlette.types import Scope, Receive, Send +from starlette.websockets import WebSocket from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.logger import APILogger @@ -33,7 +34,7 @@ class RequestMiddleware(ASGIMiddleware): self._ctx_token = None async def __call__(self, scope: Scope, receive: Receive, send: Send): - request = Request(scope, receive, send) + request = Request(scope, receive, send) if scope["type"] != "websocket" else WebSocket(scope, receive, send) await self.set_request_data(request) try: diff --git a/src/cpl-api/cpl/api/model/websocket_route.py b/src/cpl-api/cpl/api/model/websocket_route.py new file mode 100644 index 00000000..3c09ca3f --- /dev/null +++ b/src/cpl-api/cpl/api/model/websocket_route.py @@ -0,0 +1,31 @@ +from typing import Callable + +import starlette.routing + + +class WebSocketRoute: + + def __init__(self, path: str, fn: Callable, **kwargs): + self._path = path + self._fn = fn + + self._kwargs = kwargs + + @property + def name(self) -> str: + return self._fn.__name__ + + @property + def fn(self) -> Callable: + return self._fn + + @property + def path(self) -> str: + return self._path + + @property + def kwargs(self) -> dict: + return self._kwargs + + def to_starlette(self, *args) -> starlette.routing.WebSocketRoute: + return starlette.routing.WebSocketRoute(self._path, self._fn) diff --git a/src/cpl-api/cpl/api/registry/route.py b/src/cpl-api/cpl/api/registry/route.py index e030007b..83ce7862 100644 --- a/src/cpl-api/cpl/api/registry/route.py +++ b/src/cpl-api/cpl/api/registry/route.py @@ -1,32 +1,35 @@ -from typing import Optional +from typing import Optional, Union from cpl.api.model.api_route import ApiRoute +from cpl.api.model.websocket_route import WebSocketRoute from cpl.core.abc.registry_abc import RegistryABC +TRoute = Union[ApiRoute, WebSocketRoute] + class RouteRegistry(RegistryABC): def __init__(self): RegistryABC.__init__(self) - def extend(self, items: list[ApiRoute]): + def extend(self, items: list[TRoute]): for policy in items: self.add(policy) - def add(self, item: ApiRoute): - assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute" + def add(self, item: TRoute): + assert isinstance(item, (ApiRoute, WebSocketRoute)), "route must be an instance of ApiRoute" if item.path in self._items: raise ValueError(f"ApiRoute {item.path} is already registered") self._items[item.path] = item - def set(self, item: ApiRoute): + def set(self, item: TRoute): assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute" self._items[item.path] = item - def get(self, key: str) -> Optional[ApiRoute]: + def get(self, key: str) -> Optional[TRoute]: return self._items.get(key) - def all(self) -> list[ApiRoute]: + def all(self) -> list[TRoute]: return list(self._items.values()) diff --git a/src/cpl-api/cpl/api/router.py b/src/cpl-api/cpl/api/router.py index 27dfd5ab..55369c38 100644 --- a/src/cpl-api/cpl/api/router.py +++ b/src/cpl-api/cpl/api/router.py @@ -91,6 +91,22 @@ class Router: return inner + @classmethod + def websocket(cls, path: str, registry: RouteRegistry = None, **kwargs): + from cpl.api.model.websocket_route import WebSocketRoute + + if not registry: + routes = get_provider().get_service(RouteRegistry) + else: + routes = registry + + def inner(fn): + routes.add(WebSocketRoute(path, fn, **kwargs)) + setattr(fn, "_route_path", path) + return fn + + return inner + @classmethod def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs): from cpl.api.model.api_route import ApiRoute diff --git a/src/cpl-core/cpl/core/log/structured_logger.py b/src/cpl-core/cpl/core/log/structured_logger.py index 2d1b9eca..e8e45849 100644 --- a/src/cpl-core/cpl/core/log/structured_logger.py +++ b/src/cpl-core/cpl/core/log/structured_logger.py @@ -68,7 +68,7 @@ class StructuredLogger(Logger): message["request"] = { "url": str(request.url), - "method": request.method, + "method": request.method if request.scope == "http" else "websocket", "scope": self._scope_to_json(request), } if isinstance(request, Request) and request.scope == "http": diff --git a/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py b/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py index 70a81ad3..a369fd64 100644 --- a/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py +++ b/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py @@ -9,7 +9,10 @@ async def graphiql_endpoint(request): GraphiQL - +
@@ -21,13 +24,39 @@ async def graphiql_endpoint(request): + + +