From 652304a4806acfbbf58868089b4afff24c5ec371 Mon Sep 17 00:00:00 2001 From: edraft Date: Sun, 28 Sep 2025 01:09:46 +0200 Subject: [PATCH] Recursive complex filtering #181 --- example/api/src/main.py | 13 ++- example/api/src/model/author.py | 30 ++++++ example/api/src/model/author_dao.py | 11 +++ example/api/src/model/author_query.py | 37 ++++++++ example/api/src/model/post.py | 6 ++ example/api/src/model/post_dao.py | 6 +- example/api/src/model/post_query.py | 14 ++- example/api/src/scripts/0-posts.sql | 16 +++- example/api/src/test_data_seeder.py | 19 +++- .../cpl/database/mysql/mysql_pool.py | 92 ++++++++++++------- .../cpl/database/postgres/postgres_pool.py | 2 +- .../cpl/graphql/schema/collection.py | 8 +- .../cpl/graphql/schema/filter/filter.py | 5 + src/cpl-graphql/cpl/graphql/schema/input.py | 9 +- src/cpl-graphql/cpl/graphql/schema/query.py | 12 ++- .../cpl/graphql/utils/type_collector.py | 17 ++++ 16 files changed, 249 insertions(+), 48 deletions(-) create mode 100644 example/api/src/model/author.py create mode 100644 example/api/src/model/author_dao.py create mode 100644 example/api/src/model/author_query.py create mode 100644 src/cpl-graphql/cpl/graphql/utils/type_collector.py diff --git a/example/api/src/main.py b/example/api/src/main.py index 777f62fe..7273381a 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -14,6 +14,8 @@ from cpl.core.utils.cache import Cache from cpl.database.mysql.mysql_module import MySQLModule from cpl.graphql.application.graphql_app import GraphQLApp from cpl.graphql.graphql_module import GraphQLModule +from model.author_dao import AuthorDao +from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort from model.post_dao import PostDao from model.post_query import PostFilter, PostSort, PostGraphType from queries.hello import HelloQuery @@ -49,14 +51,18 @@ def main(): .add_transient(AuthUserFilter) .add_transient(AuthUserSort) .add_transient(HelloQuery) + # test data + .add_singleton(TestDataSeeder) + # authors + .add_transient(AuthorDao) + .add_transient(AuthorGraphType) + .add_transient(AuthorFilter) + .add_transient(AuthorSort) # posts .add_transient(PostDao) .add_transient(PostGraphType) .add_transient(PostFilter) .add_transient(PostSort) - - # test data - .add_singleton(TestDataSeeder) ) app = builder.build() @@ -78,6 +84,7 @@ def main(): schema = app.with_graphql() schema.query.string_field("ping", resolver=lambda: "pong") schema.query.with_query("hello", HelloQuery) + schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort) schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort) app.with_playground() diff --git a/example/api/src/model/author.py b/example/api/src/model/author.py new file mode 100644 index 00000000..05e2d3e3 --- /dev/null +++ b/example/api/src/model/author.py @@ -0,0 +1,30 @@ +from datetime import datetime +from typing import Self + +from cpl.core.typing import SerialId +from cpl.database.abc import DbModelABC + + +class Author(DbModelABC[Self]): + + def __init__( + self, + id: int, + first_name: str, + last_name: str, + deleted: bool = False, + editor_id: SerialId | None = None, + created: datetime | None = None, + updated: datetime | None = None, + ): + DbModelABC.__init__(self, id, deleted, editor_id, created, updated) + self._first_name = first_name + self._last_name = last_name + + @property + def first_name(self) -> str: + return self._first_name + + @property + def last_name(self) -> str: + return self._last_name diff --git a/example/api/src/model/author_dao.py b/example/api/src/model/author_dao.py new file mode 100644 index 00000000..98b997a6 --- /dev/null +++ b/example/api/src/model/author_dao.py @@ -0,0 +1,11 @@ +from cpl.database.abc import DbModelDaoABC +from model.author import Author + + +class AuthorDao(DbModelDaoABC): + + def __init__(self): + DbModelDaoABC.__init__(self, Author, "authors") + + self.attribute(Author.first_name, str) + self.attribute(Author.last_name, str) \ No newline at end of file diff --git a/example/api/src/model/author_query.py b/example/api/src/model/author_query.py new file mode 100644 index 00000000..f7f1d1df --- /dev/null +++ b/example/api/src/model/author_query.py @@ -0,0 +1,37 @@ +from cpl.graphql.schema.filter.filter import Filter +from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.schema.sort.sort import Sort +from cpl.graphql.schema.sort.sort_order import SortOrder +from model.author import Author + +class AuthorFilter(Filter[Author]): + def __init__(self): + Filter.__init__(self) + self.int_field("id") + self.string_field("firstName") + self.string_field("lastName") + +class AuthorSort(Sort[Author]): + def __init__(self): + Sort.__init__(self) + self.field("id", SortOrder) + self.field("firstName", SortOrder) + self.field("lastName", SortOrder) + +class AuthorGraphType(GraphType[Author]): + + def __init__(self): + GraphType.__init__(self) + + self.int_field( + "id", + resolver=lambda root: root.id, + ) + self.string_field( + "firstName", + resolver=lambda root: root.first_name, + ) + self.string_field( + "lastName", + resolver=lambda root: root.last_name, + ) diff --git a/example/api/src/model/post.py b/example/api/src/model/post.py index a2d22d60..d5801cd0 100644 --- a/example/api/src/model/post.py +++ b/example/api/src/model/post.py @@ -10,6 +10,7 @@ class Post(DbModelABC[Self]): def __init__( self, id: int, + author_id: SerialId, title: str, content: str, deleted: bool = False, @@ -18,9 +19,14 @@ class Post(DbModelABC[Self]): updated: datetime | None = None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) + self._author_id = author_id self._title = title self._content = content + @property + def author_id(self) -> SerialId: + return self._author_id + @property def title(self) -> str: return self._title diff --git a/example/api/src/model/post_dao.py b/example/api/src/model/post_dao.py index da283fef..be8e5668 100644 --- a/example/api/src/model/post_dao.py +++ b/example/api/src/model/post_dao.py @@ -1,11 +1,15 @@ from cpl.database.abc import DbModelDaoABC +from model.author_dao import AuthorDao from model.post import Post class PostDao(DbModelDaoABC): - def __init__(self): + def __init__(self, authors: AuthorDao): DbModelDaoABC.__init__(self, Post, "posts") + self.attribute(Post.author_id, int, db_name="authorId") + self.reference("author", "id", Post.author_id, "authors", authors) + self.attribute(Post.title, str) self.attribute(Post.content, str) \ No newline at end of file diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py index 2e5b2998..e3bc41af 100644 --- a/example/api/src/model/post_query.py +++ b/example/api/src/model/post_query.py @@ -2,12 +2,15 @@ from cpl.graphql.schema.filter.filter import Filter from cpl.graphql.schema.graph_type import GraphType from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort_order import SortOrder +from model.author_dao import AuthorDao +from model.author_query import AuthorGraphType, AuthorFilter from model.post import Post class PostFilter(Filter[Post]): def __init__(self): Filter.__init__(self) self.int_field("id") + self.filter_field("author", AuthorFilter) self.string_field("title") self.string_field("content") @@ -20,13 +23,22 @@ class PostSort(Sort[Post]): class PostGraphType(GraphType[Post]): - def __init__(self): + def __init__(self, authors: AuthorDao): GraphType.__init__(self) self.int_field( "id", resolver=lambda root: root.id, ) + + async def _a(root: Post): + return await authors.get_by_id(root.author_id) + + self.object_field( + "author", + AuthorGraphType, + resolver=_a#lambda root: root.author_id, + ) self.string_field( "title", resolver=lambda root: root.title, diff --git a/example/api/src/scripts/0-posts.sql b/example/api/src/scripts/0-posts.sql index bf2ecc62..26268f17 100644 --- a/example/api/src/scripts/0-posts.sql +++ b/example/api/src/scripts/0-posts.sql @@ -1,7 +1,19 @@ +CREATE TABLE IF NOT EXISTS `authors` ( + `id` INT(30) NOT NULL AUTO_INCREMENT, + `firstname` VARCHAR(64) NOT NULL, + `lastname` VARCHAR(64) NOT NULL, + deleted BOOLEAN NOT NULL DEFAULT FALSE, + editorId INT NULL, + created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY(`id`) + ); + CREATE TABLE IF NOT EXISTS `posts` ( `id` INT(30) NOT NULL AUTO_INCREMENT, - `title` VARCHAR(64) NOT NULL, - `content` VARCHAR(512) NOT NULL, + `authorId` INT(30) NOT NULL REFERENCES `authors`(`id`) ON DELETE CASCADE, + `title` TEXT NOT NULL, + `content` TEXT NOT NULL, deleted BOOLEAN NOT NULL DEFAULT FALSE, editorId INT NULL, created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, diff --git a/example/api/src/test_data_seeder.py b/example/api/src/test_data_seeder.py index f50eea6f..38bcc1f1 100644 --- a/example/api/src/test_data_seeder.py +++ b/example/api/src/test_data_seeder.py @@ -2,6 +2,8 @@ from faker import Faker from cpl.database.abc import DataSeederABC from cpl.query import Enumerable +from model.author import Author +from model.author_dao import AuthorDao from model.post import Post from model.post_dao import PostDao @@ -11,19 +13,34 @@ fake = Faker() class TestDataSeeder(DataSeederABC): - def __init__(self, posts: PostDao): + def __init__(self, authors: AuthorDao, posts: PostDao): DataSeederABC.__init__(self) + self._authors = authors self._posts = posts async def seed(self): + if await self._authors.count() == 0: + await self._seed_authors() + if await self._posts.count() == 0: await self._seed_posts() + async def _seed_authors(self): + authors = Enumerable.range(0, 35).select( + lambda x: Author( + 0, + fake.first_name(), + fake.last_name(), + ) + ).to_list() + await self._authors.create_many(authors, skip_editor=True) + async def _seed_posts(self): posts = Enumerable.range(0, 100).select( lambda x: Post( id=0, + author_id=fake.random_int(min=1, max=35), title=fake.sentence(nb_words=6), content=fake.paragraph(nb_sentences=6), ) diff --git a/src/cpl-database/cpl/database/mysql/mysql_pool.py b/src/cpl-database/cpl/database/mysql/mysql_pool.py index a5422761..474bf6ce 100644 --- a/src/cpl-database/cpl/database/mysql/mysql_pool.py +++ b/src/cpl-database/cpl/database/mysql/mysql_pool.py @@ -1,6 +1,8 @@ from typing import Optional, Any - import sqlparse +import asyncio + +from mysql.connector import errors, PoolError from mysql.connector.aio import MySQLConnectionPool from cpl.core.environment import Environment @@ -10,7 +12,6 @@ from cpl.dependency.context import get_provider class MySQLPool: - def __init__(self, database_settings: DatabaseSettings): self._dbconfig = { "host": database_settings.host, @@ -25,59 +26,87 @@ class MySQLPool: "ssl_disabled": database_settings.ssl_disabled, } self._pool: Optional[MySQLConnectionPool] = None + self._pool_lock = asyncio.Lock() - async def _get_pool(self): + async def _get_pool(self) -> MySQLConnectionPool: if self._pool is None: - try: - self._pool = MySQLConnectionPool( - pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig - ) - await self._pool.initialize_pool() + async with self._pool_lock: + if self._pool is None: + try: + self._pool = MySQLConnectionPool( + pool_name="cplpool", + pool_size=Environment.get("DB_POOL_SIZE", int, 20), + **self._dbconfig, + ) + await self._pool.initialize_pool() - con = await self._pool.get_connection() - async with await con.cursor() as cursor: - await cursor.execute("SELECT 1") - await cursor.fetchall() - - await con.close() - except Exception as e: - logger = get_provider().get_service(DBLogger) - logger.fatal(f"Error connecting to the database", e) + # Testverbindung (Ping) + con = await self._pool.get_connection() + try: + async with await con.cursor() as cursor: + await cursor.execute("SELECT 1") + await cursor.fetchall() + finally: + await con.close() + except Exception as e: + logger = get_provider().get_service(DBLogger) + logger.fatal("Error connecting to the database", e) + raise return self._pool + async def _get_connection(self, retries: int = 3, delay: float = 0.5): + """Stabiler Connection-Getter mit Retry und Ping""" + pool = await self._get_pool() + + for attempt in range(retries): + try: + con = await pool.get_connection() + + # Verbindungs-Check (Ping) + try: + async with await con.cursor() as cursor: + await cursor.execute("SELECT 1") + await cursor.fetchall() + except errors.OperationalError: + await con.close() + raise + + return con + + except PoolError: + if attempt == retries - 1: + raise + await asyncio.sleep(delay) + @staticmethod async def _exec_sql(cursor: Any, query: str, args=None, multi=True): result = [] if multi: queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()] for q in queries: - if q.strip() == "": - continue - await cursor.execute(q, args) - if cursor.description is not None: - result = await cursor.fetchall() + if q: + await cursor.execute(q, args) + if cursor.description is not None: + result = await cursor.fetchall() else: await cursor.execute(query, args) if cursor.description is not None: result = await cursor.fetchall() - return result - async def execute(self, query: str, args=None, multi=True) -> list[list]: - pool = await self._get_pool() - con = await pool.get_connection() + async def execute(self, query: str, args=None, multi=True) -> list[str]: + con = await self._get_connection() try: async with await con.cursor() as cursor: - result = await self._exec_sql(cursor, query, args, multi) + res = await self._exec_sql(cursor, query, args, multi) await con.commit() - return result + return list(res) finally: await con.close() async def select(self, query: str, args=None, multi=True) -> list[str]: - pool = await self._get_pool() - con = await pool.get_connection() + con = await self._get_connection() try: async with await con.cursor() as cursor: res = await self._exec_sql(cursor, query, args, multi) @@ -86,8 +115,7 @@ class MySQLPool: await con.close() async def select_map(self, query: str, args=None, multi=True) -> list[dict]: - pool = await self._get_pool() - con = await pool.get_connection() + con = await self._get_connection() try: async with await con.cursor(dictionary=True) as cursor: res = await self._exec_sql(cursor, query, args, multi) diff --git a/src/cpl-database/cpl/database/postgres/postgres_pool.py b/src/cpl-database/cpl/database/postgres/postgres_pool.py index 891fb7f1..434c2655 100644 --- a/src/cpl-database/cpl/database/postgres/postgres_pool.py +++ b/src/cpl-database/cpl/database/postgres/postgres_pool.py @@ -27,7 +27,7 @@ class PostgresPool: self._pool: Optional[AsyncConnectionPool] = None async def _get_pool(self): - if self._pool is None: + if self._pool is None or self._pool.closed: pool = AsyncConnectionPool( conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1) ) diff --git a/src/cpl-graphql/cpl/graphql/schema/collection.py b/src/cpl-graphql/cpl/graphql/schema/collection.py index 6cac07a8..1d37a626 100644 --- a/src/cpl-graphql/cpl/graphql/schema/collection.py +++ b/src/cpl-graphql/cpl/graphql/schema/collection.py @@ -3,6 +3,7 @@ from typing import Type, Dict, List import strawberry from cpl.core.typing import T +from cpl.dependency import get_provider from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol @@ -14,7 +15,12 @@ class CollectionGraphTypeFactory: if node_type in cls._cache: return cls._cache[node_type] - gql_node = node_type().to_strawberry() if hasattr(node_type, "to_strawberry") else node_type + node_t = get_provider().get_service(node_type) + if not node_t: + raise ValueError(f"Node type '{node_type.__name__}' not registered in service provider") + + + gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type gql_type = strawberry.type( type( diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py index d1d502e2..6463ace9 100644 --- a/src/cpl-graphql/cpl/graphql/schema/filter/filter.py +++ b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py @@ -1,3 +1,5 @@ +from typing import Type + from cpl.core.typing import T from cpl.graphql.schema.filter.bool_filter import BoolFilter from cpl.graphql.schema.filter.date_filter import DateFilter @@ -10,6 +12,9 @@ class Filter(Input[T]): def __init__(self): Input.__init__(self) + def filter_field(self, name: str, filter_type: Type["Filter"]): + self.field(name, filter_type()) + def string_field(self, name: str): self.field(name, StringFilter()) diff --git a/src/cpl-graphql/cpl/graphql/schema/input.py b/src/cpl-graphql/cpl/graphql/schema/input.py index 82ff31de..a4dfebdf 100644 --- a/src/cpl-graphql/cpl/graphql/schema/input.py +++ b/src/cpl-graphql/cpl/graphql/schema/input.py @@ -5,6 +5,7 @@ import strawberry from cpl.core.typing import T from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.schema.field import Field +from cpl.graphql.utils.type_collector import TypeCollector _PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"} @@ -18,12 +19,10 @@ class Input(StrawberryProtocol, Generic[T]): def field(self, name: str, typ: Union[type, "Input"], optional: bool = True): self._fields[name] = Field(name, typ, optional=optional) - _registry: dict[type, Type] = {} - def to_strawberry(self) -> Type: cls = self.__class__ - if cls in self._registry: - return self._registry[cls] + if TypeCollector.has(cls): + return TypeCollector.get(cls) annotations = {} namespace = {} @@ -50,5 +49,5 @@ class Input(StrawberryProtocol, Generic[T]): namespace["__annotations__"] = annotations gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace)) - Input._registry[cls] = gql_type + TypeCollector.set(cls, gql_type) return gql_type diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py index 736de81f..84270056 100644 --- a/src/cpl-graphql/cpl/graphql/schema/query.py +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -12,6 +12,7 @@ from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory from cpl.graphql.schema.field import Field from cpl.graphql.schema.sort.sort_order import SortOrder from cpl.graphql.typing import Resolver +from cpl.graphql.utils.type_collector import TypeCollector class Query(StrawberryProtocol): @@ -54,6 +55,9 @@ class Query(StrawberryProtocol): def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field: return self.field(name, list[t], resolver) + def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field: + return self.field(name, t().to_strawberry(), resolver) + def with_query(self, name: str, subquery_cls: Type["Query"]): sub = self._provider.get_service(subquery_cls) if not sub: @@ -221,6 +225,10 @@ class Query(StrawberryProtocol): ) from e def to_strawberry(self) -> Type: + cls = self.__class__ + if TypeCollector.has(cls): + return TypeCollector.get(cls) + annotations: dict[str, Any] = {} namespace: dict[str, Any] = {} @@ -229,4 +237,6 @@ class Query(StrawberryProtocol): namespace[name] = self._field_to_strawberry(f) namespace["__annotations__"] = annotations - return strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace)) + gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace)) + TypeCollector.set(cls, gql_type) + return gql_type diff --git a/src/cpl-graphql/cpl/graphql/utils/type_collector.py b/src/cpl-graphql/cpl/graphql/utils/type_collector.py new file mode 100644 index 00000000..c51718bf --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/utils/type_collector.py @@ -0,0 +1,17 @@ +from typing import Type + + +class TypeCollector: + _registry: dict[type, Type] = {} + + @classmethod + def has(cls, base: type) -> bool: + return base in cls._registry + + @classmethod + def get(cls, base: type) -> Type: + return cls._registry[base] + + @classmethod + def set(cls, base: type, gql_type: Type): + cls._registry[base] = gql_type \ No newline at end of file