Renamed project dirs
All checks were successful
Test before pr merge / test-lint (pull_request) Successful in 6s

This commit is contained in:
2025-10-11 09:32:13 +02:00
parent f1aaaf2a5b
commit 90ff8d466d
319 changed files with 0 additions and 0 deletions

View File

@@ -0,0 +1 @@
__version__ = "1.0.0"

View File

@@ -0,0 +1,69 @@
from starlette.responses import HTMLResponse
async def graphiql_endpoint(request):
return HTMLResponse(
"""
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title>GraphiQL</title>
<link
href="https://unpkg.com/graphiql@2.4.0/graphiql.min.css"
rel="stylesheet"
/>
</head>
<body style="margin:0;overflow:hidden;">
<div id="graphiql" style="height:100vh;"></div>
<!-- React + ReactDOM -->
<script src="https://unpkg.com/react@18.2.0/umd/react.production.min.js"></script>
<script src="https://unpkg.com/react-dom@18.2.0/umd/react-dom.production.min.js"></script>
<!-- GraphiQL -->
<script src="https://unpkg.com/graphiql@2.4.0/graphiql.min.js"></script>
<!-- GraphQL over WebSocket client -->
<script src="https://unpkg.com/graphql-ws@5.11.3/umd/graphql-ws.min.js"></script>
<script>
const httpUrl = window.location.origin + '/api/graphql';
const wsUrl = (window.location.protocol === 'https:' ? 'wss://' : 'ws://') +
window.location.host + '/api/graphql/ws';
// HTTP fetcher for queries & mutations
const httpFetcher = async (params) => {
const res = await fetch(httpUrl, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(params),
});
return res.json();
};
// WebSocket fetcher for subscriptions
const wsClient = graphqlWs.createClient({ url: wsUrl });
const wsFetcher = (params) => ({
subscribe: (sink) => ({
unsubscribe: wsClient.subscribe(params, sink),
}),
});
// smart fetcher wrapper (decides HTTP or WS)
const graphQLFetcher = (params) => {
if (params.query.trim().startsWith('subscription')) {
return wsFetcher(params);
}
return httpFetcher(params);
};
ReactDOM.render(
React.createElement(GraphiQL, { fetcher: graphQLFetcher }),
document.getElementById('graphiql'),
);
</script>
</body>
</html>
"""
)

View File

@@ -0,0 +1,13 @@
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from cpl.graphql.service.graphql import GraphQLService
async def graphql_endpoint(request: Request, service: GraphQLService) -> Response:
body = await request.json()
query = body.get("query")
variables = body.get("variables")
response_data = await service.execute(query, variables, request)
return JSONResponse(response_data)

View File

@@ -0,0 +1,27 @@
from starlette.requests import Request
from starlette.responses import Response
from strawberry.asgi import GraphQL
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
from cpl.dependency import ServiceProvider
from cpl.graphql.service.schema import Schema
class LazyGraphQLApp:
def __init__(self, services: ServiceProvider):
self._services = services
self._graphql_app = None
async def __call__(self, scope, receive, send):
if self._graphql_app is None:
schema = self._services.get_service(Schema)
if not schema or not schema.schema:
raise RuntimeError("GraphQL Schema not available yet")
self._graphql_app = GraphQL(
schema.schema,
subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
)
await self._graphql_app(scope, receive, send)

View File

@@ -0,0 +1,29 @@
from starlette.requests import Request
from starlette.responses import Response, HTMLResponse
async def playground_endpoint(request: Request) -> Response:
return HTMLResponse(
"""
<!DOCTYPE html>
<html>
<head>
<meta charset=utf-8/>
<title>GraphQL Playground</title>
<link rel="stylesheet" href="https://unpkg.com/graphql-playground-react/build/static/css/index.css" />
<link rel="shortcut icon" href="https://raw.githubusercontent.com/graphql/graphql-playground/master/packages/graphql-playground-react/public/favicon.png" />
<script src="https://unpkg.com/graphql-playground-react/build/static/js/middleware.js"></script>
</head>
<body>
<div id="root"/>
<script>
window.addEventListener('load', function () {
GraphQLPlayground.init(document.getElementById('root'), {
endpoint: '/api/graphql'
})
})
</script>
</body>
</html>
"""
)

View File

View File

@@ -0,0 +1,227 @@
import functools
import inspect
import types
from abc import ABC
from asyncio import iscoroutinefunction
from typing import Callable, Type, Any, Optional
import strawberry
from async_property.base import AsyncPropertyDescriptor
from strawberry.exceptions import StrawberryException
from cpl.api import Unauthorized, Forbidden
from cpl.core.ctx.user_context import get_user
from cpl.dependency import get_provider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.error import graphql_error
from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.field import Field
from cpl.graphql.typing import Resolver, AttributeName
from cpl.graphql.utils.type_collector import TypeCollector
class QueryABC(StrawberryProtocol, ABC):
def __init__(self):
ABC.__init__(self)
self._fields: dict[str, Field] = {}
@property
def fields(self) -> dict[str, Field]:
return self._fields
@property
def fields_count(self) -> int:
return len(self._fields)
def get_fields(self) -> dict[str, Field]:
return self._fields
def field(
self,
name: AttributeName,
t: type,
resolver: Resolver = None,
) -> Field:
from cpl.graphql.schema.field import Field
if isinstance(name, property):
name = name.fget.__name__
self._fields[name] = Field(name, t, resolver)
return self._fields[name]
def string_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
return self.field(name, str, resolver)
def int_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
return self.field(name, int, resolver)
def float_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
return self.field(name, float, resolver)
def bool_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
return self.field(name, bool, resolver)
def list_field(self, name: AttributeName, t: type, resolver: Resolver = None) -> Field:
return self.field(name, list[t], resolver)
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
if not isinstance(t, type) and callable(t):
return self.field(name, t, resolver)
return self.field(name, t().to_strawberry(), resolver)
@staticmethod
def _build_resolver(f: "Field"):
params: list[inspect.Parameter] = []
for arg in f.arguments.values():
_type = arg.type
if isinstance(_type, type) and issubclass(_type, StrawberryProtocol):
_type = _type().to_strawberry()
ann = Optional[_type] if arg.optional else _type
if arg.default is None:
param = inspect.Parameter(
arg.name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=ann,
)
else:
param = inspect.Parameter(
arg.name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=ann,
default=arg.default,
)
params.append(param)
sig = inspect.Signature(parameters=params, return_annotation=f.type)
async def _resolver(*args, **kwargs):
if f.resolver is None:
return None
if iscoroutinefunction(f.resolver):
return await f.resolver(*args, **kwargs)
return f.resolver(*args, **kwargs)
_resolver.__signature__ = sig
return _resolver
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
sig = getattr(resolver, "__signature__", None)
@functools.wraps(resolver)
async def _auth_resolver(*args, **kwargs):
if f.public:
return await self._run_resolver(resolver, *args, **kwargs)
user = get_user()
if user is None:
raise graphql_error(Unauthorized(f"{f.name}: Authentication required"))
if f.require_any_permission:
if not any([await user.has_permission(p) for p in f.require_any_permission]):
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
if f.require_any:
perms, resolvers = f.require_any
if not any([await user.has_permission(p) for p in perms]):
ctx = QueryContext([x.name for x in await user.permissions])
resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers]
if not any(resolved):
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
return await self._run_resolver(resolver, *args, **kwargs)
if sig:
_auth_resolver.__signature__ = sig
return _auth_resolver
@staticmethod
async def _run_resolver(r: Callable, *args, **kwargs):
result = r(*args, **kwargs)
if inspect.isawaitable(result):
return await result
return result
def _field_to_strawberry(self, f: Field) -> Any:
resolver = None
try:
if f.arguments:
resolver = self._build_resolver(f)
elif not f.resolver:
resolver = lambda root: None
else:
ann = getattr(f.resolver, "__annotations__", {})
if "return" not in ann or ann["return"] is None:
ann = dict(ann)
ann["return"] = f.type
f.resolver.__annotations__ = ann
resolver = f.resolver
return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
except StrawberryException as e:
raise Exception(f"Error converting field '{f.name}' to strawberry field: {e}") from e
@staticmethod
def _type_to_strawberry(t: Type) -> Type:
_t = get_provider().get_service(t)
if isinstance(_t, StrawberryProtocol):
return _t.to_strawberry()
return t
def to_strawberry(self) -> Type:
cls = self.__class__
if TypeCollector.has(cls):
return TypeCollector.get(cls)
gql_cls = type(f"{cls.__name__.replace('GraphType', '')}", (), {})
# register early to handle recursive types
TypeCollector.set(cls, gql_cls)
annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {}
for name, f in self._fields.items():
t = f.type
if isinstance(name, property):
name = name.fget.__name__
if isinstance(name, AsyncPropertyDescriptor):
name = name.field_name
if isinstance(t, types.GenericAlias):
t = t.__args__[0]
if callable(t) and not isinstance(t, type):
t = self._type_to_strawberry(t())
elif issubclass(t, StrawberryProtocol):
t = self._type_to_strawberry(t)
annotations[name] = t if not f.optional else Optional[t]
namespace[name] = self._field_to_strawberry(f)
namespace["__annotations__"] = annotations
for k, v in namespace.items():
if isinstance(k, property):
k = k.fget.__name__
if isinstance(k, AsyncPropertyDescriptor):
k = k.field_name
setattr(gql_cls, k, v)
try:
gql_cls.__annotations__ = annotations
gql_type = strawberry.type(gql_cls)
except Exception as e:
raise Exception(f"Error creating strawberry type for '{cls.__name__}': {e}") from e
TypeCollector.set(cls, gql_type)
return gql_type

View File

@@ -0,0 +1,11 @@
from typing import Protocol, Type, runtime_checkable
from cpl.graphql.schema.field import Field
from cpl.graphql.schema.subscription_field import SubscriptionField
@runtime_checkable
class StrawberryProtocol(Protocol):
def to_strawberry(self) -> Type: ...
def get_fields(self) -> dict[str, Field | SubscriptionField]: ...

View File

@@ -0,0 +1 @@
from .graphql_app import WebApp

View File

@@ -0,0 +1,126 @@
import socket
from enum import Enum
from typing import Self
from cpl.api.application import WebApp
from cpl.api.model.validation_match import ValidationMatch
from cpl.application.abc.application_abc import __not_implemented__
from cpl.core.environment import Environment
from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.typing import Modules
from cpl.graphql._endpoints.graphiql import graphiql_endpoint
from cpl.graphql._endpoints.graphql import graphql_endpoint
from cpl.graphql._endpoints.lazy_graphql_app import LazyGraphQLApp
from cpl.graphql._endpoints.playground import playground_endpoint
from cpl.graphql.graphql_module import GraphQLModule
from cpl.graphql.service.schema import Schema
class GraphQLApp(WebApp):
def __init__(self, services: ServiceProvider, modules: Modules):
WebApp.__init__(self, services, modules, [GraphQLModule])
self._with_graphiql = False
self._with_playground = False
def with_graphql(
self,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Schema:
self.with_route(
path="/api/graphql",
fn=graphql_endpoint,
method="POST",
authentication=authentication,
roles=roles,
permissions=permissions,
policies=policies,
match=match,
)
schema = self._services.get_service(Schema)
if schema is None:
self._logger.fatal("Could not resolve RootQuery. Make sure GraphQLModule is registered.")
#
# graphql_ws_app = GraphQL(
# schema,
# subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
# )
self.with_websocket("/api/graphql/ws", LazyGraphQLApp(self._services))
return schema
def with_graphiql(
self,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Self:
self.with_route(
path="/api/graphiql",
fn=graphiql_endpoint,
method="GET",
authentication=authentication,
roles=roles,
permissions=permissions,
policies=policies,
match=match,
)
self._with_graphiql = True
return self
def with_playground(
self,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Self:
self.with_route(
path="/api/playground",
fn=playground_endpoint,
method="GET",
authentication=authentication,
roles=roles,
permissions=permissions,
policies=policies,
match=match,
)
self._with_playground = True
return self
def with_auth_root_queries(self, public: bool = False):
try:
from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule
GraphQLAuthModule.with_auth_root_queries(self._services, public=public)
except ImportError:
__not_implemented__("cpl-auth & cpl-graphql", self.with_auth_root_mutations)
def with_auth_root_mutations(self, public: bool = False):
try:
from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule
GraphQLAuthModule.with_auth_root_mutations(self._services, public=public)
except ImportError:
__not_implemented__("cpl-auth & cpl-graphql", self.with_auth_root_mutations)
async def _log_before_startup(self):
host = self._api_settings.host
if host == "0.0.0.0" and Environment.get_environment() == "development":
host = "localhost"
elif host == "0.0.0.0":
host = socket.gethostbyname(socket.gethostname())
self._logger.info(f"Start API on {host}:{self._api_settings.port}")
if self._with_graphiql:
self._logger.warning(f"GraphiQL available at http://{host}:{self._api_settings.port}/api/graphiql")
if self._with_playground:
self._logger.warning(
f"GraphQL Playground available at http://{host}:{self._api_settings.port}/api/playground"
)

View File

View File

@@ -0,0 +1,10 @@
from cpl.auth.schema import ApiKey
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
class ApiKeyFilter(DbModelFilter[ApiKey]):
def __init__(self, public: bool = False):
DbModelFilter.__init__(self, public)
self.field("identifier", StringFilter).with_public(public)

View File

@@ -0,0 +1,14 @@
from cpl.auth.schema import ApiKey, RolePermissionDao
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
class ApiKeyGraphType(DbModelGraphType[ApiKey]):
def __init__(self, role_permission_dao: RolePermissionDao):
DbModelGraphType.__init__(self)
self.string_field(ApiKey.identifier, lambda root: root.identifier)
self.string_field(ApiKey.key, lambda root: root.key)
self.string_field(ApiKey.permissions, lambda root: root.permissions)
self.set_history_reference_dao(role_permission_dao, "apikeyid")

View File

@@ -0,0 +1,25 @@
from cpl.auth.schema import ApiKey
from cpl.core.typing import SerialId
from cpl.graphql.schema.input import Input
class ApiKeyCreateInput(Input[ApiKey]):
identifier: str
permissions: list[SerialId]
def __init__(self):
Input.__init__(self)
self.string_field("identifier").with_required()
self.list_field("permissions", SerialId)
class ApiKeyUpdateInput(Input[ApiKey]):
id: SerialId
identifier: str | None
permissions: list[SerialId] | None
def __init__(self):
Input.__init__(self)
self.int_field("id").with_required()
self.string_field("identifier").with_required()
self.list_field("permissions", SerialId)

View File

@@ -0,0 +1,93 @@
from cpl.api import APILogger
from cpl.auth.keycloak import KeycloakAdmin
from cpl.auth.permission import Permissions
from cpl.auth.schema import ApiKey, ApiKeyDao, ApiKeyPermissionDao, ApiKeyPermission
from cpl.graphql.auth.api_key.api_key_input import ApiKeyUpdateInput, ApiKeyCreateInput
from cpl.graphql.schema.mutation import Mutation
class ApiKeyMutation(Mutation):
def __init__(
self,
logger: APILogger,
api_key_dao: ApiKeyDao,
api_key_permission_dao: ApiKeyPermissionDao,
permission_dao: ApiKeyPermissionDao,
keycloak_admin: KeycloakAdmin,
):
Mutation.__init__(self)
self._logger = logger
self._api_key_dao = api_key_dao
self._api_key_permission_dao = api_key_permission_dao
self._permission_dao = permission_dao
self._keycloak_admin = keycloak_admin
self.int_field(
"create",
self.resolve_create,
).with_require_any_permission(Permissions.api_keys_create).with_argument(
"input",
ApiKeyCreateInput,
).with_required()
self.bool_field(
"update",
self.resolve_update,
).with_require_any_permission(Permissions.api_keys_update).with_argument(
"input",
ApiKeyUpdateInput,
).with_required()
self.bool_field(
"delete",
self.resolve_delete,
).with_require_any_permission(Permissions.api_keys_delete).with_argument(
"id",
int,
).with_required()
self.bool_field(
"restore",
self.resolve_restore,
).with_require_any_permission(Permissions.api_keys_delete).with_argument(
"id",
int,
).with_required()
async def resolve_create(self, obj: ApiKeyCreateInput):
self._logger.debug(f"create api key: {obj.__dict__}")
api_key = ApiKey.new(obj.identifier)
await self._api_key_dao.create(api_key)
api_key = await self._api_key_dao.get_single_by([{ApiKey.identifier: obj.identifier}])
await self._api_key_permission_dao.create_many([ApiKeyPermission(0, api_key.id, x) for x in obj.permissions])
return api_key
async def resolve_update(self, input: ApiKeyUpdateInput):
self._logger.debug(f"update api key: {input}")
api_key = await self._api_key_dao.get_by_id(input.id)
await self._resolve_assignments(
input.permissions or [],
api_key,
ApiKeyPermission.api_key_id,
ApiKeyPermission.permission_id,
self._api_key_dao,
self._api_key_permission_dao,
ApiKeyPermission,
self._permission_dao,
)
return api_key
async def resolve_delete(self, id: str):
self._logger.debug(f"delete api key: {id}")
api_key = await self._api_key_dao.get_by_id(id)
await self._api_key_dao.delete(api_key)
return True
async def resolve_restore(self, id: str):
self._logger.debug(f"restore api key: {id}")
api_key = await self._api_key_dao.get_by_id(id)
await self._api_key_dao.restore(api_key)
return True

View File

@@ -0,0 +1,9 @@
from cpl.auth.schema import ApiKey
from cpl.graphql.schema.sort.db_model_sort import DbModelSort
from cpl.graphql.schema.sort.sort_order import SortOrder
class ApiKeySort(DbModelSort[ApiKey]):
def __init__(self):
DbModelSort.__init__(self)
self.field("identifier", SortOrder)

View File

@@ -0,0 +1,77 @@
from cpl.auth.permission import Permissions
from cpl.auth.schema import UserDao, ApiKeyDao, RoleDao
from cpl.core.configuration import Configuration
from cpl.dependency import ServiceProvider
from cpl.dependency.module.module import Module
from cpl.dependency.service_collection import ServiceCollection
from cpl.graphql.auth.api_key.api_key_filter import ApiKeyFilter
from cpl.graphql.auth.api_key.api_key_graph_type import ApiKeyGraphType
from cpl.graphql.auth.api_key.api_key_mutation import ApiKeyMutation
from cpl.graphql.auth.api_key.api_key_sort import ApiKeySort
from cpl.graphql.auth.role.role_filter import RoleFilter
from cpl.graphql.auth.role.role_graph_type import RoleGraphType
from cpl.graphql.auth.role.role_mutation import RoleMutation
from cpl.graphql.auth.role.role_sort import RoleSort
from cpl.graphql.auth.user.user_filter import UserFilter
from cpl.graphql.auth.user.user_graph_type import UserGraphType
from cpl.graphql.auth.user.user_mutation import UserMutation
from cpl.graphql.auth.user.user_sort import UserSort
from cpl.graphql.graphql_module import GraphQLModule
from cpl.graphql.service.schema import Schema
class GraphQLAuthModule(Module):
dependencies = [GraphQLModule]
transient = [
UserGraphType,
UserMutation,
UserFilter,
UserSort,
ApiKeyGraphType,
ApiKeyMutation,
ApiKeyFilter,
ApiKeySort,
RoleGraphType,
RoleMutation,
RoleFilter,
RoleSort,
]
@staticmethod
def register(collection: ServiceCollection):
Configuration.set("GraphQLAuthModuleEnabled", True)
@staticmethod
def configure(provider: ServiceProvider):
schema = provider.get_service(Schema)
schema.with_type(UserGraphType)
schema.with_type(ApiKeyGraphType)
schema.with_type(RoleGraphType)
@staticmethod
def with_auth_root_queries(provider: ServiceProvider, public: bool = False):
if not Configuration.get("GraphQLAuthModuleEnabled", False):
raise Exception("GraphQLAuthModule is not loaded yet. Make sure to run 'add_module(GraphQLAuthModule)'")
schema = provider.get_service(Schema)
schema.query.dao_collection_field(
UserGraphType, UserDao, "users", UserFilter, UserSort
).with_require_any_permission(Permissions.users).with_public(public)
schema.query.dao_collection_field(
ApiKeyGraphType, ApiKeyDao, "apiKeys", ApiKeyFilter, ApiKeySort
).with_require_any_permission(Permissions.api_keys).with_public(public)
schema.query.dao_collection_field(
RoleGraphType, RoleDao, "roles", RoleFilter, RoleSort
).with_require_any_permission(Permissions.roles).with_public(public)
@staticmethod
def with_auth_root_mutations(provider: ServiceProvider, public: bool = False):
if not Configuration.get("GraphQLAuthModuleEnabled", False):
raise Exception("GraphQLAuthModule is not loaded yet. Make sure to run 'add_module(GraphQLAuthModule)'")
schema = provider.get_service(Schema)
schema.mutation.with_mutation("user", UserMutation).with_public(public)
schema.mutation.with_mutation("apiKey", ApiKeyMutation).with_public(public)
schema.mutation.with_mutation("role", RoleMutation).with_public(public)

View File

@@ -0,0 +1,11 @@
from cpl.auth.schema import User, Role
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
class RoleFilter(DbModelFilter[Role]):
def __init__(self, public: bool = False):
DbModelFilter.__init__(self, public)
self.field("name", StringFilter).with_public(public)
self.field("description", StringFilter).with_public(public)

View File

@@ -0,0 +1,14 @@
from cpl.auth.schema import Role
from cpl.graphql.auth.user.user_graph_type import UserGraphType
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
class RoleGraphType(DbModelGraphType[Role]):
def __init__(self, public: bool = False):
DbModelGraphType.__init__(self)
self.string_field("name", lambda root: root.name).with_public(public)
self.string_field("description", lambda root: root.description).with_public(public)
self.list_field("permissions", str, lambda root: root.permissions).with_public(public)
self.list_field("users", UserGraphType, lambda root: root.users).with_public(public)

View File

@@ -0,0 +1,29 @@
from cpl.auth.schema import User, Role
from cpl.core.typing import SerialId
from cpl.graphql.schema.input import Input
class RoleCreateInput(Input[Role]):
name: str
description: str | None
permissions: list[SerialId] | None
def __init__(self):
Input.__init__(self)
self.string_field("name").with_required()
self.string_field("description")
self.list_field("permissions", SerialId)
class RoleUpdateInput(Input[Role]):
id: SerialId
name: str | None
description: str | None
permissions: list[SerialId] | None
def __init__(self):
Input.__init__(self)
self.int_field("id").with_required()
self.string_field("name")
self.string_field("description")
self.list_field("permissions", SerialId)

View File

@@ -0,0 +1,101 @@
from cpl.api import APILogger
from cpl.auth.keycloak import KeycloakAdmin
from cpl.auth.permission import Permissions
from cpl.auth.schema import RoleDao, Role, RolePermissionDao, RolePermission
from cpl.graphql.auth.role.role_input import RoleCreateInput, RoleUpdateInput
from cpl.graphql.schema.mutation import Mutation
class RoleMutation(Mutation):
def __init__(
self,
logger: APILogger,
role_dao: RoleDao,
role_permission_dao: RolePermissionDao,
permission_dao: RolePermissionDao,
keycloak_admin: KeycloakAdmin,
):
Mutation.__init__(self)
self._logger = logger
self._role_dao = role_dao
self._role_permission_dao = role_permission_dao
self._permission_dao = permission_dao
self._keycloak_admin = keycloak_admin
self.int_field(
"create",
self.resolve_create,
).with_require_any_permission(Permissions.roles_create).with_argument(
"input",
RoleCreateInput,
).with_required()
self.bool_field(
"update",
self.resolve_update,
).with_require_any_permission(Permissions.roles_update).with_argument(
"input",
RoleUpdateInput,
).with_required()
self.bool_field(
"delete",
self.resolve_delete,
).with_require_any_permission(Permissions.roles_delete).with_argument(
"id",
int,
).with_required()
self.bool_field(
"restore",
self.resolve_restore,
).with_require_any_permission(Permissions.roles_delete).with_argument(
"id",
int,
).with_required()
async def resolve_create(self, input: RoleCreateInput, *_):
self._logger.debug(f"create role: {input.__dict__}")
role = Role(
0,
input.name,
input.description,
)
await self._role_dao.create(role)
role = await self._role_dao.get_by_name(role.name)
await self._role_permission_dao.create_many([RolePermission(0, role.id, x) for x in input.permissions])
return role
async def resolve_update(self, input: RoleUpdateInput, *_):
self._logger.debug(f"update role: {input.__dict__}")
role = await self._role_dao.get_by_id(input.id)
role.name = input.get("name", role.name)
role.description = input.get("description", role.description)
await self._role_dao.update(role)
await self._resolve_assignments(
input.get("permissions", []),
role,
RolePermission.role_id,
RolePermission.permission_id,
self._role_dao,
self._role_permission_dao,
RolePermission,
self._permission_dao,
)
return role
async def resolve_delete(self, id: int):
self._logger.debug(f"delete role: {id}")
role = await self._role_dao.get_by_id(id)
await self._role_dao.delete(role)
return True
async def resolve_restore(self, id: int):
self._logger.debug(f"restore role: {id}")
role = await self._role_dao.get_by_id(id)
await self._role_dao.restore(role)
return True

View File

@@ -0,0 +1,10 @@
from cpl.auth.schema import Role
from cpl.graphql.schema.sort.db_model_sort import DbModelSort
from cpl.graphql.schema.sort.sort_order import SortOrder
class RoleSort(DbModelSort[Role]):
def __init__(self):
DbModelSort.__init__(self)
self.field("name", SortOrder)
self.field("description", SortOrder)

View File

@@ -0,0 +1,11 @@
from cpl.auth.schema import User
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
class UserFilter(DbModelFilter[User]):
def __init__(self, public: bool = False):
DbModelFilter.__init__(self, public)
self.field("username", StringFilter).with_public(public)
self.field("email", StringFilter).with_public(public)

View File

@@ -0,0 +1,12 @@
from cpl.auth.schema import User
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
class UserGraphType(DbModelGraphType[User]):
def __init__(self, public: bool = False):
DbModelGraphType.__init__(self)
self.string_field(User.keycloak_id, lambda root: root.keycloak_id).with_public(public)
self.string_field(User.username, lambda root: root.username).with_public(public)
self.string_field(User.email, lambda root: root.email).with_public(public)

View File

@@ -0,0 +1,23 @@
from cpl.auth.schema import User
from cpl.core.typing import SerialId
from cpl.graphql.schema.input import Input
class UserCreateInput(Input[User]):
keycloak_id: str
roles: list[SerialId] | None
def __init__(self):
Input.__init__(self)
self.string_field("keycloak_id").with_required()
self.list_field("roles", SerialId)
class UserUpdateInput(Input[User]):
id: SerialId
roles: list[SerialId] | None
def __init__(self):
Input.__init__(self)
self.int_field("id").with_required()
self.list_field("roles", SerialId)

View File

@@ -0,0 +1,112 @@
from cpl.api import APILogger
from cpl.auth.keycloak import KeycloakAdmin
from cpl.auth.permission import Permissions
from cpl.auth.schema import UserDao, User, RoleUser, RoleUserDao, RoleDao
from cpl.core.ctx.user_context import get_user
from cpl.graphql.auth.user.user_input import UserCreateInput, UserUpdateInput
from cpl.graphql.schema.mutation import Mutation
class UserMutation(Mutation):
def __init__(
self,
logger: APILogger,
user_dao: UserDao,
role_user_dao: RoleUserDao,
role_dao: RoleDao,
keycloak_admin: KeycloakAdmin,
):
Mutation.__init__(self)
self._logger = logger
self._user_dao = user_dao
self._role_user_dao = role_user_dao
self._role_dao = role_dao
self._keycloak_admin = keycloak_admin
self.int_field(
"create",
self.resolve_create,
).with_require_any_permission(Permissions.users_create).with_argument(
"input",
UserCreateInput,
).with_required()
self.bool_field(
"update",
self.resolve_update,
).with_require_any_permission(Permissions.users_update).with_argument(
"input",
UserUpdateInput,
).with_required()
self.bool_field(
"delete",
self.resolve_delete,
).with_require_any_permission(Permissions.users_delete).with_argument(
"id",
int,
).with_required()
self.bool_field(
"restore",
self.resolve_restore,
).with_require_any_permission(Permissions.users_delete).with_argument(
"id",
int,
).with_required()
async def resolve_create(self, input: UserCreateInput):
self._logger.debug(f"create user: {input.__dict__}")
# ensure keycloak knows a user with this keycloak_id
# get_user should raise an exception if the user does not exist
kc_user = self._keycloak_admin.get_user(input.keycloak_id)
if kc_user is None:
raise ValueError(f"Keycloak user with id {input.keycloak_id} does not exist")
user = User(0, input.keycloak_id, input.license)
user_id = await self._user_dao.create(user)
user = await self._user_dao.get_by_id(user_id)
await self._role_user_dao.create_many([RoleUser(0, user.id, x) for x in set(input.roles)])
return user
async def resolve_update(self, input: UserUpdateInput):
self._logger.debug(f"update user: {input.__dict__}")
user = await self._user_dao.get_by_id(input.id)
if input.license:
user.license = input.license
await self._user_dao.update(user)
await self._resolve_assignments(
input.roles or [],
user,
RoleUser.user_id,
RoleUser.role_id,
self._user_dao,
self._role_user_dao,
RoleUser,
self._role_dao,
)
return user
async def resolve_delete(self, id: int):
self._logger.debug(f"delete user: {id}")
user = await self._user_dao.get_by_id(id)
await self._user_dao.delete(user)
try:
active_user = get_user()
if active_user is not None and active_user.id == user.id:
# await broadcast.publish("userLogout", user.id)
self._keycloak_admin.user_logout(user_id=user.keycloak_id)
except Exception as e:
self._logger.error(f"Failed to logout user from Keycloak", e)
return True
async def resolve_restore(self, id: int):
self._logger.debug(f"restore user: {id}")
user = await self._user_dao.get_by_id(id)
await self._user_dao.restore(user)
return True

View File

@@ -0,0 +1,10 @@
from cpl.auth.schema import User
from cpl.graphql.schema.sort.db_model_sort import DbModelSort
from cpl.graphql.schema.sort.sort_order import SortOrder
class UserSort(DbModelSort[User]):
def __init__(self):
DbModelSort.__init__(self)
self.field("username", SortOrder)
self.field("email", SortOrder)

View File

@@ -0,0 +1,14 @@
from graphql import GraphQLError
from cpl.api import APIError
def graphql_error(api_error: APIError) -> GraphQLError:
"""Convert an APIError (from cpl-api) into a GraphQL-friendly error."""
return GraphQLError(
message=api_error.error_message,
extensions={
"code": api_error.status_code,
},
original_error=api_error,
)

View File

@@ -0,0 +1,27 @@
import asyncio
from typing import Any, AsyncGenerator
from cpl.dependency.event_bus import EventBusABC
class InMemoryEventBus(EventBusABC):
def __init__(self):
self._subscribers: dict[str, list[asyncio.Queue]] = {}
async def publish(self, channel: str, event: Any) -> None:
queues = self._subscribers.get(channel, [])
for q in queues.copy():
await q.put(event)
async def subscribe(self, channel: str) -> AsyncGenerator[Any, None]:
q = asyncio.Queue()
if channel not in self._subscribers:
self._subscribers[channel] = []
self._subscribers[channel].append(q)
try:
while True:
item = await q.get()
yield item
finally:
self._subscribers[channel].remove(q)

View File

@@ -0,0 +1,25 @@
from cpl.api.api_module import ApiModule
from cpl.dependency.module.module import Module
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.schema.filter.bool_filter import BoolFilter
from cpl.graphql.schema.filter.date_filter import DateFilter
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.filter.int_filter import IntFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
from cpl.graphql.schema.root_mutation import RootMutation
from cpl.graphql.schema.root_query import RootQuery
from cpl.graphql.schema.root_subscription import RootSubscription
from cpl.graphql.service.graphql import GraphQLService
from cpl.graphql.service.schema import Schema
class GraphQLModule(Module):
dependencies = [ApiModule]
singleton = [Schema, RootQuery, RootMutation, RootSubscription]
scoped = [GraphQLService]
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]
@staticmethod
def configure(services: ServiceProvider) -> None:
schema = services.get_service(Schema)
schema.build()

View File

@@ -0,0 +1,48 @@
from enum import Enum
from typing import Optional
from graphql import GraphQLResolveInfo
from cpl.auth.schema import User, Permission
from cpl.core.ctx import get_user
class QueryContext:
def __init__(self, user_permissions: Optional[list[Enum | Permission]], is_mutation: bool = False, *args, **kwargs):
self._user = get_user()
self._user_permissions = user_permissions or []
self._resolve_info = None
for arg in args:
if isinstance(arg, GraphQLResolveInfo):
self._resolve_info = arg
continue
self._args = args
self._kwargs = kwargs
self._is_mutation = is_mutation
@property
def user(self) -> User:
return self._user
@property
def resolve_info(self) -> Optional[GraphQLResolveInfo]:
return self._resolve_info
@property
def args(self) -> tuple:
return self._args
@property
def kwargs(self) -> dict:
return self._kwargs
@property
def is_mutation(self) -> bool:
return self._is_mutation
def has_permission(self, permission: Enum | str) -> bool:
return permission.value if isinstance(permission, Enum) else permission in self._user_permissions

View File

@@ -0,0 +1,54 @@
from typing import Any, Self
class Argument:
def __init__(
self,
name: str,
t: type,
description: str = None,
default: Any = None,
optional: bool = None,
):
self._name = name
self._type = t
self._description = description
self._default = default
self._optional = optional
@property
def name(self) -> str:
return self._name
@property
def type(self) -> type:
return self._type
@property
def description(self) -> str | None:
return self._description
@property
def default(self) -> Any | None:
return self._default
@property
def optional(self) -> bool | None:
return self._optional
def with_description(self, description: str) -> Self:
self._description = description
return self
def with_default(self, default: Any) -> Self:
self._default = default
return self
def with_optional(self, optional: bool) -> Self:
self._optional = optional
return self
def with_required(self, required: bool = True) -> Self:
self._optional = not required
return self

View File

@@ -0,0 +1,61 @@
from typing import Type, Dict, List
import strawberry
from cpl.core.typing import T
from cpl.dependency import get_provider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.utils.type_collector import TypeCollector
class CollectionGraphTypeFactory:
@classmethod
def get(cls, node_type: Type[StrawberryProtocol]) -> Type:
type_name = f"{node_type.__name__.replace('GraphType', '')}Collection"
if TypeCollector.has(type_name):
return TypeCollector.get(type_name)
node_t = get_provider().get_service(node_type)
if not node_t:
raise ValueError(f"Node type '{node_type.__name__}' not registered in service provider")
gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type
gql_cls = type(type_name, (), {})
TypeCollector.set(type_name, gql_cls)
gql_cls.__annotations__ = {
"nodes": List[gql_node],
"total_count": int,
"count": int,
}
for k in gql_cls.__annotations__.keys():
setattr(gql_cls, k, strawberry.field())
gql_type = strawberry.type(gql_cls)
TypeCollector.set(type_name, gql_type)
return gql_type
class Collection:
def __init__(self, nodes: list[T], total_count: int, count: int):
self._nodes = nodes
self._total_count = total_count
self._count = count
@property
def nodes(self) -> list[T]:
return self._nodes
@property
def total_count(self) -> int:
return self._total_count
@property
def count(self) -> int:
return self._count

View File

@@ -0,0 +1,62 @@
from typing import Type, Optional, Generic, Annotated
import strawberry
from cpl.core.configuration import Configuration
from cpl.core.typing import T
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.query import Query
class DbModelGraphType(GraphType[T], Generic[T]):
def __init__(self, t_dao: Type[DataAccessObjectABC] = None, with_history: bool = False, public: bool = False):
Query.__init__(self)
self._dao: Optional[DataAccessObjectABC] = None
if t_dao is not None:
dao = self._provider.get_service(t_dao)
if dao is not None:
self._dao = dao
self.int_field("id", lambda root: root.id).with_public(public)
self.bool_field("deleted", lambda root: root.deleted).with_public(public)
if Configuration.get("GraphQLAuthModuleEnabled", False):
from cpl.graphql.auth.user.user_graph_type import UserGraphType
self.object_field("editor", lambda: UserGraphType, lambda root: root.editor).with_public(public)
self.string_field("created", lambda root: root.created).with_public(public)
self.string_field("updated", lambda root: root.updated).with_public(public)
# if with_history:
# if self._dao is None:
# raise ValueError("DAO must be provided to enable history")
# self.set_field("history", self._resolve_history).with_public(public)
self._history_reference_daos: dict[DataAccessObjectABC, str] = {}
async def _resolve_history(self, root):
if self._dao is None:
raise Exception("DAO not set for history query")
history = sorted(
[await self._dao.get_by_id(root.id), *await self._dao.get_history(root.id)],
key=lambda h: h.updated,
reverse=True,
)
return history
def set_history_reference_dao(self, dao: DataAccessObjectABC, key: str = None):
"""
Set the reference DAO for history resolution.
:param dao:
:param key: The key to use for resolving history.
:return:
"""
if key is None:
key = "id"
self._history_reference_daos[dao] = key

View File

@@ -0,0 +1,141 @@
from enum import Enum
from typing import Self
from cpl.graphql.schema.argument import Argument
from cpl.graphql.typing import TQuery, Resolver, TRequireAnyPermissions, TRequireAnyResolvers
class Field:
def __init__(
self,
name: str,
t: type = None,
resolver: Resolver = None,
optional=None,
default=None,
subquery: TQuery = None,
parent_type=None,
):
self._name = name
self._type = t
self._resolver = resolver
self._optional = optional or True
self._default = default
self._subquery = subquery
self._parent_type = parent_type
self._args: dict[str, Argument] = {}
self._require_any_permission = None
self._require_any = None
self._public = False
@property
def name(self) -> str:
return self._name
@property
def type(self) -> type:
return self._type
@property
def resolver(self) -> callable:
return self._resolver
@property
def optional(self) -> bool | None:
return self._optional
@property
def default(self):
return self._default
@property
def args(self) -> dict:
return self._args
@property
def subquery(self) -> TQuery | None:
return self._subquery
@property
def parent_type(self):
return self._parent_type
@property
def arguments(self) -> dict[str, Argument]:
return self._args
@property
def require_any_permission(self) -> TRequireAnyPermissions | None:
return self._require_any_permission
@property
def require_any(self) -> TRequireAnyResolvers | None:
return self._require_any
@property
def public(self) -> bool:
return self._public
def with_type(self, t: type) -> Self:
self._type = t
return self
def with_resolver(self, resolver: Resolver) -> Self:
self._resolver = resolver
return self
def with_optional(self, optional: bool = True) -> Self:
self._optional = optional
return self
def with_required(self, required: bool = True) -> Self:
self._optional = not required
return self
def with_default(self, default) -> Self:
self._default = default
return self
def with_argument(
self, name: str, arg_type: type, description: str = None, default_value=None, optional=True
) -> Argument:
if name in self._args:
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
self._args[name] = Argument(name, arg_type, description, default_value, optional)
return self._args[name]
def with_arguments(self, args: list[Argument]) -> Self:
for arg in args:
if not isinstance(arg, Argument):
raise ValueError(f"Expected Argument instance, got {type(arg)}")
self.with_argument(arg.type, arg.name, arg.description, arg.default, arg.optional)
return self
def with_require_any_permission(self, *permissions: TRequireAnyPermissions) -> Self:
if not isinstance(permissions, list):
permissions = list(permissions)
assert permissions is not None, "require_any_permission cannot be None"
assert all(isinstance(x, (str, Enum)) for x in permissions), "All permissions must be of Permission type"
self._require_any_permission = permissions
return self
def with_require_any(self, permissions: TRequireAnyPermissions, resolvers: TRequireAnyResolvers) -> Self:
assert permissions is not None, "permissions cannot be None"
assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type"
assert resolvers is not None, "resolvers cannot be None"
assert all(callable(r) for r in resolvers), "All resolvers must be callable"
self._require_any = (permissions, resolvers)
return self
def with_public(self, public: bool = True) -> Self:
if public:
self._require_any = None
self._require_any_permission = None
self._public = public
return self

View File

@@ -0,0 +1,10 @@
from cpl.graphql.schema.input import Input
class BoolFilter(Input[bool]):
def __init__(self):
super().__init__()
self.field("equal", bool, optional=True)
self.field("notEqual", bool, optional=True)
self.field("isNull", bool, optional=True)
self.field("isNotNull", bool, optional=True)

View File

@@ -0,0 +1,18 @@
from datetime import datetime
from cpl.graphql.schema.input import Input
class DateFilter(Input[datetime]):
def __init__(self):
super().__init__()
self.field("equal", datetime, optional=True)
self.field("notEqual", datetime, optional=True)
self.field("greater", datetime, optional=True)
self.field("greaterOrEqual", datetime, optional=True)
self.field("less", datetime, optional=True)
self.field("lessOrEqual", datetime, optional=True)
self.field("isNull", datetime, optional=True)
self.field("isNotNull", datetime, optional=True)
self.field("in", list[datetime], optional=True)
self.field("notIn", list[datetime], optional=True)

View File

@@ -0,0 +1,23 @@
from typing import Generic
from cpl.core.configuration.configuration import Configuration
from cpl.core.typing import T
from cpl.graphql.schema.filter.bool_filter import BoolFilter
from cpl.graphql.schema.filter.date_filter import DateFilter
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.filter.int_filter import IntFilter
class DbModelFilter(Filter[T], Generic[T]):
def __init__(self, public: bool = False):
Filter.__init__(self)
self.field("id", IntFilter).with_public(public)
self.field("deleted", BoolFilter).with_public(public)
if Configuration.get("GraphQLAuthModuleEnabled", False):
from cpl.graphql.auth.user.user_filter import UserFilter
self.field("editor", lambda: UserFilter).with_public(public)
self.field("created", DateFilter).with_public(public)
self.field("updated", DateFilter).with_public(public)

View File

@@ -0,0 +1,28 @@
from typing import Type
from cpl.core.typing import T
from cpl.graphql.schema.filter.bool_filter import BoolFilter
from cpl.graphql.schema.filter.date_filter import DateFilter
from cpl.graphql.schema.filter.int_filter import IntFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
from cpl.graphql.schema.input import Input
class Filter(Input[T]):
def __init__(self):
Input.__init__(self)
def filter_field(self, name: str, filter_type: Type["Filter"]):
self.field(name, filter_type)
def string_field(self, name: str):
self.field(name, StringFilter)
def int_field(self, name: str):
self.field(name, IntFilter)
def bool_field(self, name: str):
self.field(name, BoolFilter)
def date_field(self, name: str):
self.field(name, DateFilter)

View File

@@ -0,0 +1,16 @@
from cpl.graphql.schema.input import Input
class IntFilter(Input[int]):
def __init__(self):
super().__init__()
self.field("equal", int, optional=True)
self.field("notEqual", int, optional=True)
self.field("greater", int, optional=True)
self.field("greaterOrEqual", int, optional=True)
self.field("less", int, optional=True)
self.field("lessOrEqual", int, optional=True)
self.field("isNull", int, optional=True)
self.field("isNotNull", int, optional=True)
self.field("in", list[int], optional=True)
self.field("notIn", list[int], optional=True)

View File

@@ -0,0 +1,16 @@
from cpl.graphql.schema.input import Input
class StringFilter(Input[str]):
def __init__(self):
super().__init__()
self.field("equal", str, optional=True)
self.field("notEqual", str, optional=True)
self.field("contains", str, optional=True)
self.field("notContains", str, optional=True)
self.field("startsWith", str, optional=True)
self.field("endsWith", str, optional=True)
self.field("isNull", str, optional=True)
self.field("isNotNull", str, optional=True)
self.field("in", list[str], optional=True)
self.field("notIn", list[str], optional=True)

View File

@@ -0,0 +1,10 @@
from typing import Generic
from cpl.core.typing import T
from cpl.graphql.schema.query import Query
class GraphType(Query, Generic[T]):
def __init__(self):
Query.__init__(self)

View File

@@ -0,0 +1,115 @@
import types
from typing import Generic, Dict, Type, Optional, Union, Any
import strawberry
from cpl.core.typing import T
from cpl.dependency import get_provider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.field import Field
from cpl.graphql.typing import AttributeName
from cpl.graphql.utils.type_collector import TypeCollector
_PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
class Input(StrawberryProtocol, Generic[T]):
def __init__(self):
self._fields: Dict[str, Field] = {}
self._values: Dict[str, Any] = {}
@property
def fields(self) -> Dict[str, Field]:
return self._fields
def __getattr__(self, item):
if item in self._values:
return self._values[item]
raise AttributeError(f"{self.__class__.__name__} has no attribute {item}")
def __setattr__(self, key, value):
if key in {"_fields", "_values"}:
super().__setattr__(key, value)
elif key in self._fields:
self._values[key] = value
else:
super().__setattr__(key, value)
def get(self, key: str, default=None):
return self._values.get(key, default)
def get_fields(self) -> dict[str, Field]:
return self._fields
def field(self, name: AttributeName, typ: type, optional: bool = True) -> Field:
if isinstance(name, property):
name = name.fget.__name__
self._fields[name] = Field(name, typ, optional=optional)
return self._fields[name]
def string_field(self, name: AttributeName, optional: bool = True) -> Field:
return self.field(name, str)
def int_field(self, name: AttributeName, optional: bool = True) -> Field:
return self.field(name, int, optional)
def float_field(self, name: AttributeName, optional: bool = True) -> Field:
return self.field(name, float, optional)
def bool_field(self, name: AttributeName, optional: bool = True) -> Field:
return self.field(name, bool, optional)
def list_field(self, name: AttributeName, t: type, optional: bool = True) -> Field:
return self.field(name, list[t], optional)
def object_field(self, name: AttributeName, t: Type[StrawberryProtocol], optional: bool = True) -> Field:
if not isinstance(t, type) and callable(t):
return self.field(name, t, optional)
return self.field(name, t().to_strawberry(), optional)
def to_strawberry(self) -> Type:
cls = self.__class__
if TypeCollector.has(cls):
return TypeCollector.get(cls)
gql_cls = type(f"{cls.__name__.replace('GraphType', '')}", (), {})
# register early to handle recursive types
TypeCollector.set(cls, gql_cls)
annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {}
for name, f in self._fields.items():
t = f.type
if isinstance(t, types.FunctionType):
_t = get_provider().get_service(t())
if _t is None:
raise ValueError(f"'{t()}' could not be resolved from the provider")
t = _t.to_strawberry()
elif isinstance(t, type) and issubclass(t, Input):
t = t().to_strawberry()
elif isinstance(t, Input):
t = t.to_strawberry()
py_name = name + "_" if name in _PYTHON_KEYWORDS else name
annotations[py_name] = t if not f.optional else Optional[t]
field_args = {}
if py_name != name:
field_args["name"] = name
default = None if f.optional else f.default
namespace[py_name] = strawberry.field(default=default, **field_args)
namespace["__annotations__"] = annotations
for k, v in namespace.items():
setattr(gql_cls, k, v)
gql_cls.__annotations__ = annotations
gql_type = strawberry.input(gql_cls)
TypeCollector.set(cls, gql_type)
return gql_type

View File

@@ -0,0 +1,93 @@
from typing import Type, Union
from cpl.core.typing import T
from cpl.database.abc import DataAccessObjectABC, DbJoinModelABC
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.query_abc import QueryABC
from cpl.graphql.schema.field import Field
class Mutation(QueryABC):
@inject
def __init__(self, provider: ServiceProvider):
QueryABC.__init__(self)
self._provider = provider
from cpl.graphql.service.schema import Schema
self._schema = provider.get_service(Schema)
def with_mutation(self, name: str, cls: Type["Mutation"]) -> Field:
sub = self._provider.get_service(cls)
if not sub:
raise ValueError(f"Mutation '{cls.__name__}' not registered in service provider")
return self.field(name, sub.to_strawberry(), lambda: sub)
@staticmethod
async def _resolve_assignments(
foreign_objs: list[int],
resolved_obj: T,
reference_key_own: Union[str, property],
reference_key_foreign: Union[str, property],
source_dao: DataAccessObjectABC[T],
join_dao: DataAccessObjectABC[T],
join_type: Type[DbJoinModelABC],
foreign_dao: DataAccessObjectABC[T],
):
if foreign_objs is None:
return
reference_key_foreign_attr = reference_key_foreign
if isinstance(reference_key_foreign, property):
reference_key_foreign_attr = reference_key_foreign.fget.__name__
foreign_list = await join_dao.find_by([{reference_key_own: resolved_obj.id}, {"deleted": False}])
to_delete = (
foreign_list
if len(foreign_objs) == 0
else await join_dao.find_by(
[
{reference_key_own: resolved_obj.id},
{reference_key_foreign: {"notIn": foreign_objs}},
]
)
)
foreign_ids = [getattr(x, reference_key_foreign_attr) for x in foreign_list]
deleted_foreign_ids = [
getattr(x, reference_key_foreign_attr)
for x in await join_dao.find_by([{reference_key_own: resolved_obj.id}, {"deleted": True}])
]
to_create = [
join_type(0, resolved_obj.id, x)
for x in foreign_objs
if x not in foreign_ids and x not in deleted_foreign_ids
]
to_restore = [
await join_dao.get_single_by(
[
{reference_key_own: resolved_obj.id},
{reference_key_foreign: x},
]
)
for x in foreign_objs
if x not in foreign_ids and x in deleted_foreign_ids
]
if len(to_delete) > 0:
await join_dao.delete_many(to_delete)
if len(to_create) > 0:
await join_dao.create_many(to_create)
if len(to_restore) > 0:
await join_dao.restore_many(to_restore)
foreign_changes = [*to_delete, *to_create, *to_restore]
if len(foreign_changes) > 0:
await source_dao.touch(resolved_obj)
await foreign_dao.touch_many_by_id([getattr(x, reference_key_foreign_attr) for x in foreign_changes])

View File

@@ -0,0 +1,131 @@
from typing import Callable, Type
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.query_abc import QueryABC
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
from cpl.graphql.schema.field import Field
from cpl.graphql.schema.sort.sort_order import SortOrder
class Query(QueryABC):
@inject
def __init__(self, provider: ServiceProvider):
QueryABC.__init__(self)
self._provider = provider
from cpl.graphql.service.schema import Schema
self._schema = provider.get_service(Schema)
def with_query(self, name: str, subquery_cls: Type["Query"]) -> Field:
sub = self._provider.get_service(subquery_cls)
if not sub:
raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider")
return self.field(name, sub.to_strawberry(), lambda: sub)
def collection_field(
self,
t: type,
name: str,
filter_type: Type[StrawberryProtocol],
sort_type: Type[StrawberryProtocol],
resolver: Callable,
) -> Field:
def _resolve_collection(filter=None, sort=None, skip=0, take=10):
items = resolver()
if filter:
for field, value in filter.__dict__.items():
if value is None:
continue
items = [i for i in items if getattr(i, field) == value]
if sort:
for field, direction in sort.__dict__.items():
reverse = direction == SortOrder.DESC
items = sorted(items, key=lambda i: getattr(i, field), reverse=reverse)
total_count = len(items)
paged = items[skip : skip + take]
return Collection(nodes=paged, total_count=total_count, count=len(paged))
filter = self._provider.get_service(filter_type)
if not filter:
raise ValueError(f"Filter '{filter_type.__name__}' not registered in service provider")
sort = self._provider.get_service(sort_type)
if not sort:
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection)
f.with_argument("filter", filter.to_strawberry())
f.with_argument("sort", sort.to_strawberry())
f.with_argument("skip", int, default_value=0)
f.with_argument("take", int, default_value=10)
return f
def dao_collection_field(
self,
t: Type[StrawberryProtocol],
dao_type: Type[DataAccessObjectABC],
name: str,
filter_type: Type[StrawberryProtocol],
sort_type: Type[StrawberryProtocol],
) -> Field:
assert issubclass(dao_type, DataAccessObjectABC), "dao_type must be a subclass of DataAccessObjectABC"
dao = self._provider.get_service(dao_type)
if not dao:
raise ValueError(f"DAO '{dao_type.__name__}' not registered in service provider")
filter = self._provider.get_service(filter_type)
if not filter:
raise ValueError(f"Filter '{filter_type.__name__}' not registered in service provider")
sort = self._provider.get_service(sort_type)
if not sort:
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
def input_to_dict(obj) -> dict | None:
if obj is None:
return None
result = {}
for k, v in obj.__dict__.items():
if v is None:
continue
if hasattr(v, "__dict__"):
result[k] = input_to_dict(v)
else:
result[k] = v
return result
async def _resolver(filter=None, sort=None, take=10, skip=0):
filter_dict = input_to_dict(filter) if filter is not None else None
sort_dict = None
if sort is not None:
sort_dict = {}
for k, v in sort.__dict__.items():
if v is None:
continue
if isinstance(v, SortOrder):
sort_dict[k] = str(v.value).lower()
continue
sort_dict[k] = str(v).lower()
total_count = await dao.count(filter_dict)
data = await dao.find_by(filter_dict, sort_dict, take, skip)
return Collection(nodes=data, total_count=total_count, count=len(data))
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver)
f.with_argument("filter", filter.to_strawberry())
f.with_argument("sort", sort.to_strawberry())
f.with_argument("skip", int, default_value=0)
f.with_argument("take", int, default_value=10)
return f

View File

@@ -0,0 +1,6 @@
from cpl.graphql.schema.mutation import Mutation
class RootMutation(Mutation):
def __init__(self):
Mutation.__init__(self)

View File

@@ -0,0 +1,6 @@
from cpl.graphql.schema.query import Query
class RootQuery(Query):
def __init__(self):
Query.__init__(self)

View File

@@ -0,0 +1,6 @@
from cpl.graphql.schema.subscription import Subscription
class RootSubscription(Subscription):
def __init__(self):
Subscription.__init__(self)

View File

@@ -0,0 +1,19 @@
from typing import Generic
from cpl.core.configuration import Configuration
from cpl.core.typing import T
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
class DbModelSort(Sort[T], Generic[T]):
def __init__(
self,
):
Sort.__init__(self)
self.field("id", SortOrder)
self.field("deleted", SortOrder)
if Configuration.get("GraphQLAuthModuleEnabled", False):
self.field("editor", SortOrder)
self.field("created", SortOrder)
self.field("updated", SortOrder)

View File

@@ -0,0 +1,9 @@
from cpl.core.typing import T
from cpl.graphql.schema.input import Input
class Sort(Input[T]):
def __init__(
self,
):
Input.__init__(self)

View File

@@ -0,0 +1,6 @@
from enum import Enum, auto
class SortOrder(Enum):
ASC = "ASC"
DESC = "DESC"

View File

@@ -0,0 +1,88 @@
import inspect
from typing import Any, Type, Optional, Self
import strawberry
from strawberry.exceptions import StrawberryException
from cpl.api import Unauthorized, Forbidden
from cpl.core.ctx.user_context import get_user
from cpl.dependency import get_provider, inject
from cpl.dependency.event_bus import EventBusABC
from cpl.graphql.abc.query_abc import QueryABC
from cpl.graphql.error import graphql_error
from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.subscription_field import SubscriptionField
from cpl.graphql.typing import Selector
class Subscription(QueryABC):
@inject
def __init__(self, bus: EventBusABC):
QueryABC.__init__(self)
self._bus = bus
def subscription_field(
self,
name: str,
t: Type,
selector: Optional[Selector] = None,
channel: Optional[str] = None,
) -> SubscriptionField:
field = SubscriptionField(name, t, selector, channel)
self._fields[name] = field
return field
def with_subscription(self, sub_cls: Type[Self]) -> Self:
sub = get_provider().get_service(sub_cls)
if not sub:
raise ValueError(f"Subscription '{sub_cls.__name__}' not registered in provider")
for sub_name, sub_field in sub.get_fields().items():
self._fields[sub_name] = sub_field
return self
def _field_to_strawberry(self, f: SubscriptionField) -> Any:
try:
if isinstance(f, SubscriptionField):
def make_resolver(field: SubscriptionField):
async def resolver(root=None, info=None):
if not field.public:
user = get_user()
if not user:
raise graphql_error(Unauthorized(f"{field.name}: Authentication required"))
if field.require_any_permission:
ok = any([await user.has_permission(p) for p in field.require_any_permission])
if not ok:
raise graphql_error(Forbidden(f"{field.name}: Permission denied"))
if field.require_any:
perms, resolvers = field.require_any
ok = any([await user.has_permission(p) for p in perms])
if not ok:
ctx = QueryContext([x.name for x in await user.permissions])
results = [
r(ctx) if not inspect.iscoroutinefunction(r) else await r(ctx)
for r in resolvers
]
if not any(results):
raise graphql_error(Forbidden(f"{field.name}: Permission denied"))
async for event in self._bus.subscribe(field.channel):
if field.selector is None or field.selector(event, info):
yield event
return resolver
return strawberry.subscription(resolver=make_resolver(f))
async def wrapper_resolver(root=None, info=None):
yield None
return strawberry.subscription(resolver=wrapper_resolver)
except StrawberryException as e:
raise Exception(f"Error converting subscription field '{f.name}': {e}") from e

View File

@@ -0,0 +1,25 @@
from typing import Type, Callable, Optional
from cpl.graphql.schema.field import Field
from cpl.graphql.typing import Selector
class SubscriptionField(Field):
def __init__(
self,
name: str,
t: Type,
selector: Optional[Selector] = None,
channel: Optional[str] = None,
):
super().__init__(name, t)
self.selector = selector
self.channel = channel or name
def with_selector(self, selector: Selector) -> "SubscriptionField":
self.selector = selector
return self
def with_channel(self, channel: str) -> "SubscriptionField":
self.channel = channel
return self

View File

@@ -0,0 +1,52 @@
from typing import Any, Dict, Optional
from graphql import GraphQLError
from cpl.api import APILogger, APIError
from cpl.api.typing import TRequest
from cpl.graphql.service.schema import Schema
class GraphQLService:
def __init__(self, logger: APILogger, schema: Schema):
self._logger = logger
if schema.schema is None:
raise ValueError("Schema has not been built. Call schema.build() before using the service.")
self._schema = schema.schema
async def execute(
self,
query: str,
variables: Optional[Dict[str, Any]],
request: TRequest,
) -> Dict[str, Any]:
result = await self._schema.execute(
query,
variable_values=variables,
context_value={"request": request},
)
response_data: Dict[str, Any] = {}
if result.errors:
errors = []
for error in result.errors:
if isinstance(error, APIError):
self._logger.error(f"GraphQL APIError", error)
errors.append({"message": error.error_message, "extensions": {"code": error.status_code}})
continue
if isinstance(error, GraphQLError):
self._logger.error(f"GraphQLError", error)
errors.append({"message": error.message, "extensions": error.extensions})
continue
self._logger.error(f"GraphQL unexpected error", error)
errors.append({"message": str(error), "extensions": {"code": 500}})
response_data["errors"] = errors
if result.data:
response_data["data"] = result.data
return response_data

View File

@@ -0,0 +1,76 @@
import logging
from typing import Type, Self
import strawberry
from cpl.api.logger import APILogger
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.root_mutation import RootMutation
from cpl.graphql.schema.root_query import RootQuery
from cpl.graphql.schema.root_subscription import RootSubscription
class Schema:
def __init__(self, logger: APILogger, provider: ServiceProvider):
self._logger = logger
self._provider = provider
self._types: dict[str, Type[StrawberryProtocol]] = {}
self._schema = None
@property
def schema(self) -> strawberry.Schema | None:
return self._schema
@property
def query(self) -> RootQuery:
query = self._provider.get_service(RootQuery)
if not query:
raise ValueError("RootQuery not registered in service provider")
return query
@property
def mutation(self) -> RootMutation:
mutation = self._provider.get_service(RootMutation)
if not mutation:
raise ValueError("RootMutation not registered in service provider")
return mutation
@property
def subscription(self) -> RootSubscription:
subscription = self._provider.get_service(RootSubscription)
if not subscription:
raise ValueError("RootSubscription not registered in service provider")
return subscription
def with_type(self, t: Type[StrawberryProtocol]) -> Self:
self._types[t.__name__] = t
return self
def _get_types(self):
types: list[Type] = []
for t in self._types.values():
t_obj = self._provider.get_service(t)
if not t_obj:
raise ValueError(f"Type '{t.__name__}' not registered in service provider")
types.append(t_obj.to_strawberry())
return types
def build(self) -> strawberry.Schema:
logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL)
query = self.query
mutation = self.mutation
subscription = self.subscription
self._schema = strawberry.Schema(
query=query.to_strawberry() if query.fields_count > 0 else None,
mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None,
subscription=subscription.to_strawberry() if subscription.fields_count > 0 else None,
types=self._get_types(),
)
return self._schema

View File

@@ -0,0 +1,16 @@
from enum import Enum
from typing import Type, Callable, List, Tuple, Awaitable, Any
import strawberry
from cpl.auth.permission import Permissions
from cpl.graphql.query_context import QueryContext
TQuery = Type["Query"]
Resolver = Callable
Selector = Callable[[Any, strawberry.types.Info], bool]
ScalarType = str | int | float | bool | object
AttributeName = str | property
TRequireAnyPermissions = List[Enum | Permissions] | None
TRequireAnyResolvers = List[Callable[[QueryContext], bool | Awaitable[bool]],]
TRequireAny = Tuple[TRequireAnyPermissions, TRequireAnyResolvers]

View File

@@ -0,0 +1,28 @@
from cpl.core.pipes import PipeABC
from cpl.core.typing import T
from cpl.graphql.schema.collection import CollectionGraphType
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.object_graph_type import ObjectGraphType
class NamePipe(PipeABC):
@staticmethod
def to_str(value: type, *args) -> str:
if isinstance(value, str):
return value
if not isinstance(value, type):
raise ValueError(f"Expected a type, got {type(value)}")
if issubclass(value, CollectionGraphType):
return f"{value.__name__.replace(GraphType.__name__, "")}"
if issubclass(value, (ObjectGraphType, GraphType)):
return value.__name__.replace(GraphType.__name__, "")
return value.__name__
@staticmethod
def from_str(value: str, *args) -> T:
pass

View File

@@ -0,0 +1,17 @@
from typing import Type, Any
class TypeCollector:
_registry: dict[type | str, Type] = {}
@classmethod
def has(cls, base: type | str) -> bool:
return base in cls._registry
@classmethod
def get(cls, base: type | str) -> Type:
return cls._registry[base]
@classmethod
def set(cls, base: type | str, gql_type: Type):
cls._registry[base] = gql_type

View File

@@ -0,0 +1,30 @@
[build-system]
requires = ["setuptools>=70.1.0", "wheel>=0.43.0"]
build-backend = "setuptools.build_meta"
[project]
name = "cpl-database"
version = "2024.7.0"
description = "CPL database"
readme ="CPL database package"
requires-python = ">=3.12"
license = { text = "MIT" }
authors = [
{ name = "Sven Heidemann", email = "sven.heidemann@sh-edraft.de" }
]
keywords = ["cpl", "database", "backend", "shared", "library"]
dynamic = ["dependencies", "optional-dependencies"]
[project.urls]
Homepage = "https://www.sh-edraft.de"
[tool.setuptools.packages.find]
where = ["."]
include = ["cpl*"]
[tool.setuptools.dynamic]
dependencies = { file = ["requirements.txt"] }
optional-dependencies.dev = { file = ["requirements.dev.txt"] }

View File

@@ -0,0 +1 @@
black==25.1.0

View File

@@ -0,0 +1,2 @@
cpl-api
strawberry-graphql==0.282.0