WIP: dev into master #184

Draft
edraft wants to merge 121 commits from dev into master
41 changed files with 281 additions and 134 deletions
Showing only changes of commit 39351a5eb9 - Show all commits

View File

@@ -1,7 +1,7 @@
from starlette.responses import JSONResponse
from api.src.queries.cities import CityGraphType, CityFilter, CitySort
from api.src.queries.hello import UserGraphType, AuthUserFilter, AuthUserSort, AuthUserGraphType
from api.src.queries.hello import UserGraphType#, AuthUserFilter, AuthUserSort, AuthUserGraphType
from api.src.queries.user import UserFilter, UserSort
from cpl.api.api_module import ApiModule
from cpl.application.application_builder import ApplicationBuilder
@@ -47,9 +47,9 @@ def main():
.add_transient(UserGraphType)
.add_transient(UserFilter)
.add_transient(UserSort)
.add_transient(AuthUserGraphType)
.add_transient(AuthUserFilter)
.add_transient(AuthUserSort)
# .add_transient(AuthUserGraphType)
# .add_transient(AuthUserFilter)
# .add_transient(AuthUserSort)
.add_transient(HelloQuery)
# test data
.add_singleton(TestDataSeeder)

View File

@@ -1,12 +1,12 @@
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
from model.author import Author
class AuthorFilter(Filter[Author]):
class AuthorFilter(DbModelFilter[Author]):
def __init__(self):
Filter.__init__(self)
DbModelFilter.__init__(self, public=True)
self.int_field("id")
self.string_field("firstName")
self.string_field("lastName")
@@ -18,10 +18,10 @@ class AuthorSort(Sort[Author]):
self.field("firstName", SortOrder)
self.field("lastName", SortOrder)
class AuthorGraphType(GraphType[Author]):
class AuthorGraphType(DbModelGraphType[Author]):
def __init__(self):
GraphType.__init__(self)
DbModelGraphType.__init__(self, public=True)
self.int_field(
"id",

View File

@@ -1,6 +1,6 @@
from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
from cpl.graphql.schema.input import Input
from cpl.graphql.schema.mutation import Mutation
from cpl.graphql.schema.sort.sort import Sort
@@ -11,9 +11,9 @@ from model.post import Post
from model.post_dao import PostDao
class PostFilter(Filter[Post]):
class PostFilter(DbModelFilter[Post]):
def __init__(self):
Filter.__init__(self)
DbModelFilter.__init__(self, public=True)
self.int_field("id")
self.filter_field("author", AuthorFilter)
self.string_field("title")
@@ -26,15 +26,15 @@ class PostSort(Sort[Post]):
self.field("title", SortOrder)
self.field("content", SortOrder)
class PostGraphType(GraphType[Post]):
class PostGraphType(DbModelGraphType[Post]):
def __init__(self, authors: AuthorDao):
GraphType.__init__(self)
DbModelGraphType.__init__(self, public=True)
self.int_field(
"id",
resolver=lambda root: root.id,
).with_public(True)
).with_optional().with_public(True)
async def _a(root: Post):
return await authors.get_by_id(root.author_id)

View File

@@ -11,32 +11,32 @@ from cpl.graphql.schema.sort.sort_order import SortOrder
users = [User(i, f"User {i}") for i in range(1, 101)]
cities = [City(i, f"City {i}") for i in range(1, 101)]
class AuthUserFilter(Filter[AuthUser]):
def __init__(self):
Filter.__init__(self)
self.field("id", int)
self.field("username", str)
class AuthUserSort(Sort[AuthUser]):
def __init__(self):
Sort.__init__(self)
self.field("id", SortOrder)
self.field("username", SortOrder)
class AuthUserGraphType(GraphType[AuthUser]):
def __init__(self):
GraphType.__init__(self)
self.int_field(
"id",
resolver=lambda root: root.id,
)
self.string_field(
"username",
resolver=lambda root: root.username,
)
# class AuthUserFilter(Filter[AuthUser]):
# def __init__(self):
# Filter.__init__(self)
# self.field("id", int)
# self.field("username", str)
#
#
# class AuthUserSort(Sort[AuthUser]):
# def __init__(self):
# Sort.__init__(self)
# self.field("id", SortOrder)
# self.field("username", SortOrder)
#
# class AuthUserGraphType(GraphType[AuthUser]):
#
# def __init__(self):
# GraphType.__init__(self)
#
# self.int_field(
# "id",
# resolver=lambda root: root.id,
# )
# self.string_field(
# "username",
# resolver=lambda root: root.username,
# )
class HelloQuery(Query):
def __init__(self):
@@ -60,10 +60,10 @@ class HelloQuery(Query):
CitySort,
resolver=lambda: cities,
)
self.dao_collection_field(
AuthUserGraphType,
AuthUserDao,
"authUsers",
AuthUserFilter,
AuthUserSort,
)
# self.dao_collection_field(
# AuthUserGraphType,
# AuthUserDao,
# "authUsers",
# AuthUserFilter,
# AuthUserSort,
# )

View File

@@ -36,7 +36,9 @@ from cpl.dependency.typing import Modules
class WebApp(WebAppABC):
def __init__(self, services: ServiceProvider, modules: Modules, required_modules: list[str | object] = None):
WebAppABC.__init__(self, services, modules, [AuthModule, PermissionsModule, ApiModule] + (required_modules or []))
WebAppABC.__init__(
self, services, modules, [AuthModule, PermissionsModule, ApiModule] + (required_modules or [])
)
self._app: Starlette | None = None
self._logger = services.get_service(APILogger)

View File

@@ -21,7 +21,9 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
class RequestMiddleware(ASGIMiddleware):
def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao):
def __init__(
self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao
):
ASGIMiddleware.__init__(self, app)
self._provider = provider
@@ -92,5 +94,6 @@ class RequestMiddleware(ASGIMiddleware):
except Exception as e:
self._logger.debug(f"Silent user binding failed: {e}")
def get_request() -> Optional[TRequest]:
return _request_context.get()

View File

@@ -23,8 +23,8 @@ class RoleSeeder(DataSeederABC):
role_permission_dao: RolePermissionDao,
api_key_dao: ApiKeyDao,
api_key_permission_dao: ApiKeyPermissionDao,
user_dao: AuthUserDao,
role_user_dao: RoleUserDao,
user_dao: AuthUserDao,
role_user_dao: RoleUserDao,
):
DataSeederABC.__init__(self)
self._logger = logger

View File

@@ -25,8 +25,8 @@ class ApiKey(DbModelABC[Self]):
key: Union[str, bytes],
deleted: bool = False,
editor_id: Optional[Id] = None,
created: datetime | None= None,
updated: datetime | None= None,
created: datetime | None = None,
updated: datetime | None = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._identifier = identifier

View File

@@ -20,8 +20,8 @@ class AuthUser(DbModelABC[Self]):
keycloak_id: str,
deleted: bool = False,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
created: datetime | None = None,
updated: datetime | None = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._keycloak_id = keycloak_id

View File

@@ -16,8 +16,8 @@ class ApiKeyPermission(DbJoinModelABC):
permission_id: SerialId,
deleted: bool = False,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
created: datetime | None = None,
updated: datetime | None = None,
):
DbJoinModelABC.__init__(self, api_key_id, permission_id, id, deleted, editor_id, created, updated)
self._api_key_id = api_key_id

View File

@@ -13,8 +13,8 @@ class Permission(DbModelABC[Self]):
description: str,
deleted: bool = False,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
created: datetime | None = None,
updated: datetime | None = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name

View File

@@ -17,8 +17,8 @@ class Role(DbModelABC[Self]):
description: str,
deleted: bool = False,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
created: datetime | None = None,
updated: datetime | None = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name

View File

@@ -16,8 +16,8 @@ class RolePermission(DbModelABC[Self]):
permission_id: SerialId,
deleted: bool = False,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
created: datetime | None = None,
updated: datetime | None = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._role_id = role_id

View File

@@ -16,8 +16,8 @@ class RoleUser(DbJoinModelABC):
role_id: SerialId,
deleted: bool = False,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
created: datetime | None = None,
updated: datetime | None = None,
):
DbJoinModelABC.__init__(self, id, user_id, role_id, deleted, editor_id, created, updated)
self._user_id = user_id

View File

@@ -11,6 +11,7 @@ class CredentialManager:
@classmethod
def with_secret(cls, file: str = None):
from cpl.core.log import Logger
if file is None:
file = ".secret"

View File

@@ -13,8 +13,8 @@ class DbJoinModelABC[T](DbModelABC[T]):
foreign_id: Id,
deleted: bool = False,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
created: datetime | None = None,
updated: datetime | None = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)

View File

@@ -2,7 +2,10 @@ from abc import ABC
from datetime import datetime, timezone
from typing import Optional, Generic
from async_property import async_property
from cpl.core.typing import Id, SerialId, T
from cpl.dependency import get_provider
class DbModelABC(ABC, Generic[T]):
@@ -11,8 +14,8 @@ class DbModelABC(ABC, Generic[T]):
id: Id,
deleted: bool = False,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
created: datetime | None = None,
updated: datetime | None = None,
):
self._id = id
self._deleted = deleted
@@ -41,14 +44,16 @@ class DbModelABC(ABC, Generic[T]):
def editor_id(self, value: SerialId):
self._editor_id = value
# @async_property
# async def editor(self):
# if self._editor_id is None:
# return None
#
# from data.schemas.administration.user_dao import userDao
#
# return await userDao.get_by_id(self._editor_id)
@async_property
async def editor(self):
if self._editor_id is None:
return None
from cpl.auth.schema import AuthUserDao
auth_user_dao = get_provider().get_service(AuthUserDao)
return await auth_user_dao.get_by_id(self._editor_id)
@property
def created(self) -> datetime:

View File

@@ -8,8 +8,8 @@ class ExecutedMigration(DbModelABC[Self]):
def __init__(
self,
migration_id: str,
created: datetime | None= None,
modified: datetime | None= None,
created: datetime | None = None,
modified: datetime | None = None,
):
DbModelABC.__init__(self, migration_id, False, created, modified)

View File

@@ -1,7 +1,9 @@
from starlette.responses import HTMLResponse
async def graphiql_endpoint(request):
return HTMLResponse("""
return HTMLResponse(
"""
<!DOCTYPE html>
<html>
<head>
@@ -34,4 +36,5 @@ async def graphiql_endpoint(request):
</script>
</body>
</html>
""")
"""
)

View File

@@ -3,7 +3,8 @@ from starlette.responses import Response, HTMLResponse
async def playground_endpoint(request: Request) -> Response:
return HTMLResponse("""
return HTMLResponse(
"""
<!DOCTYPE html>
<html>
<head>
@@ -24,4 +25,5 @@ async def playground_endpoint(request: Request) -> Response:
</script>
</body>
</html>
""")
"""
)

View File

@@ -1,7 +1,8 @@
import functools
import inspect
import types
from abc import ABC
from asyncio import iscoroutinefunction
from asyncio import iscoroutinefunction, iscoroutine
from typing import Callable, Type, Any, Optional
import strawberry
@@ -9,11 +10,12 @@ from strawberry.exceptions import StrawberryException
from cpl.api import Unauthorized, Forbidden
from cpl.core.ctx.user_context import get_user
from cpl.dependency import get_provider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.error import graphql_error
from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.field import Field
from cpl.graphql.typing import Resolver
from cpl.graphql.typing import Resolver, AttributeName
from cpl.graphql.utils.type_collector import TypeCollector
@@ -36,31 +38,37 @@ class QueryABC(StrawberryProtocol, ABC):
def field(
self,
name: str,
name: AttributeName,
t: type,
resolver: Resolver = None,
) -> Field:
from cpl.graphql.schema.field import Field
if isinstance(name, property):
name = name.fget.__name__
self._fields[name] = Field(name, t, resolver)
return self._fields[name]
def string_field(self, name: str, resolver: Resolver = None) -> Field:
def string_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
return self.field(name, str, resolver)
def int_field(self, name: str, resolver: Resolver = None) -> Field:
def int_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
return self.field(name, int, resolver)
def float_field(self, name: str, resolver: Resolver = None) -> Field:
def float_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
return self.field(name, float, resolver)
def bool_field(self, name: str, resolver: Resolver = None) -> Field:
def bool_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
return self.field(name, bool, resolver)
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
def list_field(self, name: AttributeName, t: type, resolver: Resolver = None) -> Field:
return self.field(name, list[t], resolver)
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
if not isinstance(t, type) and callable(t):
return self.field(name, t, resolver)
return self.field(name, t().to_strawberry(), resolver)
@staticmethod
@@ -137,9 +145,10 @@ class QueryABC(StrawberryProtocol, ABC):
@staticmethod
async def _run_resolver(r: Callable, *args, **kwargs):
if iscoroutinefunction(r):
return await r(*args, **kwargs)
return r(*args, **kwargs)
result = r(*args, **kwargs)
if inspect.isawaitable(result):
return await result
return result
def _field_to_strawberry(self, f: Field) -> Any:
resolver = None
@@ -147,7 +156,7 @@ class QueryABC(StrawberryProtocol, ABC):
if f.arguments:
resolver = self._build_resolver(f)
elif not f.resolver:
resolver = lambda *_, **__: None
resolver = lambda root: None
else:
ann = getattr(f.resolver, "__annotations__", {})
if "return" not in ann or ann["return"] is None:
@@ -165,14 +174,31 @@ class QueryABC(StrawberryProtocol, ABC):
if TypeCollector.has(cls):
return TypeCollector.get(cls)
gql_cls = type(f"{cls.__name__.replace('GraphType', '')}", (), {})
# register early to handle recursive types
TypeCollector.set(cls, gql_cls)
annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {}
for name, f in self._fields.items():
annotations[name] = f.type
t = f.type
if callable(t) and not isinstance(t, type):
_t = get_provider().get_service(t())
if isinstance(_t, StrawberryProtocol):
t = _t.to_strawberry()
else:
t = _t
annotations[name] = t if not f.optional else Optional[t]
namespace[name] = self._field_to_strawberry(f)
namespace["__annotations__"] = annotations
gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
for k, v in namespace.items():
setattr(gql_cls, k, v)
gql_cls.__annotations__ = annotations
gql_type = strawberry.type(gql_cls)
TypeCollector.set(cls, gql_type)
return gql_type

View File

@@ -0,0 +1,12 @@
from cpl.auth.schema import AuthUser
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
class AuthUserGraphType(DbModelGraphType):
def __init__(self):
DbModelGraphType.__init__(self)
self.string_field(AuthUser.keycloak_id, lambda root: root.keycloak_id)
self.string_field(AuthUser.username, lambda root: root.username)
self.string_field(AuthUser.email, lambda root: root.email)

View File

@@ -0,0 +1,6 @@
from cpl.dependency.module.module import Module
from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType
class GraphQLAuthModule(Module):
transient = [AuthUserGraphType]

View File

@@ -11,4 +11,4 @@ def graphql_error(api_error: APIError) -> GraphQLError:
"code": api_error.status_code,
},
original_error=api_error,
)
)

View File

@@ -1,6 +1,8 @@
from cpl.api.api_module import ApiModule
from cpl.dependency import ServiceCollection
from cpl.dependency.module.module import Module
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule
from cpl.graphql.schema.filter.bool_filter import BoolFilter
from cpl.graphql.schema.filter.date_filter import DateFilter
from cpl.graphql.schema.filter.filter import Filter
@@ -18,6 +20,10 @@ class GraphQLModule(Module):
scoped = [GraphQLService]
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]
@staticmethod
def register(collection: ServiceCollection):
collection.add_module(GraphQLAuthModule)
@staticmethod
def configure(services: ServiceProvider) -> None:
schema = services.get_service(Schema)

View File

@@ -9,13 +9,7 @@ from cpl.core.ctx import get_user
class QueryContext:
def __init__(
self,
user_permissions: Optional[list[Enum | Permission]],
is_mutation: bool = False,
*args,
**kwargs
):
def __init__(self, user_permissions: Optional[list[Enum | Permission]], is_mutation: bool = False, *args, **kwargs):
self._user = get_user()
self._user_permissions = user_permissions or []

View File

@@ -19,7 +19,6 @@ class CollectionGraphTypeFactory:
if not node_t:
raise ValueError(f"Node type '{node_type.__name__}' not registered in service provider")
gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type
gql_type = strawberry.type(

View File

@@ -0,0 +1,60 @@
from typing import Type, Optional, Generic, Annotated
import strawberry
from cpl.core.typing import T
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.query import Query
class DbModelGraphType(GraphType[T], Generic[T]):
def __init__(self, t_dao: Type[DataAccessObjectABC] = None, with_history: bool = False, public: bool = False):
Query.__init__(self)
self._dao: Optional[DataAccessObjectABC] = None
if t_dao is not None:
dao = self._provider.get_service(t_dao)
if dao is not None:
self._dao = dao
self.int_field("id", lambda root: root.id).with_public(public)
self.bool_field("deleted", lambda root: root.deleted).with_public(public)
from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType
self.object_field("editor", lambda: AuthUserGraphType, lambda root: root.editor).with_public(public)
self.string_field("created", lambda root: root.created).with_public(public)
self.string_field("updated", lambda root: root.updated).with_public(public)
# if with_history:
# if self._dao is None:
# raise ValueError("DAO must be provided to enable history")
# self.set_field("history", self._resolve_history).with_public(public)
self._history_reference_daos: dict[DataAccessObjectABC, str] = {}
async def _resolve_history(self, root):
if self._dao is None:
raise Exception("DAO not set for history query")
history = sorted(
[await self._dao.get_by_id(root.id), *await self._dao.get_history(root.id)],
key=lambda h: h.updated,
reverse=True,
)
return history
def set_history_reference_dao(self, dao: DataAccessObjectABC, key: str = None):
"""
Set the reference DAO for history resolution.
:param dao:
:param key: The key to use for resolving history.
:return:
"""
if key is None:
key = "id"
self._history_reference_daos[dao] = key

View File

@@ -5,7 +5,7 @@ from cpl.graphql.schema.argument import Argument
from cpl.graphql.typing import TQuery, Resolver, TRequireAnyPermissions, TRequireAnyResolvers
class Field:
class Field:
def __init__(
self,
@@ -87,7 +87,7 @@ class Field:
self._resolver = resolver
return self
def with_optional(self, optional: bool) -> Self:
def with_optional(self, optional: bool = True) -> Self:
self._optional = optional
return self
@@ -99,7 +99,9 @@ class Field:
self._default = default
return self
def with_argument(self, name: str, arg_type: type, description: str = None, default_value=None, optional=True) -> Argument:
def with_argument(
self, name: str, arg_type: type, description: str = None, default_value=None, optional=True
) -> Argument:
if name in self._args:
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
self._args[name] = Argument(name, arg_type, description, default_value, optional)

View File

@@ -15,4 +15,4 @@ class DateFilter(Input[datetime]):
self.field("isNull", datetime, optional=True)
self.field("isNotNull", datetime, optional=True)
self.field("in", list[datetime], optional=True)
self.field("notIn", list[datetime], optional=True)
self.field("notIn", list[datetime], optional=True)

View File

@@ -0,0 +1,20 @@
from typing import Type, Generic
from cpl.core.typing import T
from cpl.graphql.schema.filter.bool_filter import BoolFilter
from cpl.graphql.schema.filter.date_filter import DateFilter
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.filter.int_filter import IntFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
from cpl.graphql.schema.input import Input
class DbModelFilter(Filter[T], Generic[T]):
def __init__(self, public: bool = False):
Filter.__init__(self)
self.field("id", IntFilter).with_public(public)
self.field("deleted", BoolFilter).with_public(public)
# self.field("editor", AuthUserFilter)
self.field("created", DateFilter).with_public(public)
self.field("updated", DateFilter).with_public(public)

View File

@@ -13,4 +13,4 @@ class IntFilter(Input[int]):
self.field("isNull", int, optional=True)
self.field("isNotNull", int, optional=True)
self.field("in", list[int], optional=True)
self.field("notIn", list[int], optional=True)
self.field("notIn", list[int], optional=True)

View File

@@ -7,4 +7,4 @@ from cpl.graphql.schema.query import Query
class GraphType(Query, Generic[T]):
def __init__(self):
Query.__init__(self)
Query.__init__(self)

View File

@@ -5,10 +5,12 @@ import strawberry
from cpl.core.typing import T
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.field import Field
from cpl.graphql.typing import AttributeName
from cpl.graphql.utils.type_collector import TypeCollector
_PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
class Input(StrawberryProtocol, Generic[T]):
def __init__(self):
self._fields: Dict[str, Field] = {}
@@ -37,26 +39,29 @@ class Input(StrawberryProtocol, Generic[T]):
def get_fields(self) -> dict[str, Field]:
return self._fields
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True) -> Field:
def field(self, name: AttributeName, typ: Union[type, "Input"], optional: bool = True) -> Field:
if isinstance(name, property):
name = name.fget.__name__
self._fields[name] = Field(name, typ, optional=optional)
return self._fields[name]
def string_field(self, name: str, optional: bool = True) -> Field:
def string_field(self, name: AttributeName, optional: bool = True) -> Field:
return self.field(name, str)
def int_field(self, name: str, optional: bool = True) -> Field:
def int_field(self, name: AttributeName, optional: bool = True) -> Field:
return self.field(name, int, optional)
def float_field(self, name: str, optional: bool = True) -> Field:
def float_field(self, name: AttributeName, optional: bool = True) -> Field:
return self.field(name, float, optional)
def bool_field(self, name: str, optional: bool = True) -> Field:
def bool_field(self, name: AttributeName, optional: bool = True) -> Field:
return self.field(name, bool, optional)
def list_field(self, name: str, t: type, optional: bool = True) -> Field:
def list_field(self, name: AttributeName, t: type, optional: bool = True) -> Field:
return self.field(name, list[t], optional)
def object_field(self, name: str, t: Type[StrawberryProtocol], optional: bool = True) -> Field:
def object_field(self, name: AttributeName, t: Type[StrawberryProtocol], optional: bool = True) -> Field:
return self.field(name, t().to_strawberry(), optional)
def to_strawberry(self) -> Type:

View File

@@ -3,4 +3,4 @@ from enum import Enum, auto
class SortOrder(Enum):
ASC = "ASC"
DESC = "DESC"
DESC = "DESC"

View File

@@ -16,10 +16,10 @@ class GraphQLService:
self._schema = schema.schema
async def execute(
self,
query: str,
variables: Optional[Dict[str, Any]],
request: TRequest,
self,
query: str,
variables: Optional[Dict[str, Any]],
request: TRequest,
) -> Dict[str, Any]:
result = await self._schema.execute(
query,

View File

@@ -6,6 +6,7 @@ import strawberry
from cpl.api.logger import APILogger
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType
from cpl.graphql.schema.root_mutation import RootMutation
from cpl.graphql.schema.root_query import RootQuery
@@ -16,7 +17,9 @@ class Schema:
self._logger = logger
self._provider = provider
self._types: dict[str, Type[StrawberryProtocol]] = {}
self._types: dict[str, Type[StrawberryProtocol]] = {
"AuthUserGraphType": AuthUserGraphType,
}
self._schema = None

View File

@@ -7,9 +7,7 @@ from cpl.graphql.query_context import QueryContext
TQuery = Type["Query"]
Resolver = Callable
ScalarType = str | int | float | bool | object
AttributeName = str | property
TRequireAnyPermissions = List[Enum | Permissions] | None
TRequireAnyResolvers = List[
Callable[[QueryContext], bool | Awaitable[bool]],
]
TRequireAnyResolvers = List[Callable[[QueryContext], bool | Awaitable[bool]],]
TRequireAny = Tuple[TRequireAnyPermissions, TRequireAnyResolvers]

View File

@@ -1,4 +1,4 @@
from typing import Type
from typing import Type, Any
class TypeCollector:
@@ -14,4 +14,4 @@ class TypeCollector:
@classmethod
def set(cls, base: type, gql_type: Type):
cls._registry[base] = gql_type
cls._registry[base] = gql_type