Renamed project dirs
All checks were successful
Test before pr merge / test-lint (pull_request) Successful in 6s
All checks were successful
Test before pr merge / test-lint (pull_request) Successful in 6s
This commit is contained in:
1
src/graphql/cpl/graphql/__init__.py
Normal file
1
src/graphql/cpl/graphql/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "1.0.0"
|
||||
0
src/graphql/cpl/graphql/_endpoints/__init__.py
Normal file
0
src/graphql/cpl/graphql/_endpoints/__init__.py
Normal file
69
src/graphql/cpl/graphql/_endpoints/graphiql.py
Normal file
69
src/graphql/cpl/graphql/_endpoints/graphiql.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from starlette.responses import HTMLResponse
|
||||
|
||||
|
||||
async def graphiql_endpoint(request):
|
||||
return HTMLResponse(
|
||||
"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>GraphiQL</title>
|
||||
<link
|
||||
href="https://unpkg.com/graphiql@2.4.0/graphiql.min.css"
|
||||
rel="stylesheet"
|
||||
/>
|
||||
</head>
|
||||
<body style="margin:0;overflow:hidden;">
|
||||
<div id="graphiql" style="height:100vh;"></div>
|
||||
|
||||
<!-- React + ReactDOM -->
|
||||
<script src="https://unpkg.com/react@18.2.0/umd/react.production.min.js"></script>
|
||||
<script src="https://unpkg.com/react-dom@18.2.0/umd/react-dom.production.min.js"></script>
|
||||
|
||||
<!-- GraphiQL -->
|
||||
<script src="https://unpkg.com/graphiql@2.4.0/graphiql.min.js"></script>
|
||||
|
||||
<!-- GraphQL over WebSocket client -->
|
||||
<script src="https://unpkg.com/graphql-ws@5.11.3/umd/graphql-ws.min.js"></script>
|
||||
|
||||
<script>
|
||||
const httpUrl = window.location.origin + '/api/graphql';
|
||||
const wsUrl = (window.location.protocol === 'https:' ? 'wss://' : 'ws://') +
|
||||
window.location.host + '/api/graphql/ws';
|
||||
|
||||
// HTTP fetcher for queries & mutations
|
||||
const httpFetcher = async (params) => {
|
||||
const res = await fetch(httpUrl, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(params),
|
||||
});
|
||||
return res.json();
|
||||
};
|
||||
|
||||
// WebSocket fetcher for subscriptions
|
||||
const wsClient = graphqlWs.createClient({ url: wsUrl });
|
||||
const wsFetcher = (params) => ({
|
||||
subscribe: (sink) => ({
|
||||
unsubscribe: wsClient.subscribe(params, sink),
|
||||
}),
|
||||
});
|
||||
|
||||
// smart fetcher wrapper (decides HTTP or WS)
|
||||
const graphQLFetcher = (params) => {
|
||||
if (params.query.trim().startsWith('subscription')) {
|
||||
return wsFetcher(params);
|
||||
}
|
||||
return httpFetcher(params);
|
||||
};
|
||||
|
||||
ReactDOM.render(
|
||||
React.createElement(GraphiQL, { fetcher: graphQLFetcher }),
|
||||
document.getElementById('graphiql'),
|
||||
);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
)
|
||||
13
src/graphql/cpl/graphql/_endpoints/graphql.py
Normal file
13
src/graphql/cpl/graphql/_endpoints/graphql.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, JSONResponse
|
||||
|
||||
from cpl.graphql.service.graphql import GraphQLService
|
||||
|
||||
|
||||
async def graphql_endpoint(request: Request, service: GraphQLService) -> Response:
|
||||
body = await request.json()
|
||||
query = body.get("query")
|
||||
variables = body.get("variables")
|
||||
|
||||
response_data = await service.execute(query, variables, request)
|
||||
return JSONResponse(response_data)
|
||||
27
src/graphql/cpl/graphql/_endpoints/lazy_graphql_app.py
Normal file
27
src/graphql/cpl/graphql/_endpoints/lazy_graphql_app.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from strawberry.asgi import GraphQL
|
||||
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
|
||||
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
|
||||
class LazyGraphQLApp:
|
||||
|
||||
def __init__(self, services: ServiceProvider):
|
||||
self._services = services
|
||||
self._graphql_app = None
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if self._graphql_app is None:
|
||||
schema = self._services.get_service(Schema)
|
||||
if not schema or not schema.schema:
|
||||
raise RuntimeError("GraphQL Schema not available yet")
|
||||
|
||||
self._graphql_app = GraphQL(
|
||||
schema.schema,
|
||||
subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
|
||||
)
|
||||
|
||||
await self._graphql_app(scope, receive, send)
|
||||
29
src/graphql/cpl/graphql/_endpoints/playground.py
Normal file
29
src/graphql/cpl/graphql/_endpoints/playground.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, HTMLResponse
|
||||
|
||||
|
||||
async def playground_endpoint(request: Request) -> Response:
|
||||
return HTMLResponse(
|
||||
"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset=utf-8/>
|
||||
<title>GraphQL Playground</title>
|
||||
<link rel="stylesheet" href="https://unpkg.com/graphql-playground-react/build/static/css/index.css" />
|
||||
<link rel="shortcut icon" href="https://raw.githubusercontent.com/graphql/graphql-playground/master/packages/graphql-playground-react/public/favicon.png" />
|
||||
<script src="https://unpkg.com/graphql-playground-react/build/static/js/middleware.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"/>
|
||||
<script>
|
||||
window.addEventListener('load', function () {
|
||||
GraphQLPlayground.init(document.getElementById('root'), {
|
||||
endpoint: '/api/graphql'
|
||||
})
|
||||
})
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
)
|
||||
0
src/graphql/cpl/graphql/abc/__init__.py
Normal file
0
src/graphql/cpl/graphql/abc/__init__.py
Normal file
227
src/graphql/cpl/graphql/abc/query_abc.py
Normal file
227
src/graphql/cpl/graphql/abc/query_abc.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import functools
|
||||
import inspect
|
||||
import types
|
||||
from abc import ABC
|
||||
from asyncio import iscoroutinefunction
|
||||
from typing import Callable, Type, Any, Optional
|
||||
|
||||
import strawberry
|
||||
from async_property.base import AsyncPropertyDescriptor
|
||||
from strawberry.exceptions import StrawberryException
|
||||
|
||||
from cpl.api import Unauthorized, Forbidden
|
||||
from cpl.core.ctx.user_context import get_user
|
||||
from cpl.dependency import get_provider
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.error import graphql_error
|
||||
from cpl.graphql.query_context import QueryContext
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.typing import Resolver, AttributeName
|
||||
from cpl.graphql.utils.type_collector import TypeCollector
|
||||
|
||||
|
||||
class QueryABC(StrawberryProtocol, ABC):
|
||||
|
||||
def __init__(self):
|
||||
ABC.__init__(self)
|
||||
self._fields: dict[str, Field] = {}
|
||||
|
||||
@property
|
||||
def fields(self) -> dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
@property
|
||||
def fields_count(self) -> int:
|
||||
return len(self._fields)
|
||||
|
||||
def get_fields(self) -> dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
def field(
|
||||
self,
|
||||
name: AttributeName,
|
||||
t: type,
|
||||
resolver: Resolver = None,
|
||||
) -> Field:
|
||||
from cpl.graphql.schema.field import Field
|
||||
|
||||
if isinstance(name, property):
|
||||
name = name.fget.__name__
|
||||
|
||||
self._fields[name] = Field(name, t, resolver)
|
||||
return self._fields[name]
|
||||
|
||||
def string_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, str, resolver)
|
||||
|
||||
def int_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, int, resolver)
|
||||
|
||||
def float_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, float, resolver)
|
||||
|
||||
def bool_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, bool, resolver)
|
||||
|
||||
def list_field(self, name: AttributeName, t: type, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, list[t], resolver)
|
||||
|
||||
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
|
||||
if not isinstance(t, type) and callable(t):
|
||||
return self.field(name, t, resolver)
|
||||
|
||||
return self.field(name, t().to_strawberry(), resolver)
|
||||
|
||||
@staticmethod
|
||||
def _build_resolver(f: "Field"):
|
||||
params: list[inspect.Parameter] = []
|
||||
for arg in f.arguments.values():
|
||||
_type = arg.type
|
||||
if isinstance(_type, type) and issubclass(_type, StrawberryProtocol):
|
||||
_type = _type().to_strawberry()
|
||||
|
||||
ann = Optional[_type] if arg.optional else _type
|
||||
|
||||
if arg.default is None:
|
||||
param = inspect.Parameter(
|
||||
arg.name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=ann,
|
||||
)
|
||||
else:
|
||||
param = inspect.Parameter(
|
||||
arg.name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=ann,
|
||||
default=arg.default,
|
||||
)
|
||||
|
||||
params.append(param)
|
||||
|
||||
sig = inspect.Signature(parameters=params, return_annotation=f.type)
|
||||
|
||||
async def _resolver(*args, **kwargs):
|
||||
if f.resolver is None:
|
||||
return None
|
||||
|
||||
if iscoroutinefunction(f.resolver):
|
||||
return await f.resolver(*args, **kwargs)
|
||||
return f.resolver(*args, **kwargs)
|
||||
|
||||
_resolver.__signature__ = sig
|
||||
return _resolver
|
||||
|
||||
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
|
||||
sig = getattr(resolver, "__signature__", None)
|
||||
|
||||
@functools.wraps(resolver)
|
||||
async def _auth_resolver(*args, **kwargs):
|
||||
if f.public:
|
||||
return await self._run_resolver(resolver, *args, **kwargs)
|
||||
|
||||
user = get_user()
|
||||
|
||||
if user is None:
|
||||
raise graphql_error(Unauthorized(f"{f.name}: Authentication required"))
|
||||
|
||||
if f.require_any_permission:
|
||||
if not any([await user.has_permission(p) for p in f.require_any_permission]):
|
||||
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
|
||||
|
||||
if f.require_any:
|
||||
perms, resolvers = f.require_any
|
||||
if not any([await user.has_permission(p) for p in perms]):
|
||||
ctx = QueryContext([x.name for x in await user.permissions])
|
||||
resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers]
|
||||
|
||||
if not any(resolved):
|
||||
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
|
||||
|
||||
return await self._run_resolver(resolver, *args, **kwargs)
|
||||
|
||||
if sig:
|
||||
_auth_resolver.__signature__ = sig
|
||||
|
||||
return _auth_resolver
|
||||
|
||||
@staticmethod
|
||||
async def _run_resolver(r: Callable, *args, **kwargs):
|
||||
result = r(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
def _field_to_strawberry(self, f: Field) -> Any:
|
||||
resolver = None
|
||||
try:
|
||||
if f.arguments:
|
||||
resolver = self._build_resolver(f)
|
||||
elif not f.resolver:
|
||||
resolver = lambda root: None
|
||||
else:
|
||||
ann = getattr(f.resolver, "__annotations__", {})
|
||||
if "return" not in ann or ann["return"] is None:
|
||||
ann = dict(ann)
|
||||
ann["return"] = f.type
|
||||
f.resolver.__annotations__ = ann
|
||||
resolver = f.resolver
|
||||
|
||||
return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
|
||||
except StrawberryException as e:
|
||||
raise Exception(f"Error converting field '{f.name}' to strawberry field: {e}") from e
|
||||
|
||||
@staticmethod
|
||||
def _type_to_strawberry(t: Type) -> Type:
|
||||
_t = get_provider().get_service(t)
|
||||
|
||||
if isinstance(_t, StrawberryProtocol):
|
||||
return _t.to_strawberry()
|
||||
|
||||
return t
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
cls = self.__class__
|
||||
if TypeCollector.has(cls):
|
||||
return TypeCollector.get(cls)
|
||||
|
||||
gql_cls = type(f"{cls.__name__.replace('GraphType', '')}", (), {})
|
||||
# register early to handle recursive types
|
||||
TypeCollector.set(cls, gql_cls)
|
||||
|
||||
annotations: dict[str, Any] = {}
|
||||
namespace: dict[str, Any] = {}
|
||||
|
||||
for name, f in self._fields.items():
|
||||
t = f.type
|
||||
if isinstance(name, property):
|
||||
name = name.fget.__name__
|
||||
if isinstance(name, AsyncPropertyDescriptor):
|
||||
name = name.field_name
|
||||
|
||||
if isinstance(t, types.GenericAlias):
|
||||
t = t.__args__[0]
|
||||
|
||||
if callable(t) and not isinstance(t, type):
|
||||
t = self._type_to_strawberry(t())
|
||||
elif issubclass(t, StrawberryProtocol):
|
||||
t = self._type_to_strawberry(t)
|
||||
|
||||
annotations[name] = t if not f.optional else Optional[t]
|
||||
namespace[name] = self._field_to_strawberry(f)
|
||||
|
||||
namespace["__annotations__"] = annotations
|
||||
for k, v in namespace.items():
|
||||
if isinstance(k, property):
|
||||
k = k.fget.__name__
|
||||
if isinstance(k, AsyncPropertyDescriptor):
|
||||
k = k.field_name
|
||||
|
||||
setattr(gql_cls, k, v)
|
||||
|
||||
try:
|
||||
gql_cls.__annotations__ = annotations
|
||||
gql_type = strawberry.type(gql_cls)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error creating strawberry type for '{cls.__name__}': {e}") from e
|
||||
TypeCollector.set(cls, gql_type)
|
||||
return gql_type
|
||||
11
src/graphql/cpl/graphql/abc/strawberry_protocol.py
Normal file
11
src/graphql/cpl/graphql/abc/strawberry_protocol.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import Protocol, Type, runtime_checkable
|
||||
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.schema.subscription_field import SubscriptionField
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class StrawberryProtocol(Protocol):
|
||||
def to_strawberry(self) -> Type: ...
|
||||
|
||||
def get_fields(self) -> dict[str, Field | SubscriptionField]: ...
|
||||
1
src/graphql/cpl/graphql/application/__init__.py
Normal file
1
src/graphql/cpl/graphql/application/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .graphql_app import WebApp
|
||||
126
src/graphql/cpl/graphql/application/graphql_app.py
Normal file
126
src/graphql/cpl/graphql/application/graphql_app.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import socket
|
||||
from enum import Enum
|
||||
from typing import Self
|
||||
|
||||
from cpl.api.application import WebApp
|
||||
from cpl.api.model.validation_match import ValidationMatch
|
||||
from cpl.application.abc.application_abc import __not_implemented__
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.dependency.typing import Modules
|
||||
from cpl.graphql._endpoints.graphiql import graphiql_endpoint
|
||||
from cpl.graphql._endpoints.graphql import graphql_endpoint
|
||||
from cpl.graphql._endpoints.lazy_graphql_app import LazyGraphQLApp
|
||||
from cpl.graphql._endpoints.playground import playground_endpoint
|
||||
from cpl.graphql.graphql_module import GraphQLModule
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
|
||||
class GraphQLApp(WebApp):
|
||||
def __init__(self, services: ServiceProvider, modules: Modules):
|
||||
WebApp.__init__(self, services, modules, [GraphQLModule])
|
||||
|
||||
self._with_graphiql = False
|
||||
self._with_playground = False
|
||||
|
||||
def with_graphql(
|
||||
self,
|
||||
authentication: bool = False,
|
||||
roles: list[str | Enum] = None,
|
||||
permissions: list[str | Enum] = None,
|
||||
policies: list[str] = None,
|
||||
match: ValidationMatch = None,
|
||||
) -> Schema:
|
||||
self.with_route(
|
||||
path="/api/graphql",
|
||||
fn=graphql_endpoint,
|
||||
method="POST",
|
||||
authentication=authentication,
|
||||
roles=roles,
|
||||
permissions=permissions,
|
||||
policies=policies,
|
||||
match=match,
|
||||
)
|
||||
schema = self._services.get_service(Schema)
|
||||
if schema is None:
|
||||
self._logger.fatal("Could not resolve RootQuery. Make sure GraphQLModule is registered.")
|
||||
#
|
||||
# graphql_ws_app = GraphQL(
|
||||
# schema,
|
||||
# subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
|
||||
# )
|
||||
self.with_websocket("/api/graphql/ws", LazyGraphQLApp(self._services))
|
||||
return schema
|
||||
|
||||
def with_graphiql(
|
||||
self,
|
||||
authentication: bool = False,
|
||||
roles: list[str | Enum] = None,
|
||||
permissions: list[str | Enum] = None,
|
||||
policies: list[str] = None,
|
||||
match: ValidationMatch = None,
|
||||
) -> Self:
|
||||
self.with_route(
|
||||
path="/api/graphiql",
|
||||
fn=graphiql_endpoint,
|
||||
method="GET",
|
||||
authentication=authentication,
|
||||
roles=roles,
|
||||
permissions=permissions,
|
||||
policies=policies,
|
||||
match=match,
|
||||
)
|
||||
self._with_graphiql = True
|
||||
return self
|
||||
|
||||
def with_playground(
|
||||
self,
|
||||
authentication: bool = False,
|
||||
roles: list[str | Enum] = None,
|
||||
permissions: list[str | Enum] = None,
|
||||
policies: list[str] = None,
|
||||
match: ValidationMatch = None,
|
||||
) -> Self:
|
||||
self.with_route(
|
||||
path="/api/playground",
|
||||
fn=playground_endpoint,
|
||||
method="GET",
|
||||
authentication=authentication,
|
||||
roles=roles,
|
||||
permissions=permissions,
|
||||
policies=policies,
|
||||
match=match,
|
||||
)
|
||||
self._with_playground = True
|
||||
return self
|
||||
|
||||
def with_auth_root_queries(self, public: bool = False):
|
||||
try:
|
||||
from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule
|
||||
|
||||
GraphQLAuthModule.with_auth_root_queries(self._services, public=public)
|
||||
except ImportError:
|
||||
__not_implemented__("cpl-auth & cpl-graphql", self.with_auth_root_mutations)
|
||||
|
||||
def with_auth_root_mutations(self, public: bool = False):
|
||||
try:
|
||||
from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule
|
||||
|
||||
GraphQLAuthModule.with_auth_root_mutations(self._services, public=public)
|
||||
except ImportError:
|
||||
__not_implemented__("cpl-auth & cpl-graphql", self.with_auth_root_mutations)
|
||||
|
||||
async def _log_before_startup(self):
|
||||
host = self._api_settings.host
|
||||
if host == "0.0.0.0" and Environment.get_environment() == "development":
|
||||
host = "localhost"
|
||||
elif host == "0.0.0.0":
|
||||
host = socket.gethostbyname(socket.gethostname())
|
||||
|
||||
self._logger.info(f"Start API on {host}:{self._api_settings.port}")
|
||||
if self._with_graphiql:
|
||||
self._logger.warning(f"GraphiQL available at http://{host}:{self._api_settings.port}/api/graphiql")
|
||||
if self._with_playground:
|
||||
self._logger.warning(
|
||||
f"GraphQL Playground available at http://{host}:{self._api_settings.port}/api/playground"
|
||||
)
|
||||
0
src/graphql/cpl/graphql/auth/__init__.py
Normal file
0
src/graphql/cpl/graphql/auth/__init__.py
Normal file
0
src/graphql/cpl/graphql/auth/api_key/__init__.py
Normal file
0
src/graphql/cpl/graphql/auth/api_key/__init__.py
Normal file
10
src/graphql/cpl/graphql/auth/api_key/api_key_filter.py
Normal file
10
src/graphql/cpl/graphql/auth/api_key/api_key_filter.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from cpl.auth.schema import ApiKey
|
||||
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
|
||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||
|
||||
|
||||
class ApiKeyFilter(DbModelFilter[ApiKey]):
|
||||
def __init__(self, public: bool = False):
|
||||
DbModelFilter.__init__(self, public)
|
||||
|
||||
self.field("identifier", StringFilter).with_public(public)
|
||||
14
src/graphql/cpl/graphql/auth/api_key/api_key_graph_type.py
Normal file
14
src/graphql/cpl/graphql/auth/api_key/api_key_graph_type.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from cpl.auth.schema import ApiKey, RolePermissionDao
|
||||
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
|
||||
|
||||
|
||||
class ApiKeyGraphType(DbModelGraphType[ApiKey]):
|
||||
|
||||
def __init__(self, role_permission_dao: RolePermissionDao):
|
||||
DbModelGraphType.__init__(self)
|
||||
|
||||
self.string_field(ApiKey.identifier, lambda root: root.identifier)
|
||||
self.string_field(ApiKey.key, lambda root: root.key)
|
||||
self.string_field(ApiKey.permissions, lambda root: root.permissions)
|
||||
|
||||
self.set_history_reference_dao(role_permission_dao, "apikeyid")
|
||||
25
src/graphql/cpl/graphql/auth/api_key/api_key_input.py
Normal file
25
src/graphql/cpl/graphql/auth/api_key/api_key_input.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from cpl.auth.schema import ApiKey
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class ApiKeyCreateInput(Input[ApiKey]):
|
||||
identifier: str
|
||||
permissions: list[SerialId]
|
||||
|
||||
def __init__(self):
|
||||
Input.__init__(self)
|
||||
self.string_field("identifier").with_required()
|
||||
self.list_field("permissions", SerialId)
|
||||
|
||||
|
||||
class ApiKeyUpdateInput(Input[ApiKey]):
|
||||
id: SerialId
|
||||
identifier: str | None
|
||||
permissions: list[SerialId] | None
|
||||
|
||||
def __init__(self):
|
||||
Input.__init__(self)
|
||||
self.int_field("id").with_required()
|
||||
self.string_field("identifier").with_required()
|
||||
self.list_field("permissions", SerialId)
|
||||
93
src/graphql/cpl/graphql/auth/api_key/api_key_mutation.py
Normal file
93
src/graphql/cpl/graphql/auth/api_key/api_key_mutation.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from cpl.api import APILogger
|
||||
from cpl.auth.keycloak import KeycloakAdmin
|
||||
from cpl.auth.permission import Permissions
|
||||
from cpl.auth.schema import ApiKey, ApiKeyDao, ApiKeyPermissionDao, ApiKeyPermission
|
||||
from cpl.graphql.auth.api_key.api_key_input import ApiKeyUpdateInput, ApiKeyCreateInput
|
||||
from cpl.graphql.schema.mutation import Mutation
|
||||
|
||||
|
||||
class ApiKeyMutation(Mutation):
|
||||
def __init__(
|
||||
self,
|
||||
logger: APILogger,
|
||||
api_key_dao: ApiKeyDao,
|
||||
api_key_permission_dao: ApiKeyPermissionDao,
|
||||
permission_dao: ApiKeyPermissionDao,
|
||||
keycloak_admin: KeycloakAdmin,
|
||||
):
|
||||
Mutation.__init__(self)
|
||||
self._logger = logger
|
||||
self._api_key_dao = api_key_dao
|
||||
self._api_key_permission_dao = api_key_permission_dao
|
||||
self._permission_dao = permission_dao
|
||||
self._keycloak_admin = keycloak_admin
|
||||
|
||||
self.int_field(
|
||||
"create",
|
||||
self.resolve_create,
|
||||
).with_require_any_permission(Permissions.api_keys_create).with_argument(
|
||||
"input",
|
||||
ApiKeyCreateInput,
|
||||
).with_required()
|
||||
|
||||
self.bool_field(
|
||||
"update",
|
||||
self.resolve_update,
|
||||
).with_require_any_permission(Permissions.api_keys_update).with_argument(
|
||||
"input",
|
||||
ApiKeyUpdateInput,
|
||||
).with_required()
|
||||
|
||||
self.bool_field(
|
||||
"delete",
|
||||
self.resolve_delete,
|
||||
).with_require_any_permission(Permissions.api_keys_delete).with_argument(
|
||||
"id",
|
||||
int,
|
||||
).with_required()
|
||||
|
||||
self.bool_field(
|
||||
"restore",
|
||||
self.resolve_restore,
|
||||
).with_require_any_permission(Permissions.api_keys_delete).with_argument(
|
||||
"id",
|
||||
int,
|
||||
).with_required()
|
||||
|
||||
async def resolve_create(self, obj: ApiKeyCreateInput):
|
||||
self._logger.debug(f"create api key: {obj.__dict__}")
|
||||
|
||||
api_key = ApiKey.new(obj.identifier)
|
||||
await self._api_key_dao.create(api_key)
|
||||
api_key = await self._api_key_dao.get_single_by([{ApiKey.identifier: obj.identifier}])
|
||||
await self._api_key_permission_dao.create_many([ApiKeyPermission(0, api_key.id, x) for x in obj.permissions])
|
||||
return api_key
|
||||
|
||||
async def resolve_update(self, input: ApiKeyUpdateInput):
|
||||
self._logger.debug(f"update api key: {input}")
|
||||
api_key = await self._api_key_dao.get_by_id(input.id)
|
||||
|
||||
await self._resolve_assignments(
|
||||
input.permissions or [],
|
||||
api_key,
|
||||
ApiKeyPermission.api_key_id,
|
||||
ApiKeyPermission.permission_id,
|
||||
self._api_key_dao,
|
||||
self._api_key_permission_dao,
|
||||
ApiKeyPermission,
|
||||
self._permission_dao,
|
||||
)
|
||||
|
||||
return api_key
|
||||
|
||||
async def resolve_delete(self, id: str):
|
||||
self._logger.debug(f"delete api key: {id}")
|
||||
api_key = await self._api_key_dao.get_by_id(id)
|
||||
await self._api_key_dao.delete(api_key)
|
||||
return True
|
||||
|
||||
async def resolve_restore(self, id: str):
|
||||
self._logger.debug(f"restore api key: {id}")
|
||||
api_key = await self._api_key_dao.get_by_id(id)
|
||||
await self._api_key_dao.restore(api_key)
|
||||
return True
|
||||
9
src/graphql/cpl/graphql/auth/api_key/api_key_sort.py
Normal file
9
src/graphql/cpl/graphql/auth/api_key/api_key_sort.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from cpl.auth.schema import ApiKey
|
||||
from cpl.graphql.schema.sort.db_model_sort import DbModelSort
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
|
||||
|
||||
class ApiKeySort(DbModelSort[ApiKey]):
|
||||
def __init__(self):
|
||||
DbModelSort.__init__(self)
|
||||
self.field("identifier", SortOrder)
|
||||
77
src/graphql/cpl/graphql/auth/graphql_auth_module.py
Normal file
77
src/graphql/cpl/graphql/auth/graphql_auth_module.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from cpl.auth.permission import Permissions
|
||||
from cpl.auth.schema import UserDao, ApiKeyDao, RoleDao
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.dependency.module.module import Module
|
||||
from cpl.dependency.service_collection import ServiceCollection
|
||||
from cpl.graphql.auth.api_key.api_key_filter import ApiKeyFilter
|
||||
from cpl.graphql.auth.api_key.api_key_graph_type import ApiKeyGraphType
|
||||
from cpl.graphql.auth.api_key.api_key_mutation import ApiKeyMutation
|
||||
from cpl.graphql.auth.api_key.api_key_sort import ApiKeySort
|
||||
from cpl.graphql.auth.role.role_filter import RoleFilter
|
||||
from cpl.graphql.auth.role.role_graph_type import RoleGraphType
|
||||
from cpl.graphql.auth.role.role_mutation import RoleMutation
|
||||
from cpl.graphql.auth.role.role_sort import RoleSort
|
||||
from cpl.graphql.auth.user.user_filter import UserFilter
|
||||
from cpl.graphql.auth.user.user_graph_type import UserGraphType
|
||||
from cpl.graphql.auth.user.user_mutation import UserMutation
|
||||
from cpl.graphql.auth.user.user_sort import UserSort
|
||||
from cpl.graphql.graphql_module import GraphQLModule
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
|
||||
class GraphQLAuthModule(Module):
|
||||
dependencies = [GraphQLModule]
|
||||
transient = [
|
||||
UserGraphType,
|
||||
UserMutation,
|
||||
UserFilter,
|
||||
UserSort,
|
||||
ApiKeyGraphType,
|
||||
ApiKeyMutation,
|
||||
ApiKeyFilter,
|
||||
ApiKeySort,
|
||||
RoleGraphType,
|
||||
RoleMutation,
|
||||
RoleFilter,
|
||||
RoleSort,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def register(collection: ServiceCollection):
|
||||
Configuration.set("GraphQLAuthModuleEnabled", True)
|
||||
|
||||
@staticmethod
|
||||
def configure(provider: ServiceProvider):
|
||||
schema = provider.get_service(Schema)
|
||||
schema.with_type(UserGraphType)
|
||||
schema.with_type(ApiKeyGraphType)
|
||||
schema.with_type(RoleGraphType)
|
||||
|
||||
@staticmethod
|
||||
def with_auth_root_queries(provider: ServiceProvider, public: bool = False):
|
||||
if not Configuration.get("GraphQLAuthModuleEnabled", False):
|
||||
raise Exception("GraphQLAuthModule is not loaded yet. Make sure to run 'add_module(GraphQLAuthModule)'")
|
||||
|
||||
schema = provider.get_service(Schema)
|
||||
schema.query.dao_collection_field(
|
||||
UserGraphType, UserDao, "users", UserFilter, UserSort
|
||||
).with_require_any_permission(Permissions.users).with_public(public)
|
||||
|
||||
schema.query.dao_collection_field(
|
||||
ApiKeyGraphType, ApiKeyDao, "apiKeys", ApiKeyFilter, ApiKeySort
|
||||
).with_require_any_permission(Permissions.api_keys).with_public(public)
|
||||
|
||||
schema.query.dao_collection_field(
|
||||
RoleGraphType, RoleDao, "roles", RoleFilter, RoleSort
|
||||
).with_require_any_permission(Permissions.roles).with_public(public)
|
||||
|
||||
@staticmethod
|
||||
def with_auth_root_mutations(provider: ServiceProvider, public: bool = False):
|
||||
if not Configuration.get("GraphQLAuthModuleEnabled", False):
|
||||
raise Exception("GraphQLAuthModule is not loaded yet. Make sure to run 'add_module(GraphQLAuthModule)'")
|
||||
|
||||
schema = provider.get_service(Schema)
|
||||
schema.mutation.with_mutation("user", UserMutation).with_public(public)
|
||||
schema.mutation.with_mutation("apiKey", ApiKeyMutation).with_public(public)
|
||||
schema.mutation.with_mutation("role", RoleMutation).with_public(public)
|
||||
0
src/graphql/cpl/graphql/auth/role/__init__.py
Normal file
0
src/graphql/cpl/graphql/auth/role/__init__.py
Normal file
11
src/graphql/cpl/graphql/auth/role/role_filter.py
Normal file
11
src/graphql/cpl/graphql/auth/role/role_filter.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from cpl.auth.schema import User, Role
|
||||
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
|
||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||
|
||||
|
||||
class RoleFilter(DbModelFilter[Role]):
|
||||
def __init__(self, public: bool = False):
|
||||
DbModelFilter.__init__(self, public)
|
||||
|
||||
self.field("name", StringFilter).with_public(public)
|
||||
self.field("description", StringFilter).with_public(public)
|
||||
14
src/graphql/cpl/graphql/auth/role/role_graph_type.py
Normal file
14
src/graphql/cpl/graphql/auth/role/role_graph_type.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from cpl.auth.schema import Role
|
||||
from cpl.graphql.auth.user.user_graph_type import UserGraphType
|
||||
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
|
||||
|
||||
|
||||
class RoleGraphType(DbModelGraphType[Role]):
|
||||
|
||||
def __init__(self, public: bool = False):
|
||||
DbModelGraphType.__init__(self)
|
||||
|
||||
self.string_field("name", lambda root: root.name).with_public(public)
|
||||
self.string_field("description", lambda root: root.description).with_public(public)
|
||||
self.list_field("permissions", str, lambda root: root.permissions).with_public(public)
|
||||
self.list_field("users", UserGraphType, lambda root: root.users).with_public(public)
|
||||
29
src/graphql/cpl/graphql/auth/role/role_input.py
Normal file
29
src/graphql/cpl/graphql/auth/role/role_input.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from cpl.auth.schema import User, Role
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class RoleCreateInput(Input[Role]):
|
||||
name: str
|
||||
description: str | None
|
||||
permissions: list[SerialId] | None
|
||||
|
||||
def __init__(self):
|
||||
Input.__init__(self)
|
||||
self.string_field("name").with_required()
|
||||
self.string_field("description")
|
||||
self.list_field("permissions", SerialId)
|
||||
|
||||
|
||||
class RoleUpdateInput(Input[Role]):
|
||||
id: SerialId
|
||||
name: str | None
|
||||
description: str | None
|
||||
permissions: list[SerialId] | None
|
||||
|
||||
def __init__(self):
|
||||
Input.__init__(self)
|
||||
self.int_field("id").with_required()
|
||||
self.string_field("name")
|
||||
self.string_field("description")
|
||||
self.list_field("permissions", SerialId)
|
||||
101
src/graphql/cpl/graphql/auth/role/role_mutation.py
Normal file
101
src/graphql/cpl/graphql/auth/role/role_mutation.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from cpl.api import APILogger
|
||||
from cpl.auth.keycloak import KeycloakAdmin
|
||||
from cpl.auth.permission import Permissions
|
||||
from cpl.auth.schema import RoleDao, Role, RolePermissionDao, RolePermission
|
||||
from cpl.graphql.auth.role.role_input import RoleCreateInput, RoleUpdateInput
|
||||
from cpl.graphql.schema.mutation import Mutation
|
||||
|
||||
|
||||
class RoleMutation(Mutation):
|
||||
def __init__(
|
||||
self,
|
||||
logger: APILogger,
|
||||
role_dao: RoleDao,
|
||||
role_permission_dao: RolePermissionDao,
|
||||
permission_dao: RolePermissionDao,
|
||||
keycloak_admin: KeycloakAdmin,
|
||||
):
|
||||
Mutation.__init__(self)
|
||||
self._logger = logger
|
||||
self._role_dao = role_dao
|
||||
self._role_permission_dao = role_permission_dao
|
||||
self._permission_dao = permission_dao
|
||||
self._keycloak_admin = keycloak_admin
|
||||
|
||||
self.int_field(
|
||||
"create",
|
||||
self.resolve_create,
|
||||
).with_require_any_permission(Permissions.roles_create).with_argument(
|
||||
"input",
|
||||
RoleCreateInput,
|
||||
).with_required()
|
||||
|
||||
self.bool_field(
|
||||
"update",
|
||||
self.resolve_update,
|
||||
).with_require_any_permission(Permissions.roles_update).with_argument(
|
||||
"input",
|
||||
RoleUpdateInput,
|
||||
).with_required()
|
||||
|
||||
self.bool_field(
|
||||
"delete",
|
||||
self.resolve_delete,
|
||||
).with_require_any_permission(Permissions.roles_delete).with_argument(
|
||||
"id",
|
||||
int,
|
||||
).with_required()
|
||||
|
||||
self.bool_field(
|
||||
"restore",
|
||||
self.resolve_restore,
|
||||
).with_require_any_permission(Permissions.roles_delete).with_argument(
|
||||
"id",
|
||||
int,
|
||||
).with_required()
|
||||
|
||||
async def resolve_create(self, input: RoleCreateInput, *_):
|
||||
self._logger.debug(f"create role: {input.__dict__}")
|
||||
|
||||
role = Role(
|
||||
0,
|
||||
input.name,
|
||||
input.description,
|
||||
)
|
||||
await self._role_dao.create(role)
|
||||
role = await self._role_dao.get_by_name(role.name)
|
||||
await self._role_permission_dao.create_many([RolePermission(0, role.id, x) for x in input.permissions])
|
||||
|
||||
return role
|
||||
|
||||
async def resolve_update(self, input: RoleUpdateInput, *_):
|
||||
self._logger.debug(f"update role: {input.__dict__}")
|
||||
role = await self._role_dao.get_by_id(input.id)
|
||||
role.name = input.get("name", role.name)
|
||||
role.description = input.get("description", role.description)
|
||||
await self._role_dao.update(role)
|
||||
|
||||
await self._resolve_assignments(
|
||||
input.get("permissions", []),
|
||||
role,
|
||||
RolePermission.role_id,
|
||||
RolePermission.permission_id,
|
||||
self._role_dao,
|
||||
self._role_permission_dao,
|
||||
RolePermission,
|
||||
self._permission_dao,
|
||||
)
|
||||
|
||||
return role
|
||||
|
||||
async def resolve_delete(self, id: int):
|
||||
self._logger.debug(f"delete role: {id}")
|
||||
role = await self._role_dao.get_by_id(id)
|
||||
await self._role_dao.delete(role)
|
||||
return True
|
||||
|
||||
async def resolve_restore(self, id: int):
|
||||
self._logger.debug(f"restore role: {id}")
|
||||
role = await self._role_dao.get_by_id(id)
|
||||
await self._role_dao.restore(role)
|
||||
return True
|
||||
10
src/graphql/cpl/graphql/auth/role/role_sort.py
Normal file
10
src/graphql/cpl/graphql/auth/role/role_sort.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from cpl.auth.schema import Role
|
||||
from cpl.graphql.schema.sort.db_model_sort import DbModelSort
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
|
||||
|
||||
class RoleSort(DbModelSort[Role]):
|
||||
def __init__(self):
|
||||
DbModelSort.__init__(self)
|
||||
self.field("name", SortOrder)
|
||||
self.field("description", SortOrder)
|
||||
0
src/graphql/cpl/graphql/auth/user/__init__.py
Normal file
0
src/graphql/cpl/graphql/auth/user/__init__.py
Normal file
11
src/graphql/cpl/graphql/auth/user/user_filter.py
Normal file
11
src/graphql/cpl/graphql/auth/user/user_filter.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from cpl.auth.schema import User
|
||||
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
|
||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||
|
||||
|
||||
class UserFilter(DbModelFilter[User]):
|
||||
def __init__(self, public: bool = False):
|
||||
DbModelFilter.__init__(self, public)
|
||||
|
||||
self.field("username", StringFilter).with_public(public)
|
||||
self.field("email", StringFilter).with_public(public)
|
||||
12
src/graphql/cpl/graphql/auth/user/user_graph_type.py
Normal file
12
src/graphql/cpl/graphql/auth/user/user_graph_type.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from cpl.auth.schema import User
|
||||
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
|
||||
|
||||
|
||||
class UserGraphType(DbModelGraphType[User]):
|
||||
|
||||
def __init__(self, public: bool = False):
|
||||
DbModelGraphType.__init__(self)
|
||||
|
||||
self.string_field(User.keycloak_id, lambda root: root.keycloak_id).with_public(public)
|
||||
self.string_field(User.username, lambda root: root.username).with_public(public)
|
||||
self.string_field(User.email, lambda root: root.email).with_public(public)
|
||||
23
src/graphql/cpl/graphql/auth/user/user_input.py
Normal file
23
src/graphql/cpl/graphql/auth/user/user_input.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from cpl.auth.schema import User
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class UserCreateInput(Input[User]):
|
||||
keycloak_id: str
|
||||
roles: list[SerialId] | None
|
||||
|
||||
def __init__(self):
|
||||
Input.__init__(self)
|
||||
self.string_field("keycloak_id").with_required()
|
||||
self.list_field("roles", SerialId)
|
||||
|
||||
|
||||
class UserUpdateInput(Input[User]):
|
||||
id: SerialId
|
||||
roles: list[SerialId] | None
|
||||
|
||||
def __init__(self):
|
||||
Input.__init__(self)
|
||||
self.int_field("id").with_required()
|
||||
self.list_field("roles", SerialId)
|
||||
112
src/graphql/cpl/graphql/auth/user/user_mutation.py
Normal file
112
src/graphql/cpl/graphql/auth/user/user_mutation.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from cpl.api import APILogger
|
||||
from cpl.auth.keycloak import KeycloakAdmin
|
||||
from cpl.auth.permission import Permissions
|
||||
from cpl.auth.schema import UserDao, User, RoleUser, RoleUserDao, RoleDao
|
||||
from cpl.core.ctx.user_context import get_user
|
||||
from cpl.graphql.auth.user.user_input import UserCreateInput, UserUpdateInput
|
||||
from cpl.graphql.schema.mutation import Mutation
|
||||
|
||||
|
||||
class UserMutation(Mutation):
|
||||
def __init__(
|
||||
self,
|
||||
logger: APILogger,
|
||||
user_dao: UserDao,
|
||||
role_user_dao: RoleUserDao,
|
||||
role_dao: RoleDao,
|
||||
keycloak_admin: KeycloakAdmin,
|
||||
):
|
||||
Mutation.__init__(self)
|
||||
self._logger = logger
|
||||
self._user_dao = user_dao
|
||||
self._role_user_dao = role_user_dao
|
||||
self._role_dao = role_dao
|
||||
self._keycloak_admin = keycloak_admin
|
||||
|
||||
self.int_field(
|
||||
"create",
|
||||
self.resolve_create,
|
||||
).with_require_any_permission(Permissions.users_create).with_argument(
|
||||
"input",
|
||||
UserCreateInput,
|
||||
).with_required()
|
||||
|
||||
self.bool_field(
|
||||
"update",
|
||||
self.resolve_update,
|
||||
).with_require_any_permission(Permissions.users_update).with_argument(
|
||||
"input",
|
||||
UserUpdateInput,
|
||||
).with_required()
|
||||
|
||||
self.bool_field(
|
||||
"delete",
|
||||
self.resolve_delete,
|
||||
).with_require_any_permission(Permissions.users_delete).with_argument(
|
||||
"id",
|
||||
int,
|
||||
).with_required()
|
||||
|
||||
self.bool_field(
|
||||
"restore",
|
||||
self.resolve_restore,
|
||||
).with_require_any_permission(Permissions.users_delete).with_argument(
|
||||
"id",
|
||||
int,
|
||||
).with_required()
|
||||
|
||||
async def resolve_create(self, input: UserCreateInput):
|
||||
self._logger.debug(f"create user: {input.__dict__}")
|
||||
|
||||
# ensure keycloak knows a user with this keycloak_id
|
||||
# get_user should raise an exception if the user does not exist
|
||||
kc_user = self._keycloak_admin.get_user(input.keycloak_id)
|
||||
if kc_user is None:
|
||||
raise ValueError(f"Keycloak user with id {input.keycloak_id} does not exist")
|
||||
|
||||
user = User(0, input.keycloak_id, input.license)
|
||||
user_id = await self._user_dao.create(user)
|
||||
user = await self._user_dao.get_by_id(user_id)
|
||||
await self._role_user_dao.create_many([RoleUser(0, user.id, x) for x in set(input.roles)])
|
||||
|
||||
return user
|
||||
|
||||
async def resolve_update(self, input: UserUpdateInput):
|
||||
self._logger.debug(f"update user: {input.__dict__}")
|
||||
user = await self._user_dao.get_by_id(input.id)
|
||||
|
||||
if input.license:
|
||||
user.license = input.license
|
||||
|
||||
await self._user_dao.update(user)
|
||||
await self._resolve_assignments(
|
||||
input.roles or [],
|
||||
user,
|
||||
RoleUser.user_id,
|
||||
RoleUser.role_id,
|
||||
self._user_dao,
|
||||
self._role_user_dao,
|
||||
RoleUser,
|
||||
self._role_dao,
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
async def resolve_delete(self, id: int):
|
||||
self._logger.debug(f"delete user: {id}")
|
||||
user = await self._user_dao.get_by_id(id)
|
||||
await self._user_dao.delete(user)
|
||||
try:
|
||||
active_user = get_user()
|
||||
if active_user is not None and active_user.id == user.id:
|
||||
# await broadcast.publish("userLogout", user.id)
|
||||
self._keycloak_admin.user_logout(user_id=user.keycloak_id)
|
||||
except Exception as e:
|
||||
self._logger.error(f"Failed to logout user from Keycloak", e)
|
||||
return True
|
||||
|
||||
async def resolve_restore(self, id: int):
|
||||
self._logger.debug(f"restore user: {id}")
|
||||
user = await self._user_dao.get_by_id(id)
|
||||
await self._user_dao.restore(user)
|
||||
return True
|
||||
10
src/graphql/cpl/graphql/auth/user/user_sort.py
Normal file
10
src/graphql/cpl/graphql/auth/user/user_sort.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from cpl.auth.schema import User
|
||||
from cpl.graphql.schema.sort.db_model_sort import DbModelSort
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
|
||||
|
||||
class UserSort(DbModelSort[User]):
|
||||
def __init__(self):
|
||||
DbModelSort.__init__(self)
|
||||
self.field("username", SortOrder)
|
||||
self.field("email", SortOrder)
|
||||
14
src/graphql/cpl/graphql/error.py
Normal file
14
src/graphql/cpl/graphql/error.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from graphql import GraphQLError
|
||||
|
||||
from cpl.api import APIError
|
||||
|
||||
|
||||
def graphql_error(api_error: APIError) -> GraphQLError:
|
||||
"""Convert an APIError (from cpl-api) into a GraphQL-friendly error."""
|
||||
return GraphQLError(
|
||||
message=api_error.error_message,
|
||||
extensions={
|
||||
"code": api_error.status_code,
|
||||
},
|
||||
original_error=api_error,
|
||||
)
|
||||
0
src/graphql/cpl/graphql/event_bus/__init__.py
Normal file
0
src/graphql/cpl/graphql/event_bus/__init__.py
Normal file
27
src/graphql/cpl/graphql/event_bus/memory.py
Normal file
27
src/graphql/cpl/graphql/event_bus/memory.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from cpl.dependency.event_bus import EventBusABC
|
||||
|
||||
|
||||
class InMemoryEventBus(EventBusABC):
|
||||
def __init__(self):
|
||||
self._subscribers: dict[str, list[asyncio.Queue]] = {}
|
||||
|
||||
async def publish(self, channel: str, event: Any) -> None:
|
||||
queues = self._subscribers.get(channel, [])
|
||||
for q in queues.copy():
|
||||
await q.put(event)
|
||||
|
||||
async def subscribe(self, channel: str) -> AsyncGenerator[Any, None]:
|
||||
q = asyncio.Queue()
|
||||
if channel not in self._subscribers:
|
||||
self._subscribers[channel] = []
|
||||
self._subscribers[channel].append(q)
|
||||
|
||||
try:
|
||||
while True:
|
||||
item = await q.get()
|
||||
yield item
|
||||
finally:
|
||||
self._subscribers[channel].remove(q)
|
||||
25
src/graphql/cpl/graphql/graphql_module.py
Normal file
25
src/graphql/cpl/graphql/graphql_module.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from cpl.api.api_module import ApiModule
|
||||
from cpl.dependency.module.module import Module
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.schema.filter.bool_filter import BoolFilter
|
||||
from cpl.graphql.schema.filter.date_filter import DateFilter
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
from cpl.graphql.schema.filter.int_filter import IntFilter
|
||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||
from cpl.graphql.schema.root_mutation import RootMutation
|
||||
from cpl.graphql.schema.root_query import RootQuery
|
||||
from cpl.graphql.schema.root_subscription import RootSubscription
|
||||
from cpl.graphql.service.graphql import GraphQLService
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
|
||||
class GraphQLModule(Module):
|
||||
dependencies = [ApiModule]
|
||||
singleton = [Schema, RootQuery, RootMutation, RootSubscription]
|
||||
scoped = [GraphQLService]
|
||||
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]
|
||||
|
||||
@staticmethod
|
||||
def configure(services: ServiceProvider) -> None:
|
||||
schema = services.get_service(Schema)
|
||||
schema.build()
|
||||
48
src/graphql/cpl/graphql/query_context.py
Normal file
48
src/graphql/cpl/graphql/query_context.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from graphql import GraphQLResolveInfo
|
||||
|
||||
from cpl.auth.schema import User, Permission
|
||||
from cpl.core.ctx import get_user
|
||||
|
||||
|
||||
class QueryContext:
|
||||
|
||||
def __init__(self, user_permissions: Optional[list[Enum | Permission]], is_mutation: bool = False, *args, **kwargs):
|
||||
self._user = get_user()
|
||||
self._user_permissions = user_permissions or []
|
||||
|
||||
self._resolve_info = None
|
||||
for arg in args:
|
||||
if isinstance(arg, GraphQLResolveInfo):
|
||||
self._resolve_info = arg
|
||||
continue
|
||||
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
|
||||
self._is_mutation = is_mutation
|
||||
|
||||
@property
|
||||
def user(self) -> User:
|
||||
return self._user
|
||||
|
||||
@property
|
||||
def resolve_info(self) -> Optional[GraphQLResolveInfo]:
|
||||
return self._resolve_info
|
||||
|
||||
@property
|
||||
def args(self) -> tuple:
|
||||
return self._args
|
||||
|
||||
@property
|
||||
def kwargs(self) -> dict:
|
||||
return self._kwargs
|
||||
|
||||
@property
|
||||
def is_mutation(self) -> bool:
|
||||
return self._is_mutation
|
||||
|
||||
def has_permission(self, permission: Enum | str) -> bool:
|
||||
return permission.value if isinstance(permission, Enum) else permission in self._user_permissions
|
||||
0
src/graphql/cpl/graphql/schema/__init__.py
Normal file
0
src/graphql/cpl/graphql/schema/__init__.py
Normal file
54
src/graphql/cpl/graphql/schema/argument.py
Normal file
54
src/graphql/cpl/graphql/schema/argument.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import Any, Self
|
||||
|
||||
|
||||
class Argument:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
t: type,
|
||||
description: str = None,
|
||||
default: Any = None,
|
||||
optional: bool = None,
|
||||
):
|
||||
self._name = name
|
||||
self._type = t
|
||||
self._description = description
|
||||
self._default = default
|
||||
self._optional = optional
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def type(self) -> type:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def default(self) -> Any | None:
|
||||
return self._default
|
||||
|
||||
@property
|
||||
def optional(self) -> bool | None:
|
||||
return self._optional
|
||||
|
||||
def with_description(self, description: str) -> Self:
|
||||
self._description = description
|
||||
return self
|
||||
|
||||
def with_default(self, default: Any) -> Self:
|
||||
self._default = default
|
||||
return self
|
||||
|
||||
def with_optional(self, optional: bool) -> Self:
|
||||
self._optional = optional
|
||||
return self
|
||||
|
||||
def with_required(self, required: bool = True) -> Self:
|
||||
self._optional = not required
|
||||
return self
|
||||
61
src/graphql/cpl/graphql/schema/collection.py
Normal file
61
src/graphql/cpl/graphql/schema/collection.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import Type, Dict, List
|
||||
|
||||
import strawberry
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.dependency import get_provider
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
|
||||
|
||||
from cpl.graphql.utils.type_collector import TypeCollector
|
||||
|
||||
|
||||
class CollectionGraphTypeFactory:
|
||||
@classmethod
|
||||
def get(cls, node_type: Type[StrawberryProtocol]) -> Type:
|
||||
type_name = f"{node_type.__name__.replace('GraphType', '')}Collection"
|
||||
|
||||
if TypeCollector.has(type_name):
|
||||
return TypeCollector.get(type_name)
|
||||
|
||||
node_t = get_provider().get_service(node_type)
|
||||
if not node_t:
|
||||
raise ValueError(f"Node type '{node_type.__name__}' not registered in service provider")
|
||||
|
||||
gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type
|
||||
|
||||
gql_cls = type(type_name, (), {})
|
||||
|
||||
TypeCollector.set(type_name, gql_cls)
|
||||
|
||||
gql_cls.__annotations__ = {
|
||||
"nodes": List[gql_node],
|
||||
"total_count": int,
|
||||
"count": int,
|
||||
}
|
||||
for k in gql_cls.__annotations__.keys():
|
||||
setattr(gql_cls, k, strawberry.field())
|
||||
|
||||
gql_type = strawberry.type(gql_cls)
|
||||
|
||||
TypeCollector.set(type_name, gql_type)
|
||||
return gql_type
|
||||
|
||||
|
||||
class Collection:
|
||||
def __init__(self, nodes: list[T], total_count: int, count: int):
|
||||
self._nodes = nodes
|
||||
self._total_count = total_count
|
||||
self._count = count
|
||||
|
||||
@property
|
||||
def nodes(self) -> list[T]:
|
||||
return self._nodes
|
||||
|
||||
@property
|
||||
def total_count(self) -> int:
|
||||
return self._total_count
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return self._count
|
||||
62
src/graphql/cpl/graphql/schema/db_model_graph_type.py
Normal file
62
src/graphql/cpl/graphql/schema/db_model_graph_type.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Type, Optional, Generic, Annotated
|
||||
|
||||
import strawberry
|
||||
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.typing import T
|
||||
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
||||
from cpl.graphql.schema.graph_type import GraphType
|
||||
from cpl.graphql.schema.query import Query
|
||||
|
||||
|
||||
class DbModelGraphType(GraphType[T], Generic[T]):
|
||||
|
||||
def __init__(self, t_dao: Type[DataAccessObjectABC] = None, with_history: bool = False, public: bool = False):
|
||||
Query.__init__(self)
|
||||
|
||||
self._dao: Optional[DataAccessObjectABC] = None
|
||||
|
||||
if t_dao is not None:
|
||||
dao = self._provider.get_service(t_dao)
|
||||
if dao is not None:
|
||||
self._dao = dao
|
||||
|
||||
self.int_field("id", lambda root: root.id).with_public(public)
|
||||
self.bool_field("deleted", lambda root: root.deleted).with_public(public)
|
||||
|
||||
if Configuration.get("GraphQLAuthModuleEnabled", False):
|
||||
from cpl.graphql.auth.user.user_graph_type import UserGraphType
|
||||
|
||||
self.object_field("editor", lambda: UserGraphType, lambda root: root.editor).with_public(public)
|
||||
|
||||
self.string_field("created", lambda root: root.created).with_public(public)
|
||||
self.string_field("updated", lambda root: root.updated).with_public(public)
|
||||
|
||||
# if with_history:
|
||||
# if self._dao is None:
|
||||
# raise ValueError("DAO must be provided to enable history")
|
||||
# self.set_field("history", self._resolve_history).with_public(public)
|
||||
|
||||
self._history_reference_daos: dict[DataAccessObjectABC, str] = {}
|
||||
|
||||
async def _resolve_history(self, root):
|
||||
if self._dao is None:
|
||||
raise Exception("DAO not set for history query")
|
||||
|
||||
history = sorted(
|
||||
[await self._dao.get_by_id(root.id), *await self._dao.get_history(root.id)],
|
||||
key=lambda h: h.updated,
|
||||
reverse=True,
|
||||
)
|
||||
return history
|
||||
|
||||
def set_history_reference_dao(self, dao: DataAccessObjectABC, key: str = None):
|
||||
"""
|
||||
Set the reference DAO for history resolution.
|
||||
:param dao:
|
||||
:param key: The key to use for resolving history.
|
||||
:return:
|
||||
"""
|
||||
if key is None:
|
||||
key = "id"
|
||||
self._history_reference_daos[dao] = key
|
||||
141
src/graphql/cpl/graphql/schema/field.py
Normal file
141
src/graphql/cpl/graphql/schema/field.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from enum import Enum
|
||||
from typing import Self
|
||||
|
||||
from cpl.graphql.schema.argument import Argument
|
||||
from cpl.graphql.typing import TQuery, Resolver, TRequireAnyPermissions, TRequireAnyResolvers
|
||||
|
||||
|
||||
class Field:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
t: type = None,
|
||||
resolver: Resolver = None,
|
||||
optional=None,
|
||||
default=None,
|
||||
subquery: TQuery = None,
|
||||
parent_type=None,
|
||||
):
|
||||
self._name = name
|
||||
self._type = t
|
||||
self._resolver = resolver
|
||||
self._optional = optional or True
|
||||
self._default = default
|
||||
|
||||
self._subquery = subquery
|
||||
self._parent_type = parent_type
|
||||
|
||||
self._args: dict[str, Argument] = {}
|
||||
self._require_any_permission = None
|
||||
self._require_any = None
|
||||
self._public = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def type(self) -> type:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def resolver(self) -> callable:
|
||||
return self._resolver
|
||||
|
||||
@property
|
||||
def optional(self) -> bool | None:
|
||||
return self._optional
|
||||
|
||||
@property
|
||||
def default(self):
|
||||
return self._default
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
return self._args
|
||||
|
||||
@property
|
||||
def subquery(self) -> TQuery | None:
|
||||
return self._subquery
|
||||
|
||||
@property
|
||||
def parent_type(self):
|
||||
return self._parent_type
|
||||
|
||||
@property
|
||||
def arguments(self) -> dict[str, Argument]:
|
||||
return self._args
|
||||
|
||||
@property
|
||||
def require_any_permission(self) -> TRequireAnyPermissions | None:
|
||||
return self._require_any_permission
|
||||
|
||||
@property
|
||||
def require_any(self) -> TRequireAnyResolvers | None:
|
||||
return self._require_any
|
||||
|
||||
@property
|
||||
def public(self) -> bool:
|
||||
return self._public
|
||||
|
||||
def with_type(self, t: type) -> Self:
|
||||
self._type = t
|
||||
return self
|
||||
|
||||
def with_resolver(self, resolver: Resolver) -> Self:
|
||||
self._resolver = resolver
|
||||
return self
|
||||
|
||||
def with_optional(self, optional: bool = True) -> Self:
|
||||
self._optional = optional
|
||||
return self
|
||||
|
||||
def with_required(self, required: bool = True) -> Self:
|
||||
self._optional = not required
|
||||
return self
|
||||
|
||||
def with_default(self, default) -> Self:
|
||||
self._default = default
|
||||
return self
|
||||
|
||||
def with_argument(
|
||||
self, name: str, arg_type: type, description: str = None, default_value=None, optional=True
|
||||
) -> Argument:
|
||||
if name in self._args:
|
||||
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
|
||||
self._args[name] = Argument(name, arg_type, description, default_value, optional)
|
||||
return self._args[name]
|
||||
|
||||
def with_arguments(self, args: list[Argument]) -> Self:
|
||||
for arg in args:
|
||||
if not isinstance(arg, Argument):
|
||||
raise ValueError(f"Expected Argument instance, got {type(arg)}")
|
||||
|
||||
self.with_argument(arg.type, arg.name, arg.description, arg.default, arg.optional)
|
||||
return self
|
||||
|
||||
def with_require_any_permission(self, *permissions: TRequireAnyPermissions) -> Self:
|
||||
if not isinstance(permissions, list):
|
||||
permissions = list(permissions)
|
||||
|
||||
assert permissions is not None, "require_any_permission cannot be None"
|
||||
assert all(isinstance(x, (str, Enum)) for x in permissions), "All permissions must be of Permission type"
|
||||
self._require_any_permission = permissions
|
||||
return self
|
||||
|
||||
def with_require_any(self, permissions: TRequireAnyPermissions, resolvers: TRequireAnyResolvers) -> Self:
|
||||
assert permissions is not None, "permissions cannot be None"
|
||||
assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type"
|
||||
assert resolvers is not None, "resolvers cannot be None"
|
||||
assert all(callable(r) for r in resolvers), "All resolvers must be callable"
|
||||
self._require_any = (permissions, resolvers)
|
||||
return self
|
||||
|
||||
def with_public(self, public: bool = True) -> Self:
|
||||
if public:
|
||||
self._require_any = None
|
||||
self._require_any_permission = None
|
||||
|
||||
self._public = public
|
||||
return self
|
||||
0
src/graphql/cpl/graphql/schema/filter/__init__.py
Normal file
0
src/graphql/cpl/graphql/schema/filter/__init__.py
Normal file
10
src/graphql/cpl/graphql/schema/filter/bool_filter.py
Normal file
10
src/graphql/cpl/graphql/schema/filter/bool_filter.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class BoolFilter(Input[bool]):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.field("equal", bool, optional=True)
|
||||
self.field("notEqual", bool, optional=True)
|
||||
self.field("isNull", bool, optional=True)
|
||||
self.field("isNotNull", bool, optional=True)
|
||||
18
src/graphql/cpl/graphql/schema/filter/date_filter.py
Normal file
18
src/graphql/cpl/graphql/schema/filter/date_filter.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from datetime import datetime
|
||||
|
||||
from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class DateFilter(Input[datetime]):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.field("equal", datetime, optional=True)
|
||||
self.field("notEqual", datetime, optional=True)
|
||||
self.field("greater", datetime, optional=True)
|
||||
self.field("greaterOrEqual", datetime, optional=True)
|
||||
self.field("less", datetime, optional=True)
|
||||
self.field("lessOrEqual", datetime, optional=True)
|
||||
self.field("isNull", datetime, optional=True)
|
||||
self.field("isNotNull", datetime, optional=True)
|
||||
self.field("in", list[datetime], optional=True)
|
||||
self.field("notIn", list[datetime], optional=True)
|
||||
23
src/graphql/cpl/graphql/schema/filter/db_model_filter.py
Normal file
23
src/graphql/cpl/graphql/schema/filter/db_model_filter.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Generic
|
||||
|
||||
from cpl.core.configuration.configuration import Configuration
|
||||
from cpl.core.typing import T
|
||||
from cpl.graphql.schema.filter.bool_filter import BoolFilter
|
||||
from cpl.graphql.schema.filter.date_filter import DateFilter
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
from cpl.graphql.schema.filter.int_filter import IntFilter
|
||||
|
||||
|
||||
class DbModelFilter(Filter[T], Generic[T]):
|
||||
def __init__(self, public: bool = False):
|
||||
Filter.__init__(self)
|
||||
|
||||
self.field("id", IntFilter).with_public(public)
|
||||
self.field("deleted", BoolFilter).with_public(public)
|
||||
if Configuration.get("GraphQLAuthModuleEnabled", False):
|
||||
from cpl.graphql.auth.user.user_filter import UserFilter
|
||||
|
||||
self.field("editor", lambda: UserFilter).with_public(public)
|
||||
|
||||
self.field("created", DateFilter).with_public(public)
|
||||
self.field("updated", DateFilter).with_public(public)
|
||||
28
src/graphql/cpl/graphql/schema/filter/filter.py
Normal file
28
src/graphql/cpl/graphql/schema/filter/filter.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import Type
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.graphql.schema.filter.bool_filter import BoolFilter
|
||||
from cpl.graphql.schema.filter.date_filter import DateFilter
|
||||
from cpl.graphql.schema.filter.int_filter import IntFilter
|
||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||
from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class Filter(Input[T]):
|
||||
def __init__(self):
|
||||
Input.__init__(self)
|
||||
|
||||
def filter_field(self, name: str, filter_type: Type["Filter"]):
|
||||
self.field(name, filter_type)
|
||||
|
||||
def string_field(self, name: str):
|
||||
self.field(name, StringFilter)
|
||||
|
||||
def int_field(self, name: str):
|
||||
self.field(name, IntFilter)
|
||||
|
||||
def bool_field(self, name: str):
|
||||
self.field(name, BoolFilter)
|
||||
|
||||
def date_field(self, name: str):
|
||||
self.field(name, DateFilter)
|
||||
16
src/graphql/cpl/graphql/schema/filter/int_filter.py
Normal file
16
src/graphql/cpl/graphql/schema/filter/int_filter.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class IntFilter(Input[int]):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.field("equal", int, optional=True)
|
||||
self.field("notEqual", int, optional=True)
|
||||
self.field("greater", int, optional=True)
|
||||
self.field("greaterOrEqual", int, optional=True)
|
||||
self.field("less", int, optional=True)
|
||||
self.field("lessOrEqual", int, optional=True)
|
||||
self.field("isNull", int, optional=True)
|
||||
self.field("isNotNull", int, optional=True)
|
||||
self.field("in", list[int], optional=True)
|
||||
self.field("notIn", list[int], optional=True)
|
||||
16
src/graphql/cpl/graphql/schema/filter/string_filter.py
Normal file
16
src/graphql/cpl/graphql/schema/filter/string_filter.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class StringFilter(Input[str]):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.field("equal", str, optional=True)
|
||||
self.field("notEqual", str, optional=True)
|
||||
self.field("contains", str, optional=True)
|
||||
self.field("notContains", str, optional=True)
|
||||
self.field("startsWith", str, optional=True)
|
||||
self.field("endsWith", str, optional=True)
|
||||
self.field("isNull", str, optional=True)
|
||||
self.field("isNotNull", str, optional=True)
|
||||
self.field("in", list[str], optional=True)
|
||||
self.field("notIn", list[str], optional=True)
|
||||
10
src/graphql/cpl/graphql/schema/graph_type.py
Normal file
10
src/graphql/cpl/graphql/schema/graph_type.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from typing import Generic
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.graphql.schema.query import Query
|
||||
|
||||
|
||||
class GraphType(Query, Generic[T]):
|
||||
|
||||
def __init__(self):
|
||||
Query.__init__(self)
|
||||
115
src/graphql/cpl/graphql/schema/input.py
Normal file
115
src/graphql/cpl/graphql/schema/input.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import types
|
||||
from typing import Generic, Dict, Type, Optional, Union, Any
|
||||
|
||||
import strawberry
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.dependency import get_provider
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.typing import AttributeName
|
||||
from cpl.graphql.utils.type_collector import TypeCollector
|
||||
|
||||
_PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
|
||||
|
||||
|
||||
class Input(StrawberryProtocol, Generic[T]):
|
||||
def __init__(self):
|
||||
self._fields: Dict[str, Field] = {}
|
||||
self._values: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def fields(self) -> Dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item in self._values:
|
||||
return self._values[item]
|
||||
raise AttributeError(f"{self.__class__.__name__} has no attribute {item}")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in {"_fields", "_values"}:
|
||||
super().__setattr__(key, value)
|
||||
elif key in self._fields:
|
||||
self._values[key] = value
|
||||
else:
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
return self._values.get(key, default)
|
||||
|
||||
def get_fields(self) -> dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
def field(self, name: AttributeName, typ: type, optional: bool = True) -> Field:
|
||||
if isinstance(name, property):
|
||||
name = name.fget.__name__
|
||||
|
||||
self._fields[name] = Field(name, typ, optional=optional)
|
||||
return self._fields[name]
|
||||
|
||||
def string_field(self, name: AttributeName, optional: bool = True) -> Field:
|
||||
return self.field(name, str)
|
||||
|
||||
def int_field(self, name: AttributeName, optional: bool = True) -> Field:
|
||||
return self.field(name, int, optional)
|
||||
|
||||
def float_field(self, name: AttributeName, optional: bool = True) -> Field:
|
||||
return self.field(name, float, optional)
|
||||
|
||||
def bool_field(self, name: AttributeName, optional: bool = True) -> Field:
|
||||
return self.field(name, bool, optional)
|
||||
|
||||
def list_field(self, name: AttributeName, t: type, optional: bool = True) -> Field:
|
||||
return self.field(name, list[t], optional)
|
||||
|
||||
def object_field(self, name: AttributeName, t: Type[StrawberryProtocol], optional: bool = True) -> Field:
|
||||
if not isinstance(t, type) and callable(t):
|
||||
return self.field(name, t, optional)
|
||||
|
||||
return self.field(name, t().to_strawberry(), optional)
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
cls = self.__class__
|
||||
if TypeCollector.has(cls):
|
||||
return TypeCollector.get(cls)
|
||||
|
||||
gql_cls = type(f"{cls.__name__.replace('GraphType', '')}", (), {})
|
||||
# register early to handle recursive types
|
||||
TypeCollector.set(cls, gql_cls)
|
||||
|
||||
annotations: dict[str, Any] = {}
|
||||
namespace: dict[str, Any] = {}
|
||||
|
||||
for name, f in self._fields.items():
|
||||
t = f.type
|
||||
|
||||
if isinstance(t, types.FunctionType):
|
||||
_t = get_provider().get_service(t())
|
||||
if _t is None:
|
||||
raise ValueError(f"'{t()}' could not be resolved from the provider")
|
||||
t = _t.to_strawberry()
|
||||
elif isinstance(t, type) and issubclass(t, Input):
|
||||
t = t().to_strawberry()
|
||||
elif isinstance(t, Input):
|
||||
t = t.to_strawberry()
|
||||
|
||||
py_name = name + "_" if name in _PYTHON_KEYWORDS else name
|
||||
annotations[py_name] = t if not f.optional else Optional[t]
|
||||
|
||||
field_args = {}
|
||||
if py_name != name:
|
||||
field_args["name"] = name
|
||||
|
||||
default = None if f.optional else f.default
|
||||
namespace[py_name] = strawberry.field(default=default, **field_args)
|
||||
|
||||
namespace["__annotations__"] = annotations
|
||||
|
||||
for k, v in namespace.items():
|
||||
setattr(gql_cls, k, v)
|
||||
|
||||
gql_cls.__annotations__ = annotations
|
||||
gql_type = strawberry.input(gql_cls)
|
||||
TypeCollector.set(cls, gql_type)
|
||||
return gql_type
|
||||
93
src/graphql/cpl/graphql/schema/mutation.py
Normal file
93
src/graphql/cpl/graphql/schema/mutation.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import Type, Union
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.database.abc import DataAccessObjectABC, DbJoinModelABC
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.abc.query_abc import QueryABC
|
||||
from cpl.graphql.schema.field import Field
|
||||
|
||||
|
||||
class Mutation(QueryABC):
|
||||
|
||||
@inject
|
||||
def __init__(self, provider: ServiceProvider):
|
||||
QueryABC.__init__(self)
|
||||
self._provider = provider
|
||||
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
self._schema = provider.get_service(Schema)
|
||||
|
||||
def with_mutation(self, name: str, cls: Type["Mutation"]) -> Field:
|
||||
sub = self._provider.get_service(cls)
|
||||
if not sub:
|
||||
raise ValueError(f"Mutation '{cls.__name__}' not registered in service provider")
|
||||
|
||||
return self.field(name, sub.to_strawberry(), lambda: sub)
|
||||
|
||||
@staticmethod
|
||||
async def _resolve_assignments(
|
||||
foreign_objs: list[int],
|
||||
resolved_obj: T,
|
||||
reference_key_own: Union[str, property],
|
||||
reference_key_foreign: Union[str, property],
|
||||
source_dao: DataAccessObjectABC[T],
|
||||
join_dao: DataAccessObjectABC[T],
|
||||
join_type: Type[DbJoinModelABC],
|
||||
foreign_dao: DataAccessObjectABC[T],
|
||||
):
|
||||
if foreign_objs is None:
|
||||
return
|
||||
|
||||
reference_key_foreign_attr = reference_key_foreign
|
||||
if isinstance(reference_key_foreign, property):
|
||||
reference_key_foreign_attr = reference_key_foreign.fget.__name__
|
||||
|
||||
foreign_list = await join_dao.find_by([{reference_key_own: resolved_obj.id}, {"deleted": False}])
|
||||
|
||||
to_delete = (
|
||||
foreign_list
|
||||
if len(foreign_objs) == 0
|
||||
else await join_dao.find_by(
|
||||
[
|
||||
{reference_key_own: resolved_obj.id},
|
||||
{reference_key_foreign: {"notIn": foreign_objs}},
|
||||
]
|
||||
)
|
||||
)
|
||||
foreign_ids = [getattr(x, reference_key_foreign_attr) for x in foreign_list]
|
||||
deleted_foreign_ids = [
|
||||
getattr(x, reference_key_foreign_attr)
|
||||
for x in await join_dao.find_by([{reference_key_own: resolved_obj.id}, {"deleted": True}])
|
||||
]
|
||||
|
||||
to_create = [
|
||||
join_type(0, resolved_obj.id, x)
|
||||
for x in foreign_objs
|
||||
if x not in foreign_ids and x not in deleted_foreign_ids
|
||||
]
|
||||
to_restore = [
|
||||
await join_dao.get_single_by(
|
||||
[
|
||||
{reference_key_own: resolved_obj.id},
|
||||
{reference_key_foreign: x},
|
||||
]
|
||||
)
|
||||
for x in foreign_objs
|
||||
if x not in foreign_ids and x in deleted_foreign_ids
|
||||
]
|
||||
|
||||
if len(to_delete) > 0:
|
||||
await join_dao.delete_many(to_delete)
|
||||
|
||||
if len(to_create) > 0:
|
||||
await join_dao.create_many(to_create)
|
||||
|
||||
if len(to_restore) > 0:
|
||||
await join_dao.restore_many(to_restore)
|
||||
|
||||
foreign_changes = [*to_delete, *to_create, *to_restore]
|
||||
if len(foreign_changes) > 0:
|
||||
await source_dao.touch(resolved_obj)
|
||||
await foreign_dao.touch_many_by_id([getattr(x, reference_key_foreign_attr) for x in foreign_changes])
|
||||
131
src/graphql/cpl/graphql/schema/query.py
Normal file
131
src/graphql/cpl/graphql/schema/query.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from typing import Callable, Type
|
||||
|
||||
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.abc.query_abc import QueryABC
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
|
||||
|
||||
class Query(QueryABC):
|
||||
|
||||
@inject
|
||||
def __init__(self, provider: ServiceProvider):
|
||||
QueryABC.__init__(self)
|
||||
self._provider = provider
|
||||
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
self._schema = provider.get_service(Schema)
|
||||
|
||||
def with_query(self, name: str, subquery_cls: Type["Query"]) -> Field:
|
||||
sub = self._provider.get_service(subquery_cls)
|
||||
if not sub:
|
||||
raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider")
|
||||
|
||||
return self.field(name, sub.to_strawberry(), lambda: sub)
|
||||
|
||||
def collection_field(
|
||||
self,
|
||||
t: type,
|
||||
name: str,
|
||||
filter_type: Type[StrawberryProtocol],
|
||||
sort_type: Type[StrawberryProtocol],
|
||||
resolver: Callable,
|
||||
) -> Field:
|
||||
def _resolve_collection(filter=None, sort=None, skip=0, take=10):
|
||||
items = resolver()
|
||||
if filter:
|
||||
for field, value in filter.__dict__.items():
|
||||
if value is None:
|
||||
continue
|
||||
items = [i for i in items if getattr(i, field) == value]
|
||||
|
||||
if sort:
|
||||
for field, direction in sort.__dict__.items():
|
||||
reverse = direction == SortOrder.DESC
|
||||
items = sorted(items, key=lambda i: getattr(i, field), reverse=reverse)
|
||||
total_count = len(items)
|
||||
paged = items[skip : skip + take]
|
||||
return Collection(nodes=paged, total_count=total_count, count=len(paged))
|
||||
|
||||
filter = self._provider.get_service(filter_type)
|
||||
if not filter:
|
||||
raise ValueError(f"Filter '{filter_type.__name__}' not registered in service provider")
|
||||
|
||||
sort = self._provider.get_service(sort_type)
|
||||
if not sort:
|
||||
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
|
||||
|
||||
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection)
|
||||
f.with_argument("filter", filter.to_strawberry())
|
||||
f.with_argument("sort", sort.to_strawberry())
|
||||
f.with_argument("skip", int, default_value=0)
|
||||
f.with_argument("take", int, default_value=10)
|
||||
return f
|
||||
|
||||
def dao_collection_field(
|
||||
self,
|
||||
t: Type[StrawberryProtocol],
|
||||
dao_type: Type[DataAccessObjectABC],
|
||||
name: str,
|
||||
filter_type: Type[StrawberryProtocol],
|
||||
sort_type: Type[StrawberryProtocol],
|
||||
) -> Field:
|
||||
assert issubclass(dao_type, DataAccessObjectABC), "dao_type must be a subclass of DataAccessObjectABC"
|
||||
dao = self._provider.get_service(dao_type)
|
||||
if not dao:
|
||||
raise ValueError(f"DAO '{dao_type.__name__}' not registered in service provider")
|
||||
|
||||
filter = self._provider.get_service(filter_type)
|
||||
if not filter:
|
||||
raise ValueError(f"Filter '{filter_type.__name__}' not registered in service provider")
|
||||
|
||||
sort = self._provider.get_service(sort_type)
|
||||
if not sort:
|
||||
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
|
||||
|
||||
def input_to_dict(obj) -> dict | None:
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
result = {}
|
||||
for k, v in obj.__dict__.items():
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
if hasattr(v, "__dict__"):
|
||||
result[k] = input_to_dict(v)
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
async def _resolver(filter=None, sort=None, take=10, skip=0):
|
||||
filter_dict = input_to_dict(filter) if filter is not None else None
|
||||
sort_dict = None
|
||||
|
||||
if sort is not None:
|
||||
sort_dict = {}
|
||||
for k, v in sort.__dict__.items():
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
if isinstance(v, SortOrder):
|
||||
sort_dict[k] = str(v.value).lower()
|
||||
continue
|
||||
|
||||
sort_dict[k] = str(v).lower()
|
||||
|
||||
total_count = await dao.count(filter_dict)
|
||||
data = await dao.find_by(filter_dict, sort_dict, take, skip)
|
||||
return Collection(nodes=data, total_count=total_count, count=len(data))
|
||||
|
||||
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver)
|
||||
f.with_argument("filter", filter.to_strawberry())
|
||||
f.with_argument("sort", sort.to_strawberry())
|
||||
f.with_argument("skip", int, default_value=0)
|
||||
f.with_argument("take", int, default_value=10)
|
||||
return f
|
||||
6
src/graphql/cpl/graphql/schema/root_mutation.py
Normal file
6
src/graphql/cpl/graphql/schema/root_mutation.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from cpl.graphql.schema.mutation import Mutation
|
||||
|
||||
|
||||
class RootMutation(Mutation):
|
||||
def __init__(self):
|
||||
Mutation.__init__(self)
|
||||
6
src/graphql/cpl/graphql/schema/root_query.py
Normal file
6
src/graphql/cpl/graphql/schema/root_query.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from cpl.graphql.schema.query import Query
|
||||
|
||||
|
||||
class RootQuery(Query):
|
||||
def __init__(self):
|
||||
Query.__init__(self)
|
||||
6
src/graphql/cpl/graphql/schema/root_subscription.py
Normal file
6
src/graphql/cpl/graphql/schema/root_subscription.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from cpl.graphql.schema.subscription import Subscription
|
||||
|
||||
|
||||
class RootSubscription(Subscription):
|
||||
def __init__(self):
|
||||
Subscription.__init__(self)
|
||||
0
src/graphql/cpl/graphql/schema/sort/__init__.py
Normal file
0
src/graphql/cpl/graphql/schema/sort/__init__.py
Normal file
19
src/graphql/cpl/graphql/schema/sort/db_model_sort.py
Normal file
19
src/graphql/cpl/graphql/schema/sort/db_model_sort.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Generic
|
||||
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.typing import T
|
||||
from cpl.graphql.schema.sort.sort import Sort
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
|
||||
|
||||
class DbModelSort(Sort[T], Generic[T]):
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
Sort.__init__(self)
|
||||
self.field("id", SortOrder)
|
||||
self.field("deleted", SortOrder)
|
||||
if Configuration.get("GraphQLAuthModuleEnabled", False):
|
||||
self.field("editor", SortOrder)
|
||||
self.field("created", SortOrder)
|
||||
self.field("updated", SortOrder)
|
||||
9
src/graphql/cpl/graphql/schema/sort/sort.py
Normal file
9
src/graphql/cpl/graphql/schema/sort/sort.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from cpl.core.typing import T
|
||||
from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class Sort(Input[T]):
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
Input.__init__(self)
|
||||
6
src/graphql/cpl/graphql/schema/sort/sort_order.py
Normal file
6
src/graphql/cpl/graphql/schema/sort/sort_order.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
class SortOrder(Enum):
|
||||
ASC = "ASC"
|
||||
DESC = "DESC"
|
||||
88
src/graphql/cpl/graphql/schema/subscription.py
Normal file
88
src/graphql/cpl/graphql/schema/subscription.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import inspect
|
||||
from typing import Any, Type, Optional, Self
|
||||
|
||||
import strawberry
|
||||
from strawberry.exceptions import StrawberryException
|
||||
|
||||
from cpl.api import Unauthorized, Forbidden
|
||||
from cpl.core.ctx.user_context import get_user
|
||||
from cpl.dependency import get_provider, inject
|
||||
from cpl.dependency.event_bus import EventBusABC
|
||||
from cpl.graphql.abc.query_abc import QueryABC
|
||||
from cpl.graphql.error import graphql_error
|
||||
from cpl.graphql.query_context import QueryContext
|
||||
from cpl.graphql.schema.subscription_field import SubscriptionField
|
||||
from cpl.graphql.typing import Selector
|
||||
|
||||
|
||||
class Subscription(QueryABC):
|
||||
|
||||
@inject
|
||||
def __init__(self, bus: EventBusABC):
|
||||
QueryABC.__init__(self)
|
||||
self._bus = bus
|
||||
|
||||
def subscription_field(
|
||||
self,
|
||||
name: str,
|
||||
t: Type,
|
||||
selector: Optional[Selector] = None,
|
||||
channel: Optional[str] = None,
|
||||
) -> SubscriptionField:
|
||||
field = SubscriptionField(name, t, selector, channel)
|
||||
self._fields[name] = field
|
||||
return field
|
||||
|
||||
def with_subscription(self, sub_cls: Type[Self]) -> Self:
|
||||
sub = get_provider().get_service(sub_cls)
|
||||
if not sub:
|
||||
raise ValueError(f"Subscription '{sub_cls.__name__}' not registered in provider")
|
||||
|
||||
for sub_name, sub_field in sub.get_fields().items():
|
||||
self._fields[sub_name] = sub_field
|
||||
|
||||
return self
|
||||
|
||||
def _field_to_strawberry(self, f: SubscriptionField) -> Any:
|
||||
try:
|
||||
if isinstance(f, SubscriptionField):
|
||||
|
||||
def make_resolver(field: SubscriptionField):
|
||||
async def resolver(root=None, info=None):
|
||||
if not field.public:
|
||||
user = get_user()
|
||||
if not user:
|
||||
raise graphql_error(Unauthorized(f"{field.name}: Authentication required"))
|
||||
|
||||
if field.require_any_permission:
|
||||
ok = any([await user.has_permission(p) for p in field.require_any_permission])
|
||||
if not ok:
|
||||
raise graphql_error(Forbidden(f"{field.name}: Permission denied"))
|
||||
|
||||
if field.require_any:
|
||||
perms, resolvers = field.require_any
|
||||
ok = any([await user.has_permission(p) for p in perms])
|
||||
if not ok:
|
||||
ctx = QueryContext([x.name for x in await user.permissions])
|
||||
results = [
|
||||
r(ctx) if not inspect.iscoroutinefunction(r) else await r(ctx)
|
||||
for r in resolvers
|
||||
]
|
||||
if not any(results):
|
||||
raise graphql_error(Forbidden(f"{field.name}: Permission denied"))
|
||||
|
||||
async for event in self._bus.subscribe(field.channel):
|
||||
if field.selector is None or field.selector(event, info):
|
||||
yield event
|
||||
|
||||
return resolver
|
||||
|
||||
return strawberry.subscription(resolver=make_resolver(f))
|
||||
|
||||
async def wrapper_resolver(root=None, info=None):
|
||||
yield None
|
||||
|
||||
return strawberry.subscription(resolver=wrapper_resolver)
|
||||
|
||||
except StrawberryException as e:
|
||||
raise Exception(f"Error converting subscription field '{f.name}': {e}") from e
|
||||
25
src/graphql/cpl/graphql/schema/subscription_field.py
Normal file
25
src/graphql/cpl/graphql/schema/subscription_field.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Type, Callable, Optional
|
||||
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.typing import Selector
|
||||
|
||||
|
||||
class SubscriptionField(Field):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
t: Type,
|
||||
selector: Optional[Selector] = None,
|
||||
channel: Optional[str] = None,
|
||||
):
|
||||
super().__init__(name, t)
|
||||
self.selector = selector
|
||||
self.channel = channel or name
|
||||
|
||||
def with_selector(self, selector: Selector) -> "SubscriptionField":
|
||||
self.selector = selector
|
||||
return self
|
||||
|
||||
def with_channel(self, channel: str) -> "SubscriptionField":
|
||||
self.channel = channel
|
||||
return self
|
||||
0
src/graphql/cpl/graphql/service/__init__.py
Normal file
0
src/graphql/cpl/graphql/service/__init__.py
Normal file
52
src/graphql/cpl/graphql/service/graphql.py
Normal file
52
src/graphql/cpl/graphql/service/graphql.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from graphql import GraphQLError
|
||||
|
||||
from cpl.api import APILogger, APIError
|
||||
from cpl.api.typing import TRequest
|
||||
from cpl.graphql.service.schema import Schema
|
||||
|
||||
|
||||
class GraphQLService:
|
||||
def __init__(self, logger: APILogger, schema: Schema):
|
||||
self._logger = logger
|
||||
|
||||
if schema.schema is None:
|
||||
raise ValueError("Schema has not been built. Call schema.build() before using the service.")
|
||||
self._schema = schema.schema
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
query: str,
|
||||
variables: Optional[Dict[str, Any]],
|
||||
request: TRequest,
|
||||
) -> Dict[str, Any]:
|
||||
result = await self._schema.execute(
|
||||
query,
|
||||
variable_values=variables,
|
||||
context_value={"request": request},
|
||||
)
|
||||
|
||||
response_data: Dict[str, Any] = {}
|
||||
if result.errors:
|
||||
errors = []
|
||||
for error in result.errors:
|
||||
if isinstance(error, APIError):
|
||||
self._logger.error(f"GraphQL APIError", error)
|
||||
errors.append({"message": error.error_message, "extensions": {"code": error.status_code}})
|
||||
continue
|
||||
|
||||
if isinstance(error, GraphQLError):
|
||||
|
||||
self._logger.error(f"GraphQLError", error)
|
||||
errors.append({"message": error.message, "extensions": error.extensions})
|
||||
continue
|
||||
|
||||
self._logger.error(f"GraphQL unexpected error", error)
|
||||
errors.append({"message": str(error), "extensions": {"code": 500}})
|
||||
|
||||
response_data["errors"] = errors
|
||||
if result.data:
|
||||
response_data["data"] = result.data
|
||||
|
||||
return response_data
|
||||
76
src/graphql/cpl/graphql/service/schema.py
Normal file
76
src/graphql/cpl/graphql/service/schema.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import logging
|
||||
from typing import Type, Self
|
||||
|
||||
import strawberry
|
||||
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.schema.root_mutation import RootMutation
|
||||
from cpl.graphql.schema.root_query import RootQuery
|
||||
from cpl.graphql.schema.root_subscription import RootSubscription
|
||||
|
||||
|
||||
class Schema:
|
||||
|
||||
def __init__(self, logger: APILogger, provider: ServiceProvider):
|
||||
self._logger = logger
|
||||
self._provider = provider
|
||||
|
||||
self._types: dict[str, Type[StrawberryProtocol]] = {}
|
||||
|
||||
self._schema = None
|
||||
|
||||
@property
|
||||
def schema(self) -> strawberry.Schema | None:
|
||||
return self._schema
|
||||
|
||||
@property
|
||||
def query(self) -> RootQuery:
|
||||
query = self._provider.get_service(RootQuery)
|
||||
if not query:
|
||||
raise ValueError("RootQuery not registered in service provider")
|
||||
return query
|
||||
|
||||
@property
|
||||
def mutation(self) -> RootMutation:
|
||||
mutation = self._provider.get_service(RootMutation)
|
||||
if not mutation:
|
||||
raise ValueError("RootMutation not registered in service provider")
|
||||
return mutation
|
||||
|
||||
@property
|
||||
def subscription(self) -> RootSubscription:
|
||||
subscription = self._provider.get_service(RootSubscription)
|
||||
if not subscription:
|
||||
raise ValueError("RootSubscription not registered in service provider")
|
||||
return subscription
|
||||
|
||||
def with_type(self, t: Type[StrawberryProtocol]) -> Self:
|
||||
self._types[t.__name__] = t
|
||||
return self
|
||||
|
||||
def _get_types(self):
|
||||
types: list[Type] = []
|
||||
for t in self._types.values():
|
||||
t_obj = self._provider.get_service(t)
|
||||
if not t_obj:
|
||||
raise ValueError(f"Type '{t.__name__}' not registered in service provider")
|
||||
types.append(t_obj.to_strawberry())
|
||||
|
||||
return types
|
||||
|
||||
def build(self) -> strawberry.Schema:
|
||||
logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL)
|
||||
|
||||
query = self.query
|
||||
mutation = self.mutation
|
||||
subscription = self.subscription
|
||||
|
||||
self._schema = strawberry.Schema(
|
||||
query=query.to_strawberry() if query.fields_count > 0 else None,
|
||||
mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None,
|
||||
subscription=subscription.to_strawberry() if subscription.fields_count > 0 else None,
|
||||
types=self._get_types(),
|
||||
)
|
||||
return self._schema
|
||||
16
src/graphql/cpl/graphql/typing.py
Normal file
16
src/graphql/cpl/graphql/typing.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from enum import Enum
|
||||
from typing import Type, Callable, List, Tuple, Awaitable, Any
|
||||
|
||||
import strawberry
|
||||
|
||||
from cpl.auth.permission import Permissions
|
||||
from cpl.graphql.query_context import QueryContext
|
||||
|
||||
TQuery = Type["Query"]
|
||||
Resolver = Callable
|
||||
Selector = Callable[[Any, strawberry.types.Info], bool]
|
||||
ScalarType = str | int | float | bool | object
|
||||
AttributeName = str | property
|
||||
TRequireAnyPermissions = List[Enum | Permissions] | None
|
||||
TRequireAnyResolvers = List[Callable[[QueryContext], bool | Awaitable[bool]],]
|
||||
TRequireAny = Tuple[TRequireAnyPermissions, TRequireAnyResolvers]
|
||||
0
src/graphql/cpl/graphql/utils/__init__.py
Normal file
0
src/graphql/cpl/graphql/utils/__init__.py
Normal file
28
src/graphql/cpl/graphql/utils/name_pipe.py
Normal file
28
src/graphql/cpl/graphql/utils/name_pipe.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from cpl.core.pipes import PipeABC
|
||||
from cpl.core.typing import T
|
||||
from cpl.graphql.schema.collection import CollectionGraphType
|
||||
from cpl.graphql.schema.graph_type import GraphType
|
||||
from cpl.graphql.schema.object_graph_type import ObjectGraphType
|
||||
|
||||
|
||||
class NamePipe(PipeABC):
|
||||
|
||||
@staticmethod
|
||||
def to_str(value: type, *args) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
if not isinstance(value, type):
|
||||
raise ValueError(f"Expected a type, got {type(value)}")
|
||||
|
||||
if issubclass(value, CollectionGraphType):
|
||||
return f"{value.__name__.replace(GraphType.__name__, "")}"
|
||||
|
||||
if issubclass(value, (ObjectGraphType, GraphType)):
|
||||
return value.__name__.replace(GraphType.__name__, "")
|
||||
|
||||
return value.__name__
|
||||
|
||||
@staticmethod
|
||||
def from_str(value: str, *args) -> T:
|
||||
pass
|
||||
17
src/graphql/cpl/graphql/utils/type_collector.py
Normal file
17
src/graphql/cpl/graphql/utils/type_collector.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Type, Any
|
||||
|
||||
|
||||
class TypeCollector:
|
||||
_registry: dict[type | str, Type] = {}
|
||||
|
||||
@classmethod
|
||||
def has(cls, base: type | str) -> bool:
|
||||
return base in cls._registry
|
||||
|
||||
@classmethod
|
||||
def get(cls, base: type | str) -> Type:
|
||||
return cls._registry[base]
|
||||
|
||||
@classmethod
|
||||
def set(cls, base: type | str, gql_type: Type):
|
||||
cls._registry[base] = gql_type
|
||||
30
src/graphql/pyproject.toml
Normal file
30
src/graphql/pyproject.toml
Normal file
@@ -0,0 +1,30 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=70.1.0", "wheel>=0.43.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "cpl-database"
|
||||
version = "2024.7.0"
|
||||
description = "CPL database"
|
||||
readme ="CPL database package"
|
||||
requires-python = ">=3.12"
|
||||
license = { text = "MIT" }
|
||||
authors = [
|
||||
{ name = "Sven Heidemann", email = "sven.heidemann@sh-edraft.de" }
|
||||
]
|
||||
keywords = ["cpl", "database", "backend", "shared", "library"]
|
||||
|
||||
dynamic = ["dependencies", "optional-dependencies"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://www.sh-edraft.de"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["cpl*"]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
dependencies = { file = ["requirements.txt"] }
|
||||
optional-dependencies.dev = { file = ["requirements.dev.txt"] }
|
||||
|
||||
|
||||
1
src/graphql/requirements.dev.txt
Normal file
1
src/graphql/requirements.dev.txt
Normal file
@@ -0,0 +1 @@
|
||||
black==25.1.0
|
||||
2
src/graphql/requirements.txt
Normal file
2
src/graphql/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
cpl-api
|
||||
strawberry-graphql==0.282.0
|
||||
Reference in New Issue
Block a user