Further gql improvements & added test data #181
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, Self
|
||||
|
||||
from async_property import async_property
|
||||
|
||||
@@ -16,7 +16,7 @@ from cpl.dependency.service_provider import ServiceProvider
|
||||
_logger = Logger(__name__)
|
||||
|
||||
|
||||
class ApiKey(DbModelABC):
|
||||
class ApiKey(DbModelABC[Self]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -25,8 +25,8 @@ class ApiKey(DbModelABC):
|
||||
key: Union[str, bytes],
|
||||
deleted: bool = False,
|
||||
editor_id: Optional[Id] = None,
|
||||
created: Optional[datetime] = None,
|
||||
updated: Optional[datetime] = None,
|
||||
created: datetime | None= None,
|
||||
updated: datetime | None= None,
|
||||
):
|
||||
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
||||
self._identifier = identifier
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, Self
|
||||
|
||||
from async_property import async_property
|
||||
from keycloak import KeycloakGetError
|
||||
@@ -13,15 +13,15 @@ from cpl.database.logger import DBLogger
|
||||
from cpl.dependency import get_provider
|
||||
|
||||
|
||||
class AuthUser(DbModelABC):
|
||||
class AuthUser(DbModelABC[Self]):
|
||||
def __init__(
|
||||
self,
|
||||
id: SerialId,
|
||||
keycloak_id: str,
|
||||
deleted: bool = False,
|
||||
editor_id: Optional[SerialId] = None,
|
||||
created: Optional[datetime] = None,
|
||||
updated: Optional[datetime] = None,
|
||||
editor_id: SerialId | None = None,
|
||||
created: datetime | None= None,
|
||||
updated: datetime | None= None,
|
||||
):
|
||||
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
||||
self._keycloak_id = keycloak_id
|
||||
@@ -87,4 +87,3 @@ class AuthUser(DbModelABC):
|
||||
|
||||
self._keycloak_id = str(uuid.UUID(int=0))
|
||||
await auth_user_dao.update(self)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from cpl.auth.schema._administration.auth_user import AuthUser
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.dependency.context import get_provider
|
||||
|
||||
|
||||
class AuthUserDao(DbModelDaoABC[AuthUser]):
|
||||
|
||||
@@ -15,9 +15,9 @@ class ApiKeyPermission(DbJoinModelABC):
|
||||
api_key_id: SerialId,
|
||||
permission_id: SerialId,
|
||||
deleted: bool = False,
|
||||
editor_id: Optional[SerialId] = None,
|
||||
created: Optional[datetime] = None,
|
||||
updated: Optional[datetime] = None,
|
||||
editor_id: SerialId | None = None,
|
||||
created: datetime | None= None,
|
||||
updated: datetime | None= None,
|
||||
):
|
||||
DbJoinModelABC.__init__(self, api_key_id, permission_id, id, deleted, editor_id, created, updated)
|
||||
self._api_key_id = api_key_id
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, Self
|
||||
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
|
||||
|
||||
class Permission(DbModelABC):
|
||||
class Permission(DbModelABC[Self]):
|
||||
def __init__(
|
||||
self,
|
||||
id: SerialId,
|
||||
name: str,
|
||||
description: str,
|
||||
deleted: bool = False,
|
||||
editor_id: Optional[SerialId] = None,
|
||||
created: Optional[datetime] = None,
|
||||
updated: Optional[datetime] = None,
|
||||
editor_id: SerialId | None = None,
|
||||
created: datetime | None= None,
|
||||
updated: datetime | None= None,
|
||||
):
|
||||
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
||||
self._name = name
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, Self
|
||||
|
||||
from async_property import async_property
|
||||
|
||||
@@ -9,16 +9,16 @@ from cpl.database.abc import DbModelABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class Role(DbModelABC):
|
||||
class Role(DbModelABC[Self]):
|
||||
def __init__(
|
||||
self,
|
||||
id: SerialId,
|
||||
name: str,
|
||||
description: str,
|
||||
deleted: bool = False,
|
||||
editor_id: Optional[SerialId] = None,
|
||||
created: Optional[datetime] = None,
|
||||
updated: Optional[datetime] = None,
|
||||
editor_id: SerialId | None = None,
|
||||
created: datetime | None= None,
|
||||
updated: datetime | None= None,
|
||||
):
|
||||
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
||||
self._name = name
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, Self
|
||||
|
||||
from async_property import async_property
|
||||
|
||||
@@ -8,16 +8,16 @@ from cpl.database.abc import DbModelABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
|
||||
|
||||
class RolePermission(DbModelABC):
|
||||
class RolePermission(DbModelABC[Self]):
|
||||
def __init__(
|
||||
self,
|
||||
id: SerialId,
|
||||
role_id: SerialId,
|
||||
permission_id: SerialId,
|
||||
deleted: bool = False,
|
||||
editor_id: Optional[SerialId] = None,
|
||||
created: Optional[datetime] = None,
|
||||
updated: Optional[datetime] = None,
|
||||
editor_id: SerialId | None = None,
|
||||
created: datetime | None= None,
|
||||
updated: datetime | None= None,
|
||||
):
|
||||
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
||||
self._role_id = role_id
|
||||
|
||||
@@ -15,9 +15,9 @@ class RoleUser(DbJoinModelABC):
|
||||
user_id: SerialId,
|
||||
role_id: SerialId,
|
||||
deleted: bool = False,
|
||||
editor_id: Optional[SerialId] = None,
|
||||
created: Optional[datetime] = None,
|
||||
updated: Optional[datetime] = None,
|
||||
editor_id: SerialId | None = None,
|
||||
created: datetime | None= None,
|
||||
updated: datetime | None= None,
|
||||
):
|
||||
DbJoinModelABC.__init__(self, id, user_id, role_id, deleted, editor_id, created, updated)
|
||||
self._user_id = user_id
|
||||
|
||||
@@ -2,10 +2,6 @@ import os
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from cpl.core.log.logger import Logger
|
||||
|
||||
_logger = Logger(__name__)
|
||||
|
||||
|
||||
class CredentialManager:
|
||||
r"""Handles credential encryption and decryption"""
|
||||
@@ -14,6 +10,7 @@ class CredentialManager:
|
||||
|
||||
@classmethod
|
||||
def with_secret(cls, file: str = None):
|
||||
from cpl.core.log import Logger
|
||||
if file is None:
|
||||
file = ".secret"
|
||||
|
||||
@@ -25,12 +22,12 @@ class CredentialManager:
|
||||
with open(file, "w") as secret_file:
|
||||
secret_file.write(Fernet.generate_key().decode())
|
||||
secret_file.close()
|
||||
_logger.warning("Secret file not found, regenerating")
|
||||
Logger(__name__).warning("Secret file not found, regenerating")
|
||||
|
||||
with open(file, "r") as secret_file:
|
||||
secret = secret_file.read().strip()
|
||||
if secret == "" or secret is None:
|
||||
_logger.fatal("No secret found in .secret file.")
|
||||
Logger(__name__).fatal("No secret found in .secret file.")
|
||||
|
||||
cls._secret = str(secret)
|
||||
|
||||
|
||||
@@ -46,6 +46,10 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
|
||||
def table_name(self) -> str:
|
||||
return self._table_name
|
||||
|
||||
@property
|
||||
def type(self) -> Type[T_DBM]:
|
||||
return self._model_type
|
||||
|
||||
def has_attribute(self, attr_name: Attribute) -> bool:
|
||||
"""
|
||||
Check if the attribute exists in the DAO
|
||||
@@ -490,16 +494,16 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
|
||||
table, join_condition = self.__foreign_tables[attr]
|
||||
builder.with_left_join(table, join_condition)
|
||||
|
||||
if filters:
|
||||
if filters is not None:
|
||||
await self._build_conditions(builder, filters, external_table_deps)
|
||||
|
||||
if sorts:
|
||||
if sorts is not None:
|
||||
self._build_sorts(builder, sorts, external_table_deps)
|
||||
|
||||
if take:
|
||||
if take is not None:
|
||||
builder.with_limit(take)
|
||||
|
||||
if skip:
|
||||
if skip is not None:
|
||||
builder.with_offset(skip)
|
||||
|
||||
for external_table in external_table_deps:
|
||||
|
||||
@@ -12,9 +12,9 @@ class DbJoinModelABC[T](DbModelABC[T]):
|
||||
source_id: Id,
|
||||
foreign_id: Id,
|
||||
deleted: bool = False,
|
||||
editor_id: Optional[SerialId] = None,
|
||||
created: Optional[datetime] = None,
|
||||
updated: Optional[datetime] = None,
|
||||
editor_id: SerialId | None = None,
|
||||
created: datetime | None= None,
|
||||
updated: datetime | None= None,
|
||||
):
|
||||
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ class DbModelABC(ABC, Generic[T]):
|
||||
self,
|
||||
id: Id,
|
||||
deleted: bool = False,
|
||||
editor_id: Optional[SerialId] = None,
|
||||
created: Optional[datetime] = None,
|
||||
updated: Optional[datetime] = None,
|
||||
editor_id: SerialId | None = None,
|
||||
created: datetime | None= None,
|
||||
updated: datetime | None= None,
|
||||
):
|
||||
self._id = id
|
||||
self._deleted = deleted
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.configuration.configuration import Configuration
|
||||
from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, Self
|
||||
|
||||
from cpl.database.abc import DbModelABC
|
||||
|
||||
|
||||
class ExecutedMigration(DbModelABC):
|
||||
class ExecutedMigration(DbModelABC[Self]):
|
||||
def __init__(
|
||||
self,
|
||||
migration_id: str,
|
||||
created: Optional[datetime] = None,
|
||||
modified: Optional[datetime] = None,
|
||||
created: datetime | None= None,
|
||||
modified: datetime | None= None,
|
||||
):
|
||||
DbModelABC.__init__(self, migration_id, False, created, modified)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class CollectionGraphTypeFactory:
|
||||
|
||||
gql_type = strawberry.type(
|
||||
type(
|
||||
f"{node_type.__name__}Collection",
|
||||
f"{node_type.__name__.replace("GraphType", "")}Collection",
|
||||
(),
|
||||
{
|
||||
"__annotations__": {
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Callable, Type, Any, Optional
|
||||
import strawberry
|
||||
from strawberry.exceptions import StrawberryException
|
||||
|
||||
from cpl.database.abc.data_access_object_abc import DataAccessObjectABC
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||
@@ -69,9 +70,6 @@ class Query(StrawberryProtocol):
|
||||
sort_type: Type[StrawberryProtocol],
|
||||
resolver: Callable,
|
||||
) -> Field:
|
||||
# self._schema.with_type(filter_type)
|
||||
# self._schema.with_type(sort_type)
|
||||
|
||||
def _resolve_collection(filter=None, sort=None, skip=0, take=10):
|
||||
items = resolver()
|
||||
if filter:
|
||||
@@ -103,6 +101,53 @@ class Query(StrawberryProtocol):
|
||||
f.with_argument(int, "take", default_value=10)
|
||||
return f
|
||||
|
||||
def dao_collection_field(
|
||||
self,
|
||||
t: Type[StrawberryProtocol],
|
||||
dao_type: Type[DataAccessObjectABC],
|
||||
name: str,
|
||||
filter_type: Type[StrawberryProtocol],
|
||||
sort_type: Type[StrawberryProtocol],
|
||||
) -> Field:
|
||||
assert issubclass(dao_type, DataAccessObjectABC), "dao_type must be a subclass of DataAccessObjectABC"
|
||||
dao = self._provider.get_service(dao_type)
|
||||
if not dao:
|
||||
raise ValueError(f"DAO '{dao_type.__name__}' not registered in service provider")
|
||||
|
||||
filter = self._provider.get_service(filter_type)
|
||||
if not filter:
|
||||
raise ValueError(f"Filter '{filter_type.__name__}' not registered in service provider")
|
||||
|
||||
sort = self._provider.get_service(sort_type)
|
||||
if not sort:
|
||||
raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider")
|
||||
|
||||
async def _resolver(filter=None, sort=None, take=10, skip=0):
|
||||
sort_dict = None
|
||||
|
||||
if sort is not None:
|
||||
sort_dict = {}
|
||||
for k, v in sort.__dict__.items():
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
if isinstance(v, SortOrder):
|
||||
sort_dict[k] = str(v.value).lower()
|
||||
continue
|
||||
|
||||
sort_dict[k] = str(v).lower()
|
||||
|
||||
total_count = await dao.count(filter)
|
||||
data = await dao.find_by(filter, sort_dict, take, skip)
|
||||
return Collection(nodes=data, total_count=total_count, count=len(data))
|
||||
|
||||
f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver)
|
||||
f.with_argument(filter.to_strawberry(), "filter")
|
||||
f.with_argument(sort.to_strawberry(), "sort")
|
||||
f.with_argument(int, "skip", default_value=0)
|
||||
f.with_argument(int, "take", default_value=10)
|
||||
return f
|
||||
|
||||
@staticmethod
|
||||
def _build_resolver(f: "Field"):
|
||||
params: list[inspect.Parameter] = []
|
||||
@@ -164,4 +209,4 @@ class Query(StrawberryProtocol):
|
||||
namespace[name] = self._field_to_strawberry(f)
|
||||
|
||||
namespace["__annotations__"] = annotations
|
||||
return strawberry.type(type(f"{self.__class__.__name__}GraphType", (), namespace))
|
||||
return strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
|
||||
|
||||
@@ -2,5 +2,5 @@ from enum import Enum, auto
|
||||
|
||||
|
||||
class SortOrder(Enum):
|
||||
ASC = auto()
|
||||
DESC = auto()
|
||||
ASC = "ASC"
|
||||
DESC = "DESC"
|
||||
Reference in New Issue
Block a user