Recursive types #181
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from api.src.queries.cities import CityGraphType, CityFilter, CitySort
|
||||
from api.src.queries.hello import UserGraphType, AuthUserFilter, AuthUserSort, AuthUserGraphType
|
||||
from api.src.queries.hello import UserGraphType#, AuthUserFilter, AuthUserSort, AuthUserGraphType
|
||||
from api.src.queries.user import UserFilter, UserSort
|
||||
from cpl.api.api_module import ApiModule
|
||||
from cpl.application.application_builder import ApplicationBuilder
|
||||
@@ -47,9 +47,9 @@ def main():
|
||||
.add_transient(UserGraphType)
|
||||
.add_transient(UserFilter)
|
||||
.add_transient(UserSort)
|
||||
.add_transient(AuthUserGraphType)
|
||||
.add_transient(AuthUserFilter)
|
||||
.add_transient(AuthUserSort)
|
||||
# .add_transient(AuthUserGraphType)
|
||||
# .add_transient(AuthUserFilter)
|
||||
# .add_transient(AuthUserSort)
|
||||
.add_transient(HelloQuery)
|
||||
# test data
|
||||
.add_singleton(TestDataSeeder)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
from cpl.graphql.schema.graph_type import GraphType
|
||||
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
|
||||
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
|
||||
from cpl.graphql.schema.sort.sort import Sort
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
from model.author import Author
|
||||
|
||||
class AuthorFilter(Filter[Author]):
|
||||
class AuthorFilter(DbModelFilter[Author]):
|
||||
def __init__(self):
|
||||
Filter.__init__(self)
|
||||
DbModelFilter.__init__(self, public=True)
|
||||
self.int_field("id")
|
||||
self.string_field("firstName")
|
||||
self.string_field("lastName")
|
||||
@@ -18,10 +18,10 @@ class AuthorSort(Sort[Author]):
|
||||
self.field("firstName", SortOrder)
|
||||
self.field("lastName", SortOrder)
|
||||
|
||||
class AuthorGraphType(GraphType[Author]):
|
||||
class AuthorGraphType(DbModelGraphType[Author]):
|
||||
|
||||
def __init__(self):
|
||||
GraphType.__init__(self)
|
||||
DbModelGraphType.__init__(self, public=True)
|
||||
|
||||
self.int_field(
|
||||
"id",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from cpl.graphql.query_context import QueryContext
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
from cpl.graphql.schema.graph_type import GraphType
|
||||
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
|
||||
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
|
||||
from cpl.graphql.schema.input import Input
|
||||
from cpl.graphql.schema.mutation import Mutation
|
||||
from cpl.graphql.schema.sort.sort import Sort
|
||||
@@ -11,9 +11,9 @@ from model.post import Post
|
||||
from model.post_dao import PostDao
|
||||
|
||||
|
||||
class PostFilter(Filter[Post]):
|
||||
class PostFilter(DbModelFilter[Post]):
|
||||
def __init__(self):
|
||||
Filter.__init__(self)
|
||||
DbModelFilter.__init__(self, public=True)
|
||||
self.int_field("id")
|
||||
self.filter_field("author", AuthorFilter)
|
||||
self.string_field("title")
|
||||
@@ -26,15 +26,15 @@ class PostSort(Sort[Post]):
|
||||
self.field("title", SortOrder)
|
||||
self.field("content", SortOrder)
|
||||
|
||||
class PostGraphType(GraphType[Post]):
|
||||
class PostGraphType(DbModelGraphType[Post]):
|
||||
|
||||
def __init__(self, authors: AuthorDao):
|
||||
GraphType.__init__(self)
|
||||
DbModelGraphType.__init__(self, public=True)
|
||||
|
||||
self.int_field(
|
||||
"id",
|
||||
resolver=lambda root: root.id,
|
||||
).with_public(True)
|
||||
).with_optional().with_public(True)
|
||||
|
||||
async def _a(root: Post):
|
||||
return await authors.get_by_id(root.author_id)
|
||||
|
||||
@@ -11,32 +11,32 @@ from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
users = [User(i, f"User {i}") for i in range(1, 101)]
|
||||
cities = [City(i, f"City {i}") for i in range(1, 101)]
|
||||
|
||||
class AuthUserFilter(Filter[AuthUser]):
|
||||
def __init__(self):
|
||||
Filter.__init__(self)
|
||||
self.field("id", int)
|
||||
self.field("username", str)
|
||||
|
||||
|
||||
class AuthUserSort(Sort[AuthUser]):
|
||||
def __init__(self):
|
||||
Sort.__init__(self)
|
||||
self.field("id", SortOrder)
|
||||
self.field("username", SortOrder)
|
||||
|
||||
class AuthUserGraphType(GraphType[AuthUser]):
|
||||
|
||||
def __init__(self):
|
||||
GraphType.__init__(self)
|
||||
|
||||
self.int_field(
|
||||
"id",
|
||||
resolver=lambda root: root.id,
|
||||
)
|
||||
self.string_field(
|
||||
"username",
|
||||
resolver=lambda root: root.username,
|
||||
)
|
||||
# class AuthUserFilter(Filter[AuthUser]):
|
||||
# def __init__(self):
|
||||
# Filter.__init__(self)
|
||||
# self.field("id", int)
|
||||
# self.field("username", str)
|
||||
#
|
||||
#
|
||||
# class AuthUserSort(Sort[AuthUser]):
|
||||
# def __init__(self):
|
||||
# Sort.__init__(self)
|
||||
# self.field("id", SortOrder)
|
||||
# self.field("username", SortOrder)
|
||||
#
|
||||
# class AuthUserGraphType(GraphType[AuthUser]):
|
||||
#
|
||||
# def __init__(self):
|
||||
# GraphType.__init__(self)
|
||||
#
|
||||
# self.int_field(
|
||||
# "id",
|
||||
# resolver=lambda root: root.id,
|
||||
# )
|
||||
# self.string_field(
|
||||
# "username",
|
||||
# resolver=lambda root: root.username,
|
||||
# )
|
||||
|
||||
class HelloQuery(Query):
|
||||
def __init__(self):
|
||||
@@ -60,10 +60,10 @@ class HelloQuery(Query):
|
||||
CitySort,
|
||||
resolver=lambda: cities,
|
||||
)
|
||||
self.dao_collection_field(
|
||||
AuthUserGraphType,
|
||||
AuthUserDao,
|
||||
"authUsers",
|
||||
AuthUserFilter,
|
||||
AuthUserSort,
|
||||
)
|
||||
# self.dao_collection_field(
|
||||
# AuthUserGraphType,
|
||||
# AuthUserDao,
|
||||
# "authUsers",
|
||||
# AuthUserFilter,
|
||||
# AuthUserSort,
|
||||
# )
|
||||
|
||||
@@ -36,7 +36,9 @@ from cpl.dependency.typing import Modules
|
||||
|
||||
class WebApp(WebAppABC):
|
||||
def __init__(self, services: ServiceProvider, modules: Modules, required_modules: list[str | object] = None):
|
||||
WebAppABC.__init__(self, services, modules, [AuthModule, PermissionsModule, ApiModule] + (required_modules or []))
|
||||
WebAppABC.__init__(
|
||||
self, services, modules, [AuthModule, PermissionsModule, ApiModule] + (required_modules or [])
|
||||
)
|
||||
self._app: Starlette | None = None
|
||||
|
||||
self._logger = services.get_service(APILogger)
|
||||
|
||||
@@ -21,7 +21,9 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
|
||||
|
||||
class RequestMiddleware(ASGIMiddleware):
|
||||
|
||||
def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
|
||||
def __init__(
|
||||
self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao
|
||||
):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._provider = provider
|
||||
@@ -92,5 +94,6 @@ class RequestMiddleware(ASGIMiddleware):
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Silent user binding failed: {e}")
|
||||
|
||||
|
||||
def get_request() -> Optional[TRequest]:
|
||||
return _request_context.get()
|
||||
|
||||
@@ -11,6 +11,7 @@ class CredentialManager:
|
||||
@classmethod
|
||||
def with_secret(cls, file: str = None):
|
||||
from cpl.core.log import Logger
|
||||
|
||||
if file is None:
|
||||
file = ".secret"
|
||||
|
||||
|
||||
@@ -2,7 +2,10 @@ from abc import ABC
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Generic
|
||||
|
||||
from async_property import async_property
|
||||
|
||||
from cpl.core.typing import Id, SerialId, T
|
||||
from cpl.dependency import get_provider
|
||||
|
||||
|
||||
class DbModelABC(ABC, Generic[T]):
|
||||
@@ -41,14 +44,16 @@ class DbModelABC(ABC, Generic[T]):
|
||||
def editor_id(self, value: SerialId):
|
||||
self._editor_id = value
|
||||
|
||||
# @async_property
|
||||
# async def editor(self):
|
||||
# if self._editor_id is None:
|
||||
# return None
|
||||
#
|
||||
# from data.schemas.administration.user_dao import userDao
|
||||
#
|
||||
# return await userDao.get_by_id(self._editor_id)
|
||||
@async_property
|
||||
async def editor(self):
|
||||
if self._editor_id is None:
|
||||
return None
|
||||
|
||||
from cpl.auth.schema import AuthUserDao
|
||||
|
||||
auth_user_dao = get_provider().get_service(AuthUserDao)
|
||||
|
||||
return await auth_user_dao.get_by_id(self._editor_id)
|
||||
|
||||
@property
|
||||
def created(self) -> datetime:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from starlette.responses import HTMLResponse
|
||||
|
||||
|
||||
async def graphiql_endpoint(request):
|
||||
return HTMLResponse("""
|
||||
return HTMLResponse(
|
||||
"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
@@ -34,4 +36,5 @@ async def graphiql_endpoint(request):
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -3,7 +3,8 @@ from starlette.responses import Response, HTMLResponse
|
||||
|
||||
|
||||
async def playground_endpoint(request: Request) -> Response:
|
||||
return HTMLResponse("""
|
||||
return HTMLResponse(
|
||||
"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
@@ -24,4 +25,5 @@ async def playground_endpoint(request: Request) -> Response:
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import functools
|
||||
import inspect
|
||||
import types
|
||||
from abc import ABC
|
||||
from asyncio import iscoroutinefunction
|
||||
from asyncio import iscoroutinefunction, iscoroutine
|
||||
from typing import Callable, Type, Any, Optional
|
||||
|
||||
import strawberry
|
||||
@@ -9,11 +10,12 @@ from strawberry.exceptions import StrawberryException
|
||||
|
||||
from cpl.api import Unauthorized, Forbidden
|
||||
from cpl.core.ctx.user_context import get_user
|
||||
from cpl.dependency import get_provider
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.error import graphql_error
|
||||
from cpl.graphql.query_context import QueryContext
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.typing import Resolver
|
||||
from cpl.graphql.typing import Resolver, AttributeName
|
||||
from cpl.graphql.utils.type_collector import TypeCollector
|
||||
|
||||
|
||||
@@ -36,31 +38,37 @@ class QueryABC(StrawberryProtocol, ABC):
|
||||
|
||||
def field(
|
||||
self,
|
||||
name: str,
|
||||
name: AttributeName,
|
||||
t: type,
|
||||
resolver: Resolver = None,
|
||||
) -> Field:
|
||||
from cpl.graphql.schema.field import Field
|
||||
|
||||
if isinstance(name, property):
|
||||
name = name.fget.__name__
|
||||
|
||||
self._fields[name] = Field(name, t, resolver)
|
||||
return self._fields[name]
|
||||
|
||||
def string_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
def string_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, str, resolver)
|
||||
|
||||
def int_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
def int_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, int, resolver)
|
||||
|
||||
def float_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
def float_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, float, resolver)
|
||||
|
||||
def bool_field(self, name: str, resolver: Resolver = None) -> Field:
|
||||
def bool_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, bool, resolver)
|
||||
|
||||
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
|
||||
def list_field(self, name: AttributeName, t: type, resolver: Resolver = None) -> Field:
|
||||
return self.field(name, list[t], resolver)
|
||||
|
||||
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
|
||||
if not isinstance(t, type) and callable(t):
|
||||
return self.field(name, t, resolver)
|
||||
|
||||
return self.field(name, t().to_strawberry(), resolver)
|
||||
|
||||
@staticmethod
|
||||
@@ -137,9 +145,10 @@ class QueryABC(StrawberryProtocol, ABC):
|
||||
|
||||
@staticmethod
|
||||
async def _run_resolver(r: Callable, *args, **kwargs):
|
||||
if iscoroutinefunction(r):
|
||||
return await r(*args, **kwargs)
|
||||
return r(*args, **kwargs)
|
||||
result = r(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
def _field_to_strawberry(self, f: Field) -> Any:
|
||||
resolver = None
|
||||
@@ -147,7 +156,7 @@ class QueryABC(StrawberryProtocol, ABC):
|
||||
if f.arguments:
|
||||
resolver = self._build_resolver(f)
|
||||
elif not f.resolver:
|
||||
resolver = lambda *_, **__: None
|
||||
resolver = lambda root: None
|
||||
else:
|
||||
ann = getattr(f.resolver, "__annotations__", {})
|
||||
if "return" not in ann or ann["return"] is None:
|
||||
@@ -165,14 +174,31 @@ class QueryABC(StrawberryProtocol, ABC):
|
||||
if TypeCollector.has(cls):
|
||||
return TypeCollector.get(cls)
|
||||
|
||||
gql_cls = type(f"{cls.__name__.replace('GraphType', '')}", (), {})
|
||||
# register early to handle recursive types
|
||||
TypeCollector.set(cls, gql_cls)
|
||||
|
||||
annotations: dict[str, Any] = {}
|
||||
namespace: dict[str, Any] = {}
|
||||
|
||||
for name, f in self._fields.items():
|
||||
annotations[name] = f.type
|
||||
t = f.type
|
||||
|
||||
if callable(t) and not isinstance(t, type):
|
||||
_t = get_provider().get_service(t())
|
||||
if isinstance(_t, StrawberryProtocol):
|
||||
t = _t.to_strawberry()
|
||||
else:
|
||||
t = _t
|
||||
|
||||
annotations[name] = t if not f.optional else Optional[t]
|
||||
namespace[name] = self._field_to_strawberry(f)
|
||||
|
||||
namespace["__annotations__"] = annotations
|
||||
gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
|
||||
for k, v in namespace.items():
|
||||
setattr(gql_cls, k, v)
|
||||
|
||||
gql_cls.__annotations__ = annotations
|
||||
gql_type = strawberry.type(gql_cls)
|
||||
TypeCollector.set(cls, gql_type)
|
||||
return gql_type
|
||||
|
||||
0
src/cpl-graphql/cpl/graphql/auth/__init__.py
Normal file
0
src/cpl-graphql/cpl/graphql/auth/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from cpl.auth.schema import AuthUser
|
||||
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
|
||||
|
||||
|
||||
class AuthUserGraphType(DbModelGraphType):
|
||||
|
||||
def __init__(self):
|
||||
DbModelGraphType.__init__(self)
|
||||
|
||||
self.string_field(AuthUser.keycloak_id, lambda root: root.keycloak_id)
|
||||
self.string_field(AuthUser.username, lambda root: root.username)
|
||||
self.string_field(AuthUser.email, lambda root: root.email)
|
||||
6
src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py
Normal file
6
src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from cpl.dependency.module.module import Module
|
||||
from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType
|
||||
|
||||
|
||||
class GraphQLAuthModule(Module):
|
||||
transient = [AuthUserGraphType]
|
||||
@@ -1,6 +1,8 @@
|
||||
from cpl.api.api_module import ApiModule
|
||||
from cpl.dependency import ServiceCollection
|
||||
from cpl.dependency.module.module import Module
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule
|
||||
from cpl.graphql.schema.filter.bool_filter import BoolFilter
|
||||
from cpl.graphql.schema.filter.date_filter import DateFilter
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
@@ -18,6 +20,10 @@ class GraphQLModule(Module):
|
||||
scoped = [GraphQLService]
|
||||
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]
|
||||
|
||||
@staticmethod
|
||||
def register(collection: ServiceCollection):
|
||||
collection.add_module(GraphQLAuthModule)
|
||||
|
||||
@staticmethod
|
||||
def configure(services: ServiceProvider) -> None:
|
||||
schema = services.get_service(Schema)
|
||||
|
||||
@@ -9,13 +9,7 @@ from cpl.core.ctx import get_user
|
||||
|
||||
class QueryContext:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_permissions: Optional[list[Enum | Permission]],
|
||||
is_mutation: bool = False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
def __init__(self, user_permissions: Optional[list[Enum | Permission]], is_mutation: bool = False, *args, **kwargs):
|
||||
self._user = get_user()
|
||||
self._user_permissions = user_permissions or []
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ class CollectionGraphTypeFactory:
|
||||
if not node_t:
|
||||
raise ValueError(f"Node type '{node_type.__name__}' not registered in service provider")
|
||||
|
||||
|
||||
gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type
|
||||
|
||||
gql_type = strawberry.type(
|
||||
|
||||
60
src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py
Normal file
60
src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import Type, Optional, Generic, Annotated
|
||||
|
||||
import strawberry
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
||||
from cpl.graphql.schema.graph_type import GraphType
|
||||
from cpl.graphql.schema.query import Query
|
||||
|
||||
|
||||
class DbModelGraphType(GraphType[T], Generic[T]):
|
||||
|
||||
def __init__(self, t_dao: Type[DataAccessObjectABC] = None, with_history: bool = False, public: bool = False):
|
||||
Query.__init__(self)
|
||||
|
||||
self._dao: Optional[DataAccessObjectABC] = None
|
||||
|
||||
if t_dao is not None:
|
||||
dao = self._provider.get_service(t_dao)
|
||||
if dao is not None:
|
||||
self._dao = dao
|
||||
|
||||
self.int_field("id", lambda root: root.id).with_public(public)
|
||||
self.bool_field("deleted", lambda root: root.deleted).with_public(public)
|
||||
|
||||
from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType
|
||||
|
||||
self.object_field("editor", lambda: AuthUserGraphType, lambda root: root.editor).with_public(public)
|
||||
|
||||
self.string_field("created", lambda root: root.created).with_public(public)
|
||||
self.string_field("updated", lambda root: root.updated).with_public(public)
|
||||
|
||||
# if with_history:
|
||||
# if self._dao is None:
|
||||
# raise ValueError("DAO must be provided to enable history")
|
||||
# self.set_field("history", self._resolve_history).with_public(public)
|
||||
|
||||
self._history_reference_daos: dict[DataAccessObjectABC, str] = {}
|
||||
|
||||
async def _resolve_history(self, root):
|
||||
if self._dao is None:
|
||||
raise Exception("DAO not set for history query")
|
||||
|
||||
history = sorted(
|
||||
[await self._dao.get_by_id(root.id), *await self._dao.get_history(root.id)],
|
||||
key=lambda h: h.updated,
|
||||
reverse=True,
|
||||
)
|
||||
return history
|
||||
|
||||
def set_history_reference_dao(self, dao: DataAccessObjectABC, key: str = None):
|
||||
"""
|
||||
Set the reference DAO for history resolution.
|
||||
:param dao:
|
||||
:param key: The key to use for resolving history.
|
||||
:return:
|
||||
"""
|
||||
if key is None:
|
||||
key = "id"
|
||||
self._history_reference_daos[dao] = key
|
||||
@@ -87,7 +87,7 @@ class Field:
|
||||
self._resolver = resolver
|
||||
return self
|
||||
|
||||
def with_optional(self, optional: bool) -> Self:
|
||||
def with_optional(self, optional: bool = True) -> Self:
|
||||
self._optional = optional
|
||||
return self
|
||||
|
||||
@@ -99,7 +99,9 @@ class Field:
|
||||
self._default = default
|
||||
return self
|
||||
|
||||
def with_argument(self, name: str, arg_type: type, description: str = None, default_value=None, optional=True) -> Argument:
|
||||
def with_argument(
|
||||
self, name: str, arg_type: type, description: str = None, default_value=None, optional=True
|
||||
) -> Argument:
|
||||
if name in self._args:
|
||||
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
|
||||
self._args[name] = Argument(name, arg_type, description, default_value, optional)
|
||||
|
||||
20
src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py
Normal file
20
src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Type, Generic
|
||||
|
||||
from cpl.core.typing import T
|
||||
from cpl.graphql.schema.filter.bool_filter import BoolFilter
|
||||
from cpl.graphql.schema.filter.date_filter import DateFilter
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
from cpl.graphql.schema.filter.int_filter import IntFilter
|
||||
from cpl.graphql.schema.filter.string_filter import StringFilter
|
||||
from cpl.graphql.schema.input import Input
|
||||
|
||||
|
||||
class DbModelFilter(Filter[T], Generic[T]):
|
||||
def __init__(self, public: bool = False):
|
||||
Filter.__init__(self)
|
||||
|
||||
self.field("id", IntFilter).with_public(public)
|
||||
self.field("deleted", BoolFilter).with_public(public)
|
||||
# self.field("editor", AuthUserFilter)
|
||||
self.field("created", DateFilter).with_public(public)
|
||||
self.field("updated", DateFilter).with_public(public)
|
||||
@@ -5,10 +5,12 @@ import strawberry
|
||||
from cpl.core.typing import T
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.schema.field import Field
|
||||
from cpl.graphql.typing import AttributeName
|
||||
from cpl.graphql.utils.type_collector import TypeCollector
|
||||
|
||||
_PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
|
||||
|
||||
|
||||
class Input(StrawberryProtocol, Generic[T]):
|
||||
def __init__(self):
|
||||
self._fields: Dict[str, Field] = {}
|
||||
@@ -37,26 +39,29 @@ class Input(StrawberryProtocol, Generic[T]):
|
||||
def get_fields(self) -> dict[str, Field]:
|
||||
return self._fields
|
||||
|
||||
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True) -> Field:
|
||||
def field(self, name: AttributeName, typ: Union[type, "Input"], optional: bool = True) -> Field:
|
||||
if isinstance(name, property):
|
||||
name = name.fget.__name__
|
||||
|
||||
self._fields[name] = Field(name, typ, optional=optional)
|
||||
return self._fields[name]
|
||||
|
||||
def string_field(self, name: str, optional: bool = True) -> Field:
|
||||
def string_field(self, name: AttributeName, optional: bool = True) -> Field:
|
||||
return self.field(name, str)
|
||||
|
||||
def int_field(self, name: str, optional: bool = True) -> Field:
|
||||
def int_field(self, name: AttributeName, optional: bool = True) -> Field:
|
||||
return self.field(name, int, optional)
|
||||
|
||||
def float_field(self, name: str, optional: bool = True) -> Field:
|
||||
def float_field(self, name: AttributeName, optional: bool = True) -> Field:
|
||||
return self.field(name, float, optional)
|
||||
|
||||
def bool_field(self, name: str, optional: bool = True) -> Field:
|
||||
def bool_field(self, name: AttributeName, optional: bool = True) -> Field:
|
||||
return self.field(name, bool, optional)
|
||||
|
||||
def list_field(self, name: str, t: type, optional: bool = True) -> Field:
|
||||
def list_field(self, name: AttributeName, t: type, optional: bool = True) -> Field:
|
||||
return self.field(name, list[t], optional)
|
||||
|
||||
def object_field(self, name: str, t: Type[StrawberryProtocol], optional: bool = True) -> Field:
|
||||
def object_field(self, name: AttributeName, t: Type[StrawberryProtocol], optional: bool = True) -> Field:
|
||||
return self.field(name, t().to_strawberry(), optional)
|
||||
|
||||
def to_strawberry(self) -> Type:
|
||||
|
||||
@@ -6,6 +6,7 @@ import strawberry
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType
|
||||
from cpl.graphql.schema.root_mutation import RootMutation
|
||||
from cpl.graphql.schema.root_query import RootQuery
|
||||
|
||||
@@ -16,7 +17,9 @@ class Schema:
|
||||
self._logger = logger
|
||||
self._provider = provider
|
||||
|
||||
self._types: dict[str, Type[StrawberryProtocol]] = {}
|
||||
self._types: dict[str, Type[StrawberryProtocol]] = {
|
||||
"AuthUserGraphType": AuthUserGraphType,
|
||||
}
|
||||
|
||||
self._schema = None
|
||||
|
||||
|
||||
@@ -7,9 +7,7 @@ from cpl.graphql.query_context import QueryContext
|
||||
TQuery = Type["Query"]
|
||||
Resolver = Callable
|
||||
ScalarType = str | int | float | bool | object
|
||||
|
||||
AttributeName = str | property
|
||||
TRequireAnyPermissions = List[Enum | Permissions] | None
|
||||
TRequireAnyResolvers = List[
|
||||
Callable[[QueryContext], bool | Awaitable[bool]],
|
||||
]
|
||||
TRequireAnyResolvers = List[Callable[[QueryContext], bool | Awaitable[bool]],]
|
||||
TRequireAny = Tuple[TRequireAnyPermissions, TRequireAnyResolvers]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Type
|
||||
from typing import Type, Any
|
||||
|
||||
|
||||
class TypeCollector:
|
||||
|
||||
Reference in New Issue
Block a user