WIP: dev into master #184

Draft
edraft wants to merge 121 commits from dev into master
41 changed files with 281 additions and 134 deletions
Showing only changes of commit 39351a5eb9 - Show all commits

View File

@@ -1,7 +1,7 @@
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from api.src.queries.cities import CityGraphType, CityFilter, CitySort from api.src.queries.cities import CityGraphType, CityFilter, CitySort
from api.src.queries.hello import UserGraphType, AuthUserFilter, AuthUserSort, AuthUserGraphType from api.src.queries.hello import UserGraphType#, AuthUserFilter, AuthUserSort, AuthUserGraphType
from api.src.queries.user import UserFilter, UserSort from api.src.queries.user import UserFilter, UserSort
from cpl.api.api_module import ApiModule from cpl.api.api_module import ApiModule
from cpl.application.application_builder import ApplicationBuilder from cpl.application.application_builder import ApplicationBuilder
@@ -47,9 +47,9 @@ def main():
.add_transient(UserGraphType) .add_transient(UserGraphType)
.add_transient(UserFilter) .add_transient(UserFilter)
.add_transient(UserSort) .add_transient(UserSort)
.add_transient(AuthUserGraphType) # .add_transient(AuthUserGraphType)
.add_transient(AuthUserFilter) # .add_transient(AuthUserFilter)
.add_transient(AuthUserSort) # .add_transient(AuthUserSort)
.add_transient(HelloQuery) .add_transient(HelloQuery)
# test data # test data
.add_singleton(TestDataSeeder) .add_singleton(TestDataSeeder)

View File

@@ -1,12 +1,12 @@
from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
from cpl.graphql.schema.graph_type import GraphType from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder from cpl.graphql.schema.sort.sort_order import SortOrder
from model.author import Author from model.author import Author
class AuthorFilter(Filter[Author]): class AuthorFilter(DbModelFilter[Author]):
def __init__(self): def __init__(self):
Filter.__init__(self) DbModelFilter.__init__(self, public=True)
self.int_field("id") self.int_field("id")
self.string_field("firstName") self.string_field("firstName")
self.string_field("lastName") self.string_field("lastName")
@@ -18,10 +18,10 @@ class AuthorSort(Sort[Author]):
self.field("firstName", SortOrder) self.field("firstName", SortOrder)
self.field("lastName", SortOrder) self.field("lastName", SortOrder)
class AuthorGraphType(GraphType[Author]): class AuthorGraphType(DbModelGraphType[Author]):
def __init__(self): def __init__(self):
GraphType.__init__(self) DbModelGraphType.__init__(self, public=True)
self.int_field( self.int_field(
"id", "id",

View File

@@ -1,6 +1,6 @@
from cpl.graphql.query_context import QueryContext from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
from cpl.graphql.schema.graph_type import GraphType from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
from cpl.graphql.schema.input import Input from cpl.graphql.schema.input import Input
from cpl.graphql.schema.mutation import Mutation from cpl.graphql.schema.mutation import Mutation
from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort import Sort
@@ -11,9 +11,9 @@ from model.post import Post
from model.post_dao import PostDao from model.post_dao import PostDao
class PostFilter(Filter[Post]): class PostFilter(DbModelFilter[Post]):
def __init__(self): def __init__(self):
Filter.__init__(self) DbModelFilter.__init__(self, public=True)
self.int_field("id") self.int_field("id")
self.filter_field("author", AuthorFilter) self.filter_field("author", AuthorFilter)
self.string_field("title") self.string_field("title")
@@ -26,15 +26,15 @@ class PostSort(Sort[Post]):
self.field("title", SortOrder) self.field("title", SortOrder)
self.field("content", SortOrder) self.field("content", SortOrder)
class PostGraphType(GraphType[Post]): class PostGraphType(DbModelGraphType[Post]):
def __init__(self, authors: AuthorDao): def __init__(self, authors: AuthorDao):
GraphType.__init__(self) DbModelGraphType.__init__(self, public=True)
self.int_field( self.int_field(
"id", "id",
resolver=lambda root: root.id, resolver=lambda root: root.id,
).with_public(True) ).with_optional().with_public(True)
async def _a(root: Post): async def _a(root: Post):
return await authors.get_by_id(root.author_id) return await authors.get_by_id(root.author_id)

View File

@@ -11,32 +11,32 @@ from cpl.graphql.schema.sort.sort_order import SortOrder
users = [User(i, f"User {i}") for i in range(1, 101)] users = [User(i, f"User {i}") for i in range(1, 101)]
cities = [City(i, f"City {i}") for i in range(1, 101)] cities = [City(i, f"City {i}") for i in range(1, 101)]
class AuthUserFilter(Filter[AuthUser]): # class AuthUserFilter(Filter[AuthUser]):
def __init__(self): # def __init__(self):
Filter.__init__(self) # Filter.__init__(self)
self.field("id", int) # self.field("id", int)
self.field("username", str) # self.field("username", str)
#
#
class AuthUserSort(Sort[AuthUser]): # class AuthUserSort(Sort[AuthUser]):
def __init__(self): # def __init__(self):
Sort.__init__(self) # Sort.__init__(self)
self.field("id", SortOrder) # self.field("id", SortOrder)
self.field("username", SortOrder) # self.field("username", SortOrder)
#
class AuthUserGraphType(GraphType[AuthUser]): # class AuthUserGraphType(GraphType[AuthUser]):
#
def __init__(self): # def __init__(self):
GraphType.__init__(self) # GraphType.__init__(self)
#
self.int_field( # self.int_field(
"id", # "id",
resolver=lambda root: root.id, # resolver=lambda root: root.id,
) # )
self.string_field( # self.string_field(
"username", # "username",
resolver=lambda root: root.username, # resolver=lambda root: root.username,
) # )
class HelloQuery(Query): class HelloQuery(Query):
def __init__(self): def __init__(self):
@@ -60,10 +60,10 @@ class HelloQuery(Query):
CitySort, CitySort,
resolver=lambda: cities, resolver=lambda: cities,
) )
self.dao_collection_field( # self.dao_collection_field(
AuthUserGraphType, # AuthUserGraphType,
AuthUserDao, # AuthUserDao,
"authUsers", # "authUsers",
AuthUserFilter, # AuthUserFilter,
AuthUserSort, # AuthUserSort,
) # )

View File

@@ -36,7 +36,9 @@ from cpl.dependency.typing import Modules
class WebApp(WebAppABC): class WebApp(WebAppABC):
def __init__(self, services: ServiceProvider, modules: Modules, required_modules: list[str | object] = None): def __init__(self, services: ServiceProvider, modules: Modules, required_modules: list[str | object] = None):
WebAppABC.__init__(self, services, modules, [AuthModule, PermissionsModule, ApiModule] + (required_modules or [])) WebAppABC.__init__(
self, services, modules, [AuthModule, PermissionsModule, ApiModule] + (required_modules or [])
)
self._app: Starlette | None = None self._app: Starlette | None = None
self._logger = services.get_service(APILogger) self._logger = services.get_service(APILogger)

View File

@@ -21,7 +21,9 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
class RequestMiddleware(ASGIMiddleware): class RequestMiddleware(ASGIMiddleware):
def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao): def __init__(
self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao
):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)
self._provider = provider self._provider = provider
@@ -92,5 +94,6 @@ class RequestMiddleware(ASGIMiddleware):
except Exception as e: except Exception as e:
self._logger.debug(f"Silent user binding failed: {e}") self._logger.debug(f"Silent user binding failed: {e}")
def get_request() -> Optional[TRequest]: def get_request() -> Optional[TRequest]:
return _request_context.get() return _request_context.get()

View File

@@ -23,8 +23,8 @@ class RoleSeeder(DataSeederABC):
role_permission_dao: RolePermissionDao, role_permission_dao: RolePermissionDao,
api_key_dao: ApiKeyDao, api_key_dao: ApiKeyDao,
api_key_permission_dao: ApiKeyPermissionDao, api_key_permission_dao: ApiKeyPermissionDao,
user_dao: AuthUserDao, user_dao: AuthUserDao,
role_user_dao: RoleUserDao, role_user_dao: RoleUserDao,
): ):
DataSeederABC.__init__(self) DataSeederABC.__init__(self)
self._logger = logger self._logger = logger

View File

@@ -25,8 +25,8 @@ class ApiKey(DbModelABC[Self]):
key: Union[str, bytes], key: Union[str, bytes],
deleted: bool = False, deleted: bool = False,
editor_id: Optional[Id] = None, editor_id: Optional[Id] = None,
created: datetime | None= None, created: datetime | None = None,
updated: datetime | None= None, updated: datetime | None = None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._identifier = identifier self._identifier = identifier

View File

@@ -20,8 +20,8 @@ class AuthUser(DbModelABC[Self]):
keycloak_id: str, keycloak_id: str,
deleted: bool = False, deleted: bool = False,
editor_id: SerialId | None = None, editor_id: SerialId | None = None,
created: datetime | None= None, created: datetime | None = None,
updated: datetime | None= None, updated: datetime | None = None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._keycloak_id = keycloak_id self._keycloak_id = keycloak_id

View File

@@ -16,8 +16,8 @@ class ApiKeyPermission(DbJoinModelABC):
permission_id: SerialId, permission_id: SerialId,
deleted: bool = False, deleted: bool = False,
editor_id: SerialId | None = None, editor_id: SerialId | None = None,
created: datetime | None= None, created: datetime | None = None,
updated: datetime | None= None, updated: datetime | None = None,
): ):
DbJoinModelABC.__init__(self, api_key_id, permission_id, id, deleted, editor_id, created, updated) DbJoinModelABC.__init__(self, api_key_id, permission_id, id, deleted, editor_id, created, updated)
self._api_key_id = api_key_id self._api_key_id = api_key_id

View File

@@ -13,8 +13,8 @@ class Permission(DbModelABC[Self]):
description: str, description: str,
deleted: bool = False, deleted: bool = False,
editor_id: SerialId | None = None, editor_id: SerialId | None = None,
created: datetime | None= None, created: datetime | None = None,
updated: datetime | None= None, updated: datetime | None = None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name self._name = name

View File

@@ -17,8 +17,8 @@ class Role(DbModelABC[Self]):
description: str, description: str,
deleted: bool = False, deleted: bool = False,
editor_id: SerialId | None = None, editor_id: SerialId | None = None,
created: datetime | None= None, created: datetime | None = None,
updated: datetime | None= None, updated: datetime | None = None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name self._name = name

View File

@@ -16,8 +16,8 @@ class RolePermission(DbModelABC[Self]):
permission_id: SerialId, permission_id: SerialId,
deleted: bool = False, deleted: bool = False,
editor_id: SerialId | None = None, editor_id: SerialId | None = None,
created: datetime | None= None, created: datetime | None = None,
updated: datetime | None= None, updated: datetime | None = None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._role_id = role_id self._role_id = role_id

View File

@@ -16,8 +16,8 @@ class RoleUser(DbJoinModelABC):
role_id: SerialId, role_id: SerialId,
deleted: bool = False, deleted: bool = False,
editor_id: SerialId | None = None, editor_id: SerialId | None = None,
created: datetime | None= None, created: datetime | None = None,
updated: datetime | None= None, updated: datetime | None = None,
): ):
DbJoinModelABC.__init__(self, id, user_id, role_id, deleted, editor_id, created, updated) DbJoinModelABC.__init__(self, id, user_id, role_id, deleted, editor_id, created, updated)
self._user_id = user_id self._user_id = user_id

View File

@@ -11,6 +11,7 @@ class CredentialManager:
@classmethod @classmethod
def with_secret(cls, file: str = None): def with_secret(cls, file: str = None):
from cpl.core.log import Logger from cpl.core.log import Logger
if file is None: if file is None:
file = ".secret" file = ".secret"

View File

@@ -13,8 +13,8 @@ class DbJoinModelABC[T](DbModelABC[T]):
foreign_id: Id, foreign_id: Id,
deleted: bool = False, deleted: bool = False,
editor_id: SerialId | None = None, editor_id: SerialId | None = None,
created: datetime | None= None, created: datetime | None = None,
updated: datetime | None= None, updated: datetime | None = None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)

View File

@@ -2,7 +2,10 @@ from abc import ABC
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional, Generic from typing import Optional, Generic
from async_property import async_property
from cpl.core.typing import Id, SerialId, T from cpl.core.typing import Id, SerialId, T
from cpl.dependency import get_provider
class DbModelABC(ABC, Generic[T]): class DbModelABC(ABC, Generic[T]):
@@ -11,8 +14,8 @@ class DbModelABC(ABC, Generic[T]):
id: Id, id: Id,
deleted: bool = False, deleted: bool = False,
editor_id: SerialId | None = None, editor_id: SerialId | None = None,
created: datetime | None= None, created: datetime | None = None,
updated: datetime | None= None, updated: datetime | None = None,
): ):
self._id = id self._id = id
self._deleted = deleted self._deleted = deleted
@@ -41,14 +44,16 @@ class DbModelABC(ABC, Generic[T]):
def editor_id(self, value: SerialId): def editor_id(self, value: SerialId):
self._editor_id = value self._editor_id = value
# @async_property @async_property
# async def editor(self): async def editor(self):
# if self._editor_id is None: if self._editor_id is None:
# return None return None
#
# from data.schemas.administration.user_dao import userDao from cpl.auth.schema import AuthUserDao
#
# return await userDao.get_by_id(self._editor_id) auth_user_dao = get_provider().get_service(AuthUserDao)
return await auth_user_dao.get_by_id(self._editor_id)
@property @property
def created(self) -> datetime: def created(self) -> datetime:

View File

@@ -8,8 +8,8 @@ class ExecutedMigration(DbModelABC[Self]):
def __init__( def __init__(
self, self,
migration_id: str, migration_id: str,
created: datetime | None= None, created: datetime | None = None,
modified: datetime | None= None, modified: datetime | None = None,
): ):
DbModelABC.__init__(self, migration_id, False, created, modified) DbModelABC.__init__(self, migration_id, False, created, modified)

View File

@@ -1,7 +1,9 @@
from starlette.responses import HTMLResponse from starlette.responses import HTMLResponse
async def graphiql_endpoint(request): async def graphiql_endpoint(request):
return HTMLResponse(""" return HTMLResponse(
"""
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
@@ -34,4 +36,5 @@ async def graphiql_endpoint(request):
</script> </script>
</body> </body>
</html> </html>
""") """
)

View File

@@ -3,7 +3,8 @@ from starlette.responses import Response, HTMLResponse
async def playground_endpoint(request: Request) -> Response: async def playground_endpoint(request: Request) -> Response:
return HTMLResponse(""" return HTMLResponse(
"""
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
@@ -24,4 +25,5 @@ async def playground_endpoint(request: Request) -> Response:
</script> </script>
</body> </body>
</html> </html>
""") """
)

View File

@@ -1,7 +1,8 @@
import functools import functools
import inspect import inspect
import types
from abc import ABC from abc import ABC
from asyncio import iscoroutinefunction from asyncio import iscoroutinefunction, iscoroutine
from typing import Callable, Type, Any, Optional from typing import Callable, Type, Any, Optional
import strawberry import strawberry
@@ -9,11 +10,12 @@ from strawberry.exceptions import StrawberryException
from cpl.api import Unauthorized, Forbidden from cpl.api import Unauthorized, Forbidden
from cpl.core.ctx.user_context import get_user from cpl.core.ctx.user_context import get_user
from cpl.dependency import get_provider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.error import graphql_error from cpl.graphql.error import graphql_error
from cpl.graphql.query_context import QueryContext from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.field import Field from cpl.graphql.schema.field import Field
from cpl.graphql.typing import Resolver from cpl.graphql.typing import Resolver, AttributeName
from cpl.graphql.utils.type_collector import TypeCollector from cpl.graphql.utils.type_collector import TypeCollector
@@ -36,31 +38,37 @@ class QueryABC(StrawberryProtocol, ABC):
def field( def field(
self, self,
name: str, name: AttributeName,
t: type, t: type,
resolver: Resolver = None, resolver: Resolver = None,
) -> Field: ) -> Field:
from cpl.graphql.schema.field import Field from cpl.graphql.schema.field import Field
if isinstance(name, property):
name = name.fget.__name__
self._fields[name] = Field(name, t, resolver) self._fields[name] = Field(name, t, resolver)
return self._fields[name] return self._fields[name]
def string_field(self, name: str, resolver: Resolver = None) -> Field: def string_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
return self.field(name, str, resolver) return self.field(name, str, resolver)
def int_field(self, name: str, resolver: Resolver = None) -> Field: def int_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
return self.field(name, int, resolver) return self.field(name, int, resolver)
def float_field(self, name: str, resolver: Resolver = None) -> Field: def float_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
return self.field(name, float, resolver) return self.field(name, float, resolver)
def bool_field(self, name: str, resolver: Resolver = None) -> Field: def bool_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
return self.field(name, bool, resolver) return self.field(name, bool, resolver)
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field: def list_field(self, name: AttributeName, t: type, resolver: Resolver = None) -> Field:
return self.field(name, list[t], resolver) return self.field(name, list[t], resolver)
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field: def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
if not isinstance(t, type) and callable(t):
return self.field(name, t, resolver)
return self.field(name, t().to_strawberry(), resolver) return self.field(name, t().to_strawberry(), resolver)
@staticmethod @staticmethod
@@ -137,9 +145,10 @@ class QueryABC(StrawberryProtocol, ABC):
@staticmethod @staticmethod
async def _run_resolver(r: Callable, *args, **kwargs): async def _run_resolver(r: Callable, *args, **kwargs):
if iscoroutinefunction(r): result = r(*args, **kwargs)
return await r(*args, **kwargs) if inspect.isawaitable(result):
return r(*args, **kwargs) return await result
return result
def _field_to_strawberry(self, f: Field) -> Any: def _field_to_strawberry(self, f: Field) -> Any:
resolver = None resolver = None
@@ -147,7 +156,7 @@ class QueryABC(StrawberryProtocol, ABC):
if f.arguments: if f.arguments:
resolver = self._build_resolver(f) resolver = self._build_resolver(f)
elif not f.resolver: elif not f.resolver:
resolver = lambda *_, **__: None resolver = lambda root: None
else: else:
ann = getattr(f.resolver, "__annotations__", {}) ann = getattr(f.resolver, "__annotations__", {})
if "return" not in ann or ann["return"] is None: if "return" not in ann or ann["return"] is None:
@@ -165,14 +174,31 @@ class QueryABC(StrawberryProtocol, ABC):
if TypeCollector.has(cls): if TypeCollector.has(cls):
return TypeCollector.get(cls) return TypeCollector.get(cls)
gql_cls = type(f"{cls.__name__.replace('GraphType', '')}", (), {})
# register early to handle recursive types
TypeCollector.set(cls, gql_cls)
annotations: dict[str, Any] = {} annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {} namespace: dict[str, Any] = {}
for name, f in self._fields.items(): for name, f in self._fields.items():
annotations[name] = f.type t = f.type
if callable(t) and not isinstance(t, type):
_t = get_provider().get_service(t())
if isinstance(_t, StrawberryProtocol):
t = _t.to_strawberry()
else:
t = _t
annotations[name] = t if not f.optional else Optional[t]
namespace[name] = self._field_to_strawberry(f) namespace[name] = self._field_to_strawberry(f)
namespace["__annotations__"] = annotations namespace["__annotations__"] = annotations
gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace)) for k, v in namespace.items():
setattr(gql_cls, k, v)
gql_cls.__annotations__ = annotations
gql_type = strawberry.type(gql_cls)
TypeCollector.set(cls, gql_type) TypeCollector.set(cls, gql_type)
return gql_type return gql_type

View File

@@ -0,0 +1,12 @@
from cpl.auth.schema import AuthUser
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
class AuthUserGraphType(DbModelGraphType):
def __init__(self):
DbModelGraphType.__init__(self)
self.string_field(AuthUser.keycloak_id, lambda root: root.keycloak_id)
self.string_field(AuthUser.username, lambda root: root.username)
self.string_field(AuthUser.email, lambda root: root.email)

View File

@@ -0,0 +1,6 @@
from cpl.dependency.module.module import Module
from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType
class GraphQLAuthModule(Module):
transient = [AuthUserGraphType]

View File

@@ -11,4 +11,4 @@ def graphql_error(api_error: APIError) -> GraphQLError:
"code": api_error.status_code, "code": api_error.status_code,
}, },
original_error=api_error, original_error=api_error,
) )

View File

@@ -1,6 +1,8 @@
from cpl.api.api_module import ApiModule from cpl.api.api_module import ApiModule
from cpl.dependency import ServiceCollection
from cpl.dependency.module.module import Module from cpl.dependency.module.module import Module
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule
from cpl.graphql.schema.filter.bool_filter import BoolFilter from cpl.graphql.schema.filter.bool_filter import BoolFilter
from cpl.graphql.schema.filter.date_filter import DateFilter from cpl.graphql.schema.filter.date_filter import DateFilter
from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.filter.filter import Filter
@@ -18,6 +20,10 @@ class GraphQLModule(Module):
scoped = [GraphQLService] scoped = [GraphQLService]
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter] transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]
@staticmethod
def register(collection: ServiceCollection):
collection.add_module(GraphQLAuthModule)
@staticmethod @staticmethod
def configure(services: ServiceProvider) -> None: def configure(services: ServiceProvider) -> None:
schema = services.get_service(Schema) schema = services.get_service(Schema)

View File

@@ -9,13 +9,7 @@ from cpl.core.ctx import get_user
class QueryContext: class QueryContext:
def __init__( def __init__(self, user_permissions: Optional[list[Enum | Permission]], is_mutation: bool = False, *args, **kwargs):
self,
user_permissions: Optional[list[Enum | Permission]],
is_mutation: bool = False,
*args,
**kwargs
):
self._user = get_user() self._user = get_user()
self._user_permissions = user_permissions or [] self._user_permissions = user_permissions or []

View File

@@ -19,7 +19,6 @@ class CollectionGraphTypeFactory:
if not node_t: if not node_t:
raise ValueError(f"Node type '{node_type.__name__}' not registered in service provider") raise ValueError(f"Node type '{node_type.__name__}' not registered in service provider")
gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type
gql_type = strawberry.type( gql_type = strawberry.type(

View File

@@ -0,0 +1,60 @@
from typing import Type, Optional, Generic, Annotated
import strawberry
from cpl.core.typing import T
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.query import Query
class DbModelGraphType(GraphType[T], Generic[T]):
def __init__(self, t_dao: Type[DataAccessObjectABC] = None, with_history: bool = False, public: bool = False):
Query.__init__(self)
self._dao: Optional[DataAccessObjectABC] = None
if t_dao is not None:
dao = self._provider.get_service(t_dao)
if dao is not None:
self._dao = dao
self.int_field("id", lambda root: root.id).with_public(public)
self.bool_field("deleted", lambda root: root.deleted).with_public(public)
from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType
self.object_field("editor", lambda: AuthUserGraphType, lambda root: root.editor).with_public(public)
self.string_field("created", lambda root: root.created).with_public(public)
self.string_field("updated", lambda root: root.updated).with_public(public)
# if with_history:
# if self._dao is None:
# raise ValueError("DAO must be provided to enable history")
# self.set_field("history", self._resolve_history).with_public(public)
self._history_reference_daos: dict[DataAccessObjectABC, str] = {}
async def _resolve_history(self, root):
if self._dao is None:
raise Exception("DAO not set for history query")
history = sorted(
[await self._dao.get_by_id(root.id), *await self._dao.get_history(root.id)],
key=lambda h: h.updated,
reverse=True,
)
return history
def set_history_reference_dao(self, dao: DataAccessObjectABC, key: str = None):
"""
Set the reference DAO for history resolution.
:param dao:
:param key: The key to use for resolving history.
:return:
"""
if key is None:
key = "id"
self._history_reference_daos[dao] = key

View File

@@ -5,7 +5,7 @@ from cpl.graphql.schema.argument import Argument
from cpl.graphql.typing import TQuery, Resolver, TRequireAnyPermissions, TRequireAnyResolvers from cpl.graphql.typing import TQuery, Resolver, TRequireAnyPermissions, TRequireAnyResolvers
class Field: class Field:
def __init__( def __init__(
self, self,
@@ -87,7 +87,7 @@ class Field:
self._resolver = resolver self._resolver = resolver
return self return self
def with_optional(self, optional: bool) -> Self: def with_optional(self, optional: bool = True) -> Self:
self._optional = optional self._optional = optional
return self return self
@@ -99,7 +99,9 @@ class Field:
self._default = default self._default = default
return self return self
def with_argument(self, name: str, arg_type: type, description: str = None, default_value=None, optional=True) -> Argument: def with_argument(
self, name: str, arg_type: type, description: str = None, default_value=None, optional=True
) -> Argument:
if name in self._args: if name in self._args:
raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'") raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'")
self._args[name] = Argument(name, arg_type, description, default_value, optional) self._args[name] = Argument(name, arg_type, description, default_value, optional)

View File

@@ -15,4 +15,4 @@ class DateFilter(Input[datetime]):
self.field("isNull", datetime, optional=True) self.field("isNull", datetime, optional=True)
self.field("isNotNull", datetime, optional=True) self.field("isNotNull", datetime, optional=True)
self.field("in", list[datetime], optional=True) self.field("in", list[datetime], optional=True)
self.field("notIn", list[datetime], optional=True) self.field("notIn", list[datetime], optional=True)

View File

@@ -0,0 +1,20 @@
from typing import Type, Generic
from cpl.core.typing import T
from cpl.graphql.schema.filter.bool_filter import BoolFilter
from cpl.graphql.schema.filter.date_filter import DateFilter
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.filter.int_filter import IntFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
from cpl.graphql.schema.input import Input
class DbModelFilter(Filter[T], Generic[T]):
def __init__(self, public: bool = False):
Filter.__init__(self)
self.field("id", IntFilter).with_public(public)
self.field("deleted", BoolFilter).with_public(public)
# self.field("editor", AuthUserFilter)
self.field("created", DateFilter).with_public(public)
self.field("updated", DateFilter).with_public(public)

View File

@@ -13,4 +13,4 @@ class IntFilter(Input[int]):
self.field("isNull", int, optional=True) self.field("isNull", int, optional=True)
self.field("isNotNull", int, optional=True) self.field("isNotNull", int, optional=True)
self.field("in", list[int], optional=True) self.field("in", list[int], optional=True)
self.field("notIn", list[int], optional=True) self.field("notIn", list[int], optional=True)

View File

@@ -7,4 +7,4 @@ from cpl.graphql.schema.query import Query
class GraphType(Query, Generic[T]): class GraphType(Query, Generic[T]):
def __init__(self): def __init__(self):
Query.__init__(self) Query.__init__(self)

View File

@@ -5,10 +5,12 @@ import strawberry
from cpl.core.typing import T from cpl.core.typing import T
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.field import Field from cpl.graphql.schema.field import Field
from cpl.graphql.typing import AttributeName
from cpl.graphql.utils.type_collector import TypeCollector from cpl.graphql.utils.type_collector import TypeCollector
_PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"} _PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
class Input(StrawberryProtocol, Generic[T]): class Input(StrawberryProtocol, Generic[T]):
def __init__(self): def __init__(self):
self._fields: Dict[str, Field] = {} self._fields: Dict[str, Field] = {}
@@ -37,26 +39,29 @@ class Input(StrawberryProtocol, Generic[T]):
def get_fields(self) -> dict[str, Field]: def get_fields(self) -> dict[str, Field]:
return self._fields return self._fields
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True) -> Field: def field(self, name: AttributeName, typ: Union[type, "Input"], optional: bool = True) -> Field:
if isinstance(name, property):
name = name.fget.__name__
self._fields[name] = Field(name, typ, optional=optional) self._fields[name] = Field(name, typ, optional=optional)
return self._fields[name] return self._fields[name]
def string_field(self, name: str, optional: bool = True) -> Field: def string_field(self, name: AttributeName, optional: bool = True) -> Field:
return self.field(name, str) return self.field(name, str)
def int_field(self, name: str, optional: bool = True) -> Field: def int_field(self, name: AttributeName, optional: bool = True) -> Field:
return self.field(name, int, optional) return self.field(name, int, optional)
def float_field(self, name: str, optional: bool = True) -> Field: def float_field(self, name: AttributeName, optional: bool = True) -> Field:
return self.field(name, float, optional) return self.field(name, float, optional)
def bool_field(self, name: str, optional: bool = True) -> Field: def bool_field(self, name: AttributeName, optional: bool = True) -> Field:
return self.field(name, bool, optional) return self.field(name, bool, optional)
def list_field(self, name: str, t: type, optional: bool = True) -> Field: def list_field(self, name: AttributeName, t: type, optional: bool = True) -> Field:
return self.field(name, list[t], optional) return self.field(name, list[t], optional)
def object_field(self, name: str, t: Type[StrawberryProtocol], optional: bool = True) -> Field: def object_field(self, name: AttributeName, t: Type[StrawberryProtocol], optional: bool = True) -> Field:
return self.field(name, t().to_strawberry(), optional) return self.field(name, t().to_strawberry(), optional)
def to_strawberry(self) -> Type: def to_strawberry(self) -> Type:

View File

@@ -3,4 +3,4 @@ from enum import Enum, auto
class SortOrder(Enum): class SortOrder(Enum):
ASC = "ASC" ASC = "ASC"
DESC = "DESC" DESC = "DESC"

View File

@@ -16,10 +16,10 @@ class GraphQLService:
self._schema = schema.schema self._schema = schema.schema
async def execute( async def execute(
self, self,
query: str, query: str,
variables: Optional[Dict[str, Any]], variables: Optional[Dict[str, Any]],
request: TRequest, request: TRequest,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
result = await self._schema.execute( result = await self._schema.execute(
query, query,

View File

@@ -6,6 +6,7 @@ import strawberry
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.auth.administration.auth_user_graph_type import AuthUserGraphType
from cpl.graphql.schema.root_mutation import RootMutation from cpl.graphql.schema.root_mutation import RootMutation
from cpl.graphql.schema.root_query import RootQuery from cpl.graphql.schema.root_query import RootQuery
@@ -16,7 +17,9 @@ class Schema:
self._logger = logger self._logger = logger
self._provider = provider self._provider = provider
self._types: dict[str, Type[StrawberryProtocol]] = {} self._types: dict[str, Type[StrawberryProtocol]] = {
"AuthUserGraphType": AuthUserGraphType,
}
self._schema = None self._schema = None

View File

@@ -7,9 +7,7 @@ from cpl.graphql.query_context import QueryContext
TQuery = Type["Query"] TQuery = Type["Query"]
Resolver = Callable Resolver = Callable
ScalarType = str | int | float | bool | object ScalarType = str | int | float | bool | object
AttributeName = str | property
TRequireAnyPermissions = List[Enum | Permissions] | None TRequireAnyPermissions = List[Enum | Permissions] | None
TRequireAnyResolvers = List[ TRequireAnyResolvers = List[Callable[[QueryContext], bool | Awaitable[bool]],]
Callable[[QueryContext], bool | Awaitable[bool]],
]
TRequireAny = Tuple[TRequireAnyPermissions, TRequireAnyResolvers] TRequireAny = Tuple[TRequireAnyPermissions, TRequireAnyResolvers]

View File

@@ -1,4 +1,4 @@
from typing import Type from typing import Type, Any
class TypeCollector: class TypeCollector:
@@ -14,4 +14,4 @@ class TypeCollector:
@classmethod @classmethod
def set(cls, base: type, gql_type: Type): def set(cls, base: type, gql_type: Type):
cls._registry[base] = gql_type cls._registry[base] = gql_type