WIP: dev into master #184

Draft
edraft wants to merge 121 commits from dev into master
27 changed files with 305 additions and 88 deletions
Showing only changes of commit d8c60defba - Show all commits

View File

@@ -1,7 +1,7 @@
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from api.src.queries.cities import CityGraphType, CityFilter, CitySort from api.src.queries.cities import CityGraphType, CityFilter, CitySort
from api.src.queries.hello import UserGraphType from api.src.queries.hello import UserGraphType, AuthUserFilter, AuthUserSort, AuthUserGraphType
from api.src.queries.user import UserFilter, UserSort from api.src.queries.user import UserFilter, UserSort
from cpl.api.api_module import ApiModule from cpl.api.api_module import ApiModule
from cpl.application.application_builder import ApplicationBuilder from cpl.application.application_builder import ApplicationBuilder
@@ -14,9 +14,12 @@ from cpl.core.utils.cache import Cache
from cpl.database.mysql.mysql_module import MySQLModule from cpl.database.mysql.mysql_module import MySQLModule
from cpl.graphql.application.graphql_app import GraphQLApp from cpl.graphql.application.graphql_app import GraphQLApp
from cpl.graphql.graphql_module import GraphQLModule from cpl.graphql.graphql_module import GraphQLModule
from model.post_dao import PostDao
from model.post_query import PostFilter, PostSort, PostGraphType
from queries.hello import HelloQuery from queries.hello import HelloQuery
from scoped_service import ScopedService from scoped_service import ScopedService
from service import PingService from service import PingService
from test_data_seeder import TestDataSeeder
def main(): def main():
@@ -27,29 +30,38 @@ def main():
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True) Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
# builder.services.add_logging() # builder.services.add_logging()
(
builder.services.add_structured_logging() builder.services.add_structured_logging()
builder.services.add_transient(PingService) .add_transient(PingService)
builder.services.add_module(MySQLModule) .add_module(MySQLModule)
builder.services.add_module(ApiModule) .add_module(ApiModule)
builder.services.add_module(GraphQLModule) .add_module(GraphQLModule)
.add_scoped(ScopedService)
.add_cache(AuthUser)
.add_cache(Role)
.add_transient(CityGraphType)
.add_transient(CityFilter)
.add_transient(CitySort)
.add_transient(UserGraphType)
.add_transient(UserFilter)
.add_transient(UserSort)
.add_transient(AuthUserGraphType)
.add_transient(AuthUserFilter)
.add_transient(AuthUserSort)
.add_transient(HelloQuery)
# posts
.add_transient(PostDao)
.add_transient(PostGraphType)
.add_transient(PostFilter)
.add_transient(PostSort)
builder.services.add_scoped(ScopedService) # test data
.add_singleton(TestDataSeeder)
builder.services.add_cache(AuthUser) )
builder.services.add_cache(Role)
builder.services.add_transient(CityGraphType)
builder.services.add_transient(CityFilter)
builder.services.add_transient(CitySort)
builder.services.add_transient(UserGraphType)
builder.services.add_transient(UserFilter)
builder.services.add_transient(UserSort)
builder.services.add_transient(HelloQuery)
app = builder.build() app = builder.build()
app.with_logging() app.with_logging()
app.with_migrations("./scripts")
app.with_authentication() app.with_authentication()
app.with_authorization() app.with_authorization()
@@ -66,6 +78,7 @@ def main():
schema = app.with_graphql() schema = app.with_graphql()
schema.query.string_field("ping", resolver=lambda: "pong") schema.query.string_field("ping", resolver=lambda: "pong")
schema.query.with_query("hello", HelloQuery) schema.query.with_query("hello", HelloQuery)
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort)
app.with_playground() app.with_playground()
app.with_graphiql() app.with_graphiql()

View File

View File

@@ -0,0 +1,30 @@
from datetime import datetime
from typing import Self
from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC
class Post(DbModelABC[Self]):
def __init__(
self,
id: int,
title: str,
content: str,
deleted: bool = False,
editor_id: SerialId | None = None,
created: datetime | None = None,
updated: datetime | None = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._title = title
self._content = content
@property
def title(self) -> str:
return self._title
@property
def content(self) -> str:
return self._content

View File

@@ -0,0 +1,11 @@
from cpl.database.abc import DbModelDaoABC
from model.post import Post
class PostDao(DbModelDaoABC):
def __init__(self):
DbModelDaoABC.__init__(self, Post, "posts")
self.attribute(Post.title, str)
self.attribute(Post.content, str)

View File

@@ -0,0 +1,38 @@
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
from model.post import Post
class PostFilter(Filter[Post]):
def __init__(self):
Filter.__init__(self)
self.field("id", int)
self.field("title", str)
self.field("content", str)
class PostSort(Sort[Post]):
def __init__(self):
Sort.__init__(self)
self.field("id", SortOrder)
self.field("title", SortOrder)
self.field("content", SortOrder)
class PostGraphType(GraphType[Post]):
def __init__(self):
GraphType.__init__(self)
self.int_field(
"id",
resolver=lambda root: root.id,
)
self.string_field(
"title",
resolver=lambda root: root.title,
)
self.string_field(
"content",
resolver=lambda root: root.content,
)

View File

@@ -1,11 +1,43 @@
from api.src.queries.cities import CityFilter, CitySort, CityGraphType, City from api.src.queries.cities import CityFilter, CitySort, CityGraphType, City
from api.src.queries.user import User, UserFilter, UserSort, UserGraphType from api.src.queries.user import User, UserFilter, UserSort, UserGraphType
from cpl.api.middleware.request import get_request from cpl.api.middleware.request import get_request
from cpl.auth.schema import AuthUserDao, AuthUser
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.query import Query from cpl.graphql.schema.query import Query
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
users = [User(i, f"User {i}") for i in range(1, 101)] users = [User(i, f"User {i}") for i in range(1, 101)]
cities = [City(i, f"City {i}") for i in range(1, 101)] cities = [City(i, f"City {i}") for i in range(1, 101)]
class AuthUserFilter(Filter[AuthUser]):
def __init__(self):
Filter.__init__(self)
self.field("id", int)
self.field("username", str)
class AuthUserSort(Sort[AuthUser]):
def __init__(self):
Sort.__init__(self)
self.field("id", SortOrder)
self.field("username", SortOrder)
class AuthUserGraphType(GraphType[AuthUser]):
def __init__(self):
GraphType.__init__(self)
self.int_field(
"id",
resolver=lambda root: root.id,
)
self.string_field(
"username",
resolver=lambda root: root.username,
)
class HelloQuery(Query): class HelloQuery(Query):
def __init__(self): def __init__(self):
Query.__init__(self) Query.__init__(self)
@@ -28,3 +60,10 @@ class HelloQuery(Query):
CitySort, CitySort,
resolver=lambda: cities, resolver=lambda: cities,
) )
self.dao_collection_field(
AuthUserGraphType,
AuthUserDao,
"authUsers",
AuthUserFilter,
AuthUserSort,
)

View File

@@ -0,0 +1,10 @@
CREATE TABLE IF NOT EXISTS `posts` (
`id` INT(30) NOT NULL AUTO_INCREMENT,
`title` VARCHAR(64) NOT NULL,
`content` VARCHAR(512) NOT NULL,
deleted BOOLEAN NOT NULL DEFAULT FALSE,
editorId INT NULL,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY(`id`)
);

View File

@@ -0,0 +1,31 @@
from faker import Faker
from cpl.database.abc import DataSeederABC
from cpl.query import Enumerable
from model.post import Post
from model.post_dao import PostDao
fake = Faker()
class TestDataSeeder(DataSeederABC):
def __init__(self, posts: PostDao):
DataSeederABC.__init__(self)
self._posts = posts
async def seed(self):
if await self._posts.count() == 0:
await self._seed_posts()
async def _seed_posts(self):
posts = Enumerable.range(0, 100).select(
lambda x: Post(
id=0,
title=fake.sentence(nb_words=6),
content=fake.paragraph(nb_sentences=6),
)
).to_list()
await self._posts.create_many(posts, skip_editor=True)

View File

@@ -5,16 +5,16 @@ from cpl.core.typing import SerialId
from cpl.database.abc.db_model_abc import DbModelABC from cpl.database.abc.db_model_abc import DbModelABC
class City(DbModelABC): class City(DbModelABC[Self]):
def __init__( def __init__(
self, self,
id: int, id: int,
name: str, name: str,
zip: str, zip: str,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None= None,
updated: Optional[datetime] = None, updated: datetime | None= None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name self._name = name

View File

@@ -5,7 +5,7 @@ from cpl.core.typing import SerialId
from cpl.database.abc.db_model_abc import DbModelABC from cpl.database.abc.db_model_abc import DbModelABC
class User(DbModelABC): class User(DbModelABC[Self]):
def __init__( def __init__(
self, self,
@@ -13,9 +13,9 @@ class User(DbModelABC):
name: str, name: str,
city_id: int = 0, city_id: int = 0,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None= None,
updated: Optional[datetime] = None, updated: datetime | None= None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name self._name = name

View File

@@ -1,6 +1,6 @@
import secrets import secrets
from datetime import datetime from datetime import datetime
from typing import Optional, Union from typing import Optional, Union, Self
from async_property import async_property from async_property import async_property
@@ -16,7 +16,7 @@ from cpl.dependency.service_provider import ServiceProvider
_logger = Logger(__name__) _logger = Logger(__name__)
class ApiKey(DbModelABC): class ApiKey(DbModelABC[Self]):
def __init__( def __init__(
self, self,
@@ -25,8 +25,8 @@ class ApiKey(DbModelABC):
key: Union[str, bytes], key: Union[str, bytes],
deleted: bool = False, deleted: bool = False,
editor_id: Optional[Id] = None, editor_id: Optional[Id] = None,
created: Optional[datetime] = None, created: datetime | None= None,
updated: Optional[datetime] = None, updated: datetime | None= None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._identifier = identifier self._identifier = identifier

View File

@@ -1,6 +1,6 @@
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional, Self
from async_property import async_property from async_property import async_property
from keycloak import KeycloakGetError from keycloak import KeycloakGetError
@@ -13,15 +13,15 @@ from cpl.database.logger import DBLogger
from cpl.dependency import get_provider from cpl.dependency import get_provider
class AuthUser(DbModelABC): class AuthUser(DbModelABC[Self]):
def __init__( def __init__(
self, self,
id: SerialId, id: SerialId,
keycloak_id: str, keycloak_id: str,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None= None,
updated: Optional[datetime] = None, updated: datetime | None= None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._keycloak_id = keycloak_id self._keycloak_id = keycloak_id
@@ -87,4 +87,3 @@ class AuthUser(DbModelABC):
self._keycloak_id = str(uuid.UUID(int=0)) self._keycloak_id = str(uuid.UUID(int=0))
await auth_user_dao.update(self) await auth_user_dao.update(self)

View File

@@ -5,7 +5,7 @@ from cpl.auth.schema._administration.auth_user import AuthUser
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.dependency import ServiceProvider from cpl.dependency.context import get_provider
class AuthUserDao(DbModelDaoABC[AuthUser]): class AuthUserDao(DbModelDaoABC[AuthUser]):

View File

@@ -15,9 +15,9 @@ class ApiKeyPermission(DbJoinModelABC):
api_key_id: SerialId, api_key_id: SerialId,
permission_id: SerialId, permission_id: SerialId,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None= None,
updated: Optional[datetime] = None, updated: datetime | None= None,
): ):
DbJoinModelABC.__init__(self, api_key_id, permission_id, id, deleted, editor_id, created, updated) DbJoinModelABC.__init__(self, api_key_id, permission_id, id, deleted, editor_id, created, updated)
self._api_key_id = api_key_id self._api_key_id = api_key_id

View File

@@ -1,20 +1,20 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional, Self
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
class Permission(DbModelABC): class Permission(DbModelABC[Self]):
def __init__( def __init__(
self, self,
id: SerialId, id: SerialId,
name: str, name: str,
description: str, description: str,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None= None,
updated: Optional[datetime] = None, updated: datetime | None= None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name self._name = name

View File

@@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional, Self
from async_property import async_property from async_property import async_property
@@ -9,16 +9,16 @@ from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProvider
class Role(DbModelABC): class Role(DbModelABC[Self]):
def __init__( def __init__(
self, self,
id: SerialId, id: SerialId,
name: str, name: str,
description: str, description: str,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None= None,
updated: Optional[datetime] = None, updated: datetime | None= None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name self._name = name

View File

@@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional, Self
from async_property import async_property from async_property import async_property
@@ -8,16 +8,16 @@ from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProvider
class RolePermission(DbModelABC): class RolePermission(DbModelABC[Self]):
def __init__( def __init__(
self, self,
id: SerialId, id: SerialId,
role_id: SerialId, role_id: SerialId,
permission_id: SerialId, permission_id: SerialId,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None= None,
updated: Optional[datetime] = None, updated: datetime | None= None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._role_id = role_id self._role_id = role_id

View File

@@ -15,9 +15,9 @@ class RoleUser(DbJoinModelABC):
user_id: SerialId, user_id: SerialId,
role_id: SerialId, role_id: SerialId,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None= None,
updated: Optional[datetime] = None, updated: datetime | None= None,
): ):
DbJoinModelABC.__init__(self, id, user_id, role_id, deleted, editor_id, created, updated) DbJoinModelABC.__init__(self, id, user_id, role_id, deleted, editor_id, created, updated)
self._user_id = user_id self._user_id = user_id

View File

@@ -2,10 +2,6 @@ import os
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from cpl.core.log.logger import Logger
_logger = Logger(__name__)
class CredentialManager: class CredentialManager:
r"""Handles credential encryption and decryption""" r"""Handles credential encryption and decryption"""
@@ -14,6 +10,7 @@ class CredentialManager:
@classmethod @classmethod
def with_secret(cls, file: str = None): def with_secret(cls, file: str = None):
from cpl.core.log import Logger
if file is None: if file is None:
file = ".secret" file = ".secret"
@@ -25,12 +22,12 @@ class CredentialManager:
with open(file, "w") as secret_file: with open(file, "w") as secret_file:
secret_file.write(Fernet.generate_key().decode()) secret_file.write(Fernet.generate_key().decode())
secret_file.close() secret_file.close()
_logger.warning("Secret file not found, regenerating") Logger(__name__).warning("Secret file not found, regenerating")
with open(file, "r") as secret_file: with open(file, "r") as secret_file:
secret = secret_file.read().strip() secret = secret_file.read().strip()
if secret == "" or secret is None: if secret == "" or secret is None:
_logger.fatal("No secret found in .secret file.") Logger(__name__).fatal("No secret found in .secret file.")
cls._secret = str(secret) cls._secret = str(secret)

View File

@@ -46,6 +46,10 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
def table_name(self) -> str: def table_name(self) -> str:
return self._table_name return self._table_name
@property
def type(self) -> Type[T_DBM]:
return self._model_type
def has_attribute(self, attr_name: Attribute) -> bool: def has_attribute(self, attr_name: Attribute) -> bool:
""" """
Check if the attribute exists in the DAO Check if the attribute exists in the DAO
@@ -490,16 +494,16 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
table, join_condition = self.__foreign_tables[attr] table, join_condition = self.__foreign_tables[attr]
builder.with_left_join(table, join_condition) builder.with_left_join(table, join_condition)
if filters: if filters is not None:
await self._build_conditions(builder, filters, external_table_deps) await self._build_conditions(builder, filters, external_table_deps)
if sorts: if sorts is not None:
self._build_sorts(builder, sorts, external_table_deps) self._build_sorts(builder, sorts, external_table_deps)
if take: if take is not None:
builder.with_limit(take) builder.with_limit(take)
if skip: if skip is not None:
builder.with_offset(skip) builder.with_offset(skip)
for external_table in external_table_deps: for external_table in external_table_deps:

View File

@@ -12,9 +12,9 @@ class DbJoinModelABC[T](DbModelABC[T]):
source_id: Id, source_id: Id,
foreign_id: Id, foreign_id: Id,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None= None,
updated: Optional[datetime] = None, updated: datetime | None= None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)

View File

@@ -10,9 +10,9 @@ class DbModelABC(ABC, Generic[T]):
self, self,
id: Id, id: Id,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None= None,
updated: Optional[datetime] = None, updated: datetime | None= None,
): ):
self._id = id self._id = id
self._deleted = deleted self._deleted = deleted

View File

@@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from cpl.core.configuration import Configuration from cpl.core.configuration.configuration import Configuration
from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC

View File

@@ -1,15 +1,15 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional, Self
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
class ExecutedMigration(DbModelABC): class ExecutedMigration(DbModelABC[Self]):
def __init__( def __init__(
self, self,
migration_id: str, migration_id: str,
created: Optional[datetime] = None, created: datetime | None= None,
modified: Optional[datetime] = None, modified: datetime | None= None,
): ):
DbModelABC.__init__(self, migration_id, False, created, modified) DbModelABC.__init__(self, migration_id, False, created, modified)

View File

@@ -18,7 +18,7 @@ class CollectionGraphTypeFactory:
gql_type = strawberry.type( gql_type = strawberry.type(
type( type(
f"{node_type.__name__}Collection", f"{node_type.__name__.replace("GraphType", "")}Collection",
(), (),
{ {
"__annotations__": { "__annotations__": {

View File

@@ -4,6 +4,7 @@ from typing import Callable, Type, Any, Optional
import strawberry import strawberry
from strawberry.exceptions import StrawberryException from strawberry.exceptions import StrawberryException
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
from cpl.dependency.inject import inject from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
@@ -69,9 +70,6 @@ class Query(StrawberryProtocol):
sort_type: Type[StrawberryProtocol], sort_type: Type[StrawberryProtocol],
resolver: Callable, resolver: Callable,
) -> Field: ) -> Field:
# self._schema.with_type(filter_type)
# self._schema.with_type(sort_type)
def _resolve_collection(filter=None, sort=None, skip=0, take=10): def _resolve_collection(filter=None, sort=None, skip=0, take=10):
items = resolver() items = resolver()
if filter: if filter:
@@ -103,6 +101,53 @@ class Query(StrawberryProtocol):
f.with_argument(int, "take", default_value=10) f.with_argument(int, "take", default_value=10)
return f return f
def dao_collection_field(
self,
t: Type[StrawberryProtocol],
dao_type: Type[DataAccessObjectABC],
name: str,
filter_type: Type[StrawberryProtocol],
sort_type: Type[StrawberryProtocol],
) -> Field:
assert issubclass(dao_type, DataAccessObjectABC), "dao_type must be a subclass of DataAccessObjectABC"
dao = self._provider.get_service(dao_type)
if not dao:
raise ValueError(f"DAO '{dao_type.__name__}' not registered in service provider")
filter = self._provider.get_service(filter_type)
if not filter:
raise ValueError(f"Filter '{filter_type.__name__}' not registered in service provider")
sort = self._provider.get_service(sort_type)
if not sort:
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
async def _resolver(filter=None, sort=None, take=10, skip=0):
sort_dict = None
if sort is not None:
sort_dict = {}
for k, v in sort.__dict__.items():
if v is None:
continue
if isinstance(v, SortOrder):
sort_dict[k] = str(v.value).lower()
continue
sort_dict[k] = str(v).lower()
total_count = await dao.count(filter)
data = await dao.find_by(filter, sort_dict, take, skip)
return Collection(nodes=data, total_count=total_count, count=len(data))
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver)
f.with_argument(filter.to_strawberry(), "filter")
f.with_argument(sort.to_strawberry(), "sort")
f.with_argument(int, "skip", default_value=0)
f.with_argument(int, "take", default_value=10)
return f
@staticmethod @staticmethod
def _build_resolver(f: "Field"): def _build_resolver(f: "Field"):
params: list[inspect.Parameter] = [] params: list[inspect.Parameter] = []
@@ -164,4 +209,4 @@ class Query(StrawberryProtocol):
namespace[name] = self._field_to_strawberry(f) namespace[name] = self._field_to_strawberry(f)
namespace["__annotations__"] = annotations namespace["__annotations__"] = annotations
return strawberry.type(type(f"{self.__class__.__name__}GraphType", (), namespace)) return strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))

View File

@@ -2,5 +2,5 @@ from enum import Enum, auto
class SortOrder(Enum): class SortOrder(Enum):
ASC = auto() ASC = "ASC"
DESC = auto() DESC = "DESC"