Further gql improvements & added test data #181
Some checks failed
Test before pr merge / test-lint (pull_request) Failing after 5s

This commit is contained in:
2025-09-27 21:57:33 +02:00
parent 2e98159d4e
commit af7945fe92
27 changed files with 305 additions and 88 deletions

View File

@@ -1,7 +1,7 @@
from starlette.responses import JSONResponse
from api.src.queries.cities import CityGraphType, CityFilter, CitySort
from api.src.queries.hello import UserGraphType
from api.src.queries.hello import UserGraphType, AuthUserFilter, AuthUserSort, AuthUserGraphType
from api.src.queries.user import UserFilter, UserSort
from cpl.api.api_module import ApiModule
from cpl.application.application_builder import ApplicationBuilder
@@ -14,9 +14,12 @@ from cpl.core.utils.cache import Cache
from cpl.database.mysql.mysql_module import MySQLModule
from cpl.graphql.application.graphql_app import GraphQLApp
from cpl.graphql.graphql_module import GraphQLModule
from model.post_dao import PostDao
from model.post_query import PostFilter, PostSort, PostGraphType
from queries.hello import HelloQuery
from scoped_service import ScopedService
from service import PingService
from test_data_seeder import TestDataSeeder
def main():
@@ -27,29 +30,38 @@ def main():
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
# builder.services.add_logging()
(
builder.services.add_structured_logging()
builder.services.add_transient(PingService)
builder.services.add_module(MySQLModule)
builder.services.add_module(ApiModule)
builder.services.add_module(GraphQLModule)
.add_transient(PingService)
.add_module(MySQLModule)
.add_module(ApiModule)
.add_module(GraphQLModule)
.add_scoped(ScopedService)
.add_cache(AuthUser)
.add_cache(Role)
.add_transient(CityGraphType)
.add_transient(CityFilter)
.add_transient(CitySort)
.add_transient(UserGraphType)
.add_transient(UserFilter)
.add_transient(UserSort)
.add_transient(AuthUserGraphType)
.add_transient(AuthUserFilter)
.add_transient(AuthUserSort)
.add_transient(HelloQuery)
# posts
.add_transient(PostDao)
.add_transient(PostGraphType)
.add_transient(PostFilter)
.add_transient(PostSort)
builder.services.add_scoped(ScopedService)
builder.services.add_cache(AuthUser)
builder.services.add_cache(Role)
builder.services.add_transient(CityGraphType)
builder.services.add_transient(CityFilter)
builder.services.add_transient(CitySort)
builder.services.add_transient(UserGraphType)
builder.services.add_transient(UserFilter)
builder.services.add_transient(UserSort)
builder.services.add_transient(HelloQuery)
# test data
.add_singleton(TestDataSeeder)
)
app = builder.build()
app.with_logging()
app.with_migrations("./scripts")
app.with_authentication()
app.with_authorization()
@@ -66,6 +78,7 @@ def main():
schema = app.with_graphql()
schema.query.string_field("ping", resolver=lambda: "pong")
schema.query.with_query("hello", HelloQuery)
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort)
app.with_playground()
app.with_graphiql()

View File

View File

@@ -0,0 +1,30 @@
from datetime import datetime
from typing import Self
from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC
class Post(DbModelABC[Self]):
def __init__(
self,
id: int,
title: str,
content: str,
deleted: bool = False,
editor_id: SerialId | None = None,
created: datetime | None = None,
updated: datetime | None = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._title = title
self._content = content
@property
def title(self) -> str:
return self._title
@property
def content(self) -> str:
return self._content

View File

@@ -0,0 +1,11 @@
from cpl.database.abc import DbModelDaoABC
from model.post import Post
class PostDao(DbModelDaoABC):
def __init__(self):
DbModelDaoABC.__init__(self, Post, "posts")
self.attribute(Post.title, str)
self.attribute(Post.content, str)

View File

@@ -0,0 +1,38 @@
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
from model.post import Post
class PostFilter(Filter[Post]):
def __init__(self):
Filter.__init__(self)
self.field("id", int)
self.field("title", str)
self.field("content", str)
class PostSort(Sort[Post]):
def __init__(self):
Sort.__init__(self)
self.field("id", SortOrder)
self.field("title", SortOrder)
self.field("content", SortOrder)
class PostGraphType(GraphType[Post]):
def __init__(self):
GraphType.__init__(self)
self.int_field(
"id",
resolver=lambda root: root.id,
)
self.string_field(
"title",
resolver=lambda root: root.title,
)
self.string_field(
"content",
resolver=lambda root: root.content,
)

View File

@@ -1,11 +1,43 @@
from api.src.queries.cities import CityFilter, CitySort, CityGraphType, City
from api.src.queries.user import User, UserFilter, UserSort, UserGraphType
from cpl.api.middleware.request import get_request
from cpl.auth.schema import AuthUserDao, AuthUser
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.query import Query
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
users = [User(i, f"User {i}") for i in range(1, 101)]
cities = [City(i, f"City {i}") for i in range(1, 101)]
class AuthUserFilter(Filter[AuthUser]):
def __init__(self):
Filter.__init__(self)
self.field("id", int)
self.field("username", str)
class AuthUserSort(Sort[AuthUser]):
def __init__(self):
Sort.__init__(self)
self.field("id", SortOrder)
self.field("username", SortOrder)
class AuthUserGraphType(GraphType[AuthUser]):
def __init__(self):
GraphType.__init__(self)
self.int_field(
"id",
resolver=lambda root: root.id,
)
self.string_field(
"username",
resolver=lambda root: root.username,
)
class HelloQuery(Query):
def __init__(self):
Query.__init__(self)
@@ -28,3 +60,10 @@ class HelloQuery(Query):
CitySort,
resolver=lambda: cities,
)
self.dao_collection_field(
AuthUserGraphType,
AuthUserDao,
"authUsers",
AuthUserFilter,
AuthUserSort,
)

View File

@@ -0,0 +1,10 @@
CREATE TABLE IF NOT EXISTS `posts` (
`id` INT(30) NOT NULL AUTO_INCREMENT,
`title` VARCHAR(64) NOT NULL,
`content` VARCHAR(512) NOT NULL,
deleted BOOLEAN NOT NULL DEFAULT FALSE,
editorId INT NULL,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY(`id`)
);

View File

@@ -0,0 +1,31 @@
from faker import Faker
from cpl.database.abc import DataSeederABC
from cpl.query import Enumerable
from model.post import Post
from model.post_dao import PostDao
fake = Faker()
class TestDataSeeder(DataSeederABC):
def __init__(self, posts: PostDao):
DataSeederABC.__init__(self)
self._posts = posts
async def seed(self):
if await self._posts.count() == 0:
await self._seed_posts()
async def _seed_posts(self):
posts = Enumerable.range(0, 100).select(
lambda x: Post(
id=0,
title=fake.sentence(nb_words=6),
content=fake.paragraph(nb_sentences=6),
)
).to_list()
await self._posts.create_many(posts, skip_editor=True)

View File

@@ -5,16 +5,16 @@ from cpl.core.typing import SerialId
from cpl.database.abc.db_model_abc import DbModelABC
class City(DbModelABC):
class City(DbModelABC[Self]):
def __init__(
self,
id: int,
name: str,
zip: str,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name

View File

@@ -5,7 +5,7 @@ from cpl.core.typing import SerialId
from cpl.database.abc.db_model_abc import DbModelABC
class User(DbModelABC):
class User(DbModelABC[Self]):
def __init__(
self,
@@ -13,9 +13,9 @@ class User(DbModelABC):
name: str,
city_id: int = 0,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name

View File

@@ -1,6 +1,6 @@
import secrets
from datetime import datetime
from typing import Optional, Union
from typing import Optional, Union, Self
from async_property import async_property
@@ -16,7 +16,7 @@ from cpl.dependency.service_provider import ServiceProvider
_logger = Logger(__name__)
class ApiKey(DbModelABC):
class ApiKey(DbModelABC[Self]):
def __init__(
self,
@@ -25,8 +25,8 @@ class ApiKey(DbModelABC):
key: Union[str, bytes],
deleted: bool = False,
editor_id: Optional[Id] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
created: datetime | None= None,
updated: datetime | None= None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._identifier = identifier

View File

@@ -1,6 +1,6 @@
import uuid
from datetime import datetime
from typing import Optional
from typing import Optional, Self
from async_property import async_property
from keycloak import KeycloakGetError
@@ -13,15 +13,15 @@ from cpl.database.logger import DBLogger
from cpl.dependency import get_provider
class AuthUser(DbModelABC):
class AuthUser(DbModelABC[Self]):
def __init__(
self,
id: SerialId,
keycloak_id: str,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._keycloak_id = keycloak_id
@@ -87,4 +87,3 @@ class AuthUser(DbModelABC):
self._keycloak_id = str(uuid.UUID(int=0))
await auth_user_dao.update(self)

View File

@@ -5,7 +5,7 @@ from cpl.auth.schema._administration.auth_user import AuthUser
from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.dependency import ServiceProvider
from cpl.dependency.context import get_provider
class AuthUserDao(DbModelDaoABC[AuthUser]):

View File

@@ -15,9 +15,9 @@ class ApiKeyPermission(DbJoinModelABC):
api_key_id: SerialId,
permission_id: SerialId,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
):
DbJoinModelABC.__init__(self, api_key_id, permission_id, id, deleted, editor_id, created, updated)
self._api_key_id = api_key_id

View File

@@ -1,20 +1,20 @@
from datetime import datetime
from typing import Optional
from typing import Optional, Self
from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC
class Permission(DbModelABC):
class Permission(DbModelABC[Self]):
def __init__(
self,
id: SerialId,
name: str,
description: str,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Optional
from typing import Optional, Self
from async_property import async_property
@@ -9,16 +9,16 @@ from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProvider
class Role(DbModelABC):
class Role(DbModelABC[Self]):
def __init__(
self,
id: SerialId,
name: str,
description: str,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Optional
from typing import Optional, Self
from async_property import async_property
@@ -8,16 +8,16 @@ from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProvider
class RolePermission(DbModelABC):
class RolePermission(DbModelABC[Self]):
def __init__(
self,
id: SerialId,
role_id: SerialId,
permission_id: SerialId,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._role_id = role_id

View File

@@ -15,9 +15,9 @@ class RoleUser(DbJoinModelABC):
user_id: SerialId,
role_id: SerialId,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
):
DbJoinModelABC.__init__(self, id, user_id, role_id, deleted, editor_id, created, updated)
self._user_id = user_id

View File

@@ -2,10 +2,6 @@ import os
from cryptography.fernet import Fernet
from cpl.core.log.logger import Logger
_logger = Logger(__name__)
class CredentialManager:
r"""Handles credential encryption and decryption"""
@@ -14,6 +10,7 @@ class CredentialManager:
@classmethod
def with_secret(cls, file: str = None):
from cpl.core.log import Logger
if file is None:
file = ".secret"
@@ -25,12 +22,12 @@ class CredentialManager:
with open(file, "w") as secret_file:
secret_file.write(Fernet.generate_key().decode())
secret_file.close()
_logger.warning("Secret file not found, regenerating")
Logger(__name__).warning("Secret file not found, regenerating")
with open(file, "r") as secret_file:
secret = secret_file.read().strip()
if secret == "" or secret is None:
_logger.fatal("No secret found in .secret file.")
Logger(__name__).fatal("No secret found in .secret file.")
cls._secret = str(secret)

View File

@@ -46,6 +46,10 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
def table_name(self) -> str:
return self._table_name
@property
def type(self) -> Type[T_DBM]:
return self._model_type
def has_attribute(self, attr_name: Attribute) -> bool:
"""
Check if the attribute exists in the DAO
@@ -490,16 +494,16 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
table, join_condition = self.__foreign_tables[attr]
builder.with_left_join(table, join_condition)
if filters:
if filters is not None:
await self._build_conditions(builder, filters, external_table_deps)
if sorts:
if sorts is not None:
self._build_sorts(builder, sorts, external_table_deps)
if take:
if take is not None:
builder.with_limit(take)
if skip:
if skip is not None:
builder.with_offset(skip)
for external_table in external_table_deps:

View File

@@ -12,9 +12,9 @@ class DbJoinModelABC[T](DbModelABC[T]):
source_id: Id,
foreign_id: Id,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)

View File

@@ -10,9 +10,9 @@ class DbModelABC(ABC, Generic[T]):
self,
id: Id,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
editor_id: SerialId | None = None,
created: datetime | None= None,
updated: datetime | None= None,
):
self._id = id
self._deleted = deleted

View File

@@ -1,6 +1,6 @@
from typing import Optional
from cpl.core.configuration import Configuration
from cpl.core.configuration.configuration import Configuration
from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC

View File

@@ -1,15 +1,15 @@
from datetime import datetime
from typing import Optional
from typing import Optional, Self
from cpl.database.abc import DbModelABC
class ExecutedMigration(DbModelABC):
class ExecutedMigration(DbModelABC[Self]):
def __init__(
self,
migration_id: str,
created: Optional[datetime] = None,
modified: Optional[datetime] = None,
created: datetime | None= None,
modified: datetime | None= None,
):
DbModelABC.__init__(self, migration_id, False, created, modified)

View File

@@ -18,7 +18,7 @@ class CollectionGraphTypeFactory:
gql_type = strawberry.type(
type(
f"{node_type.__name__}Collection",
f"{node_type.__name__.replace("GraphType", "")}Collection",
(),
{
"__annotations__": {

View File

@@ -4,6 +4,7 @@ from typing import Callable, Type, Any, Optional
import strawberry
from strawberry.exceptions import StrawberryException
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
@@ -69,9 +70,6 @@ class Query(StrawberryProtocol):
sort_type: Type[StrawberryProtocol],
resolver: Callable,
) -> Field:
# self._schema.with_type(filter_type)
# self._schema.with_type(sort_type)
def _resolve_collection(filter=None, sort=None, skip=0, take=10):
items = resolver()
if filter:
@@ -103,6 +101,53 @@ class Query(StrawberryProtocol):
f.with_argument(int, "take", default_value=10)
return f
def dao_collection_field(
self,
t: Type[StrawberryProtocol],
dao_type: Type[DataAccessObjectABC],
name: str,
filter_type: Type[StrawberryProtocol],
sort_type: Type[StrawberryProtocol],
) -> Field:
assert issubclass(dao_type, DataAccessObjectABC), "dao_type must be a subclass of DataAccessObjectABC"
dao = self._provider.get_service(dao_type)
if not dao:
raise ValueError(f"DAO '{dao_type.__name__}' not registered in service provider")
filter = self._provider.get_service(filter_type)
if not filter:
raise ValueError(f"Filter '{filter_type.__name__}' not registered in service provider")
sort = self._provider.get_service(sort_type)
if not sort:
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
async def _resolver(filter=None, sort=None, take=10, skip=0):
sort_dict = None
if sort is not None:
sort_dict = {}
for k, v in sort.__dict__.items():
if v is None:
continue
if isinstance(v, SortOrder):
sort_dict[k] = str(v.value).lower()
continue
sort_dict[k] = str(v).lower()
total_count = await dao.count(filter)
data = await dao.find_by(filter, sort_dict, take, skip)
return Collection(nodes=data, total_count=total_count, count=len(data))
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver)
f.with_argument(filter.to_strawberry(), "filter")
f.with_argument(sort.to_strawberry(), "sort")
f.with_argument(int, "skip", default_value=0)
f.with_argument(int, "take", default_value=10)
return f
@staticmethod
def _build_resolver(f: "Field"):
params: list[inspect.Parameter] = []
@@ -164,4 +209,4 @@ class Query(StrawberryProtocol):
namespace[name] = self._field_to_strawberry(f)
namespace["__annotations__"] = annotations
return strawberry.type(type(f"{self.__class__.__name__}GraphType", (), namespace))
return strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))

View File

@@ -2,5 +2,5 @@ from enum import Enum, auto
class SortOrder(Enum):
ASC = auto()
DESC = auto()
ASC = "ASC"
DESC = "DESC"