Recursive complex filtering #181
This commit is contained in:
@@ -14,6 +14,8 @@ from cpl.core.utils.cache import Cache
|
|||||||
from cpl.database.mysql.mysql_module import MySQLModule
|
from cpl.database.mysql.mysql_module import MySQLModule
|
||||||
from cpl.graphql.application.graphql_app import GraphQLApp
|
from cpl.graphql.application.graphql_app import GraphQLApp
|
||||||
from cpl.graphql.graphql_module import GraphQLModule
|
from cpl.graphql.graphql_module import GraphQLModule
|
||||||
|
from model.author_dao import AuthorDao
|
||||||
|
from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort
|
||||||
from model.post_dao import PostDao
|
from model.post_dao import PostDao
|
||||||
from model.post_query import PostFilter, PostSort, PostGraphType
|
from model.post_query import PostFilter, PostSort, PostGraphType
|
||||||
from queries.hello import HelloQuery
|
from queries.hello import HelloQuery
|
||||||
@@ -49,14 +51,18 @@ def main():
|
|||||||
.add_transient(AuthUserFilter)
|
.add_transient(AuthUserFilter)
|
||||||
.add_transient(AuthUserSort)
|
.add_transient(AuthUserSort)
|
||||||
.add_transient(HelloQuery)
|
.add_transient(HelloQuery)
|
||||||
|
# test data
|
||||||
|
.add_singleton(TestDataSeeder)
|
||||||
|
# authors
|
||||||
|
.add_transient(AuthorDao)
|
||||||
|
.add_transient(AuthorGraphType)
|
||||||
|
.add_transient(AuthorFilter)
|
||||||
|
.add_transient(AuthorSort)
|
||||||
# posts
|
# posts
|
||||||
.add_transient(PostDao)
|
.add_transient(PostDao)
|
||||||
.add_transient(PostGraphType)
|
.add_transient(PostGraphType)
|
||||||
.add_transient(PostFilter)
|
.add_transient(PostFilter)
|
||||||
.add_transient(PostSort)
|
.add_transient(PostSort)
|
||||||
|
|
||||||
# test data
|
|
||||||
.add_singleton(TestDataSeeder)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
app = builder.build()
|
app = builder.build()
|
||||||
@@ -78,6 +84,7 @@ def main():
|
|||||||
schema = app.with_graphql()
|
schema = app.with_graphql()
|
||||||
schema.query.string_field("ping", resolver=lambda: "pong")
|
schema.query.string_field("ping", resolver=lambda: "pong")
|
||||||
schema.query.with_query("hello", HelloQuery)
|
schema.query.with_query("hello", HelloQuery)
|
||||||
|
schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort)
|
||||||
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort)
|
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort)
|
||||||
|
|
||||||
app.with_playground()
|
app.with_playground()
|
||||||
|
|||||||
30
example/api/src/model/author.py
Normal file
30
example/api/src/model/author.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
|
from cpl.core.typing import SerialId
|
||||||
|
from cpl.database.abc import DbModelABC
|
||||||
|
|
||||||
|
|
||||||
|
class Author(DbModelABC[Self]):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: int,
|
||||||
|
first_name: str,
|
||||||
|
last_name: str,
|
||||||
|
deleted: bool = False,
|
||||||
|
editor_id: SerialId | None = None,
|
||||||
|
created: datetime | None = None,
|
||||||
|
updated: datetime | None = None,
|
||||||
|
):
|
||||||
|
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
||||||
|
self._first_name = first_name
|
||||||
|
self._last_name = last_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def first_name(self) -> str:
|
||||||
|
return self._first_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_name(self) -> str:
|
||||||
|
return self._last_name
|
||||||
11
example/api/src/model/author_dao.py
Normal file
11
example/api/src/model/author_dao.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from cpl.database.abc import DbModelDaoABC
|
||||||
|
from model.author import Author
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorDao(DbModelDaoABC):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
DbModelDaoABC.__init__(self, Author, "authors")
|
||||||
|
|
||||||
|
self.attribute(Author.first_name, str)
|
||||||
|
self.attribute(Author.last_name, str)
|
||||||
37
example/api/src/model/author_query.py
Normal file
37
example/api/src/model/author_query.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from cpl.graphql.schema.filter.filter import Filter
|
||||||
|
from cpl.graphql.schema.graph_type import GraphType
|
||||||
|
from cpl.graphql.schema.sort.sort import Sort
|
||||||
|
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||||
|
from model.author import Author
|
||||||
|
|
||||||
|
class AuthorFilter(Filter[Author]):
|
||||||
|
def __init__(self):
|
||||||
|
Filter.__init__(self)
|
||||||
|
self.int_field("id")
|
||||||
|
self.string_field("firstName")
|
||||||
|
self.string_field("lastName")
|
||||||
|
|
||||||
|
class AuthorSort(Sort[Author]):
|
||||||
|
def __init__(self):
|
||||||
|
Sort.__init__(self)
|
||||||
|
self.field("id", SortOrder)
|
||||||
|
self.field("firstName", SortOrder)
|
||||||
|
self.field("lastName", SortOrder)
|
||||||
|
|
||||||
|
class AuthorGraphType(GraphType[Author]):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
GraphType.__init__(self)
|
||||||
|
|
||||||
|
self.int_field(
|
||||||
|
"id",
|
||||||
|
resolver=lambda root: root.id,
|
||||||
|
)
|
||||||
|
self.string_field(
|
||||||
|
"firstName",
|
||||||
|
resolver=lambda root: root.first_name,
|
||||||
|
)
|
||||||
|
self.string_field(
|
||||||
|
"lastName",
|
||||||
|
resolver=lambda root: root.last_name,
|
||||||
|
)
|
||||||
@@ -10,6 +10,7 @@ class Post(DbModelABC[Self]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
id: int,
|
id: int,
|
||||||
|
author_id: SerialId,
|
||||||
title: str,
|
title: str,
|
||||||
content: str,
|
content: str,
|
||||||
deleted: bool = False,
|
deleted: bool = False,
|
||||||
@@ -18,9 +19,14 @@ class Post(DbModelABC[Self]):
|
|||||||
updated: datetime | None = None,
|
updated: datetime | None = None,
|
||||||
):
|
):
|
||||||
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
||||||
|
self._author_id = author_id
|
||||||
self._title = title
|
self._title = title
|
||||||
self._content = content
|
self._content = content
|
||||||
|
|
||||||
|
@property
|
||||||
|
def author_id(self) -> SerialId:
|
||||||
|
return self._author_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def title(self) -> str:
|
def title(self) -> str:
|
||||||
return self._title
|
return self._title
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
from cpl.database.abc import DbModelDaoABC
|
from cpl.database.abc import DbModelDaoABC
|
||||||
|
from model.author_dao import AuthorDao
|
||||||
from model.post import Post
|
from model.post import Post
|
||||||
|
|
||||||
|
|
||||||
class PostDao(DbModelDaoABC):
|
class PostDao(DbModelDaoABC):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, authors: AuthorDao):
|
||||||
DbModelDaoABC.__init__(self, Post, "posts")
|
DbModelDaoABC.__init__(self, Post, "posts")
|
||||||
|
|
||||||
|
self.attribute(Post.author_id, int, db_name="authorId")
|
||||||
|
self.reference("author", "id", Post.author_id, "authors", authors)
|
||||||
|
|
||||||
self.attribute(Post.title, str)
|
self.attribute(Post.title, str)
|
||||||
self.attribute(Post.content, str)
|
self.attribute(Post.content, str)
|
||||||
@@ -2,12 +2,15 @@ from cpl.graphql.schema.filter.filter import Filter
|
|||||||
from cpl.graphql.schema.graph_type import GraphType
|
from cpl.graphql.schema.graph_type import GraphType
|
||||||
from cpl.graphql.schema.sort.sort import Sort
|
from cpl.graphql.schema.sort.sort import Sort
|
||||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||||
|
from model.author_dao import AuthorDao
|
||||||
|
from model.author_query import AuthorGraphType, AuthorFilter
|
||||||
from model.post import Post
|
from model.post import Post
|
||||||
|
|
||||||
class PostFilter(Filter[Post]):
|
class PostFilter(Filter[Post]):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
Filter.__init__(self)
|
Filter.__init__(self)
|
||||||
self.int_field("id")
|
self.int_field("id")
|
||||||
|
self.filter_field("author", AuthorFilter)
|
||||||
self.string_field("title")
|
self.string_field("title")
|
||||||
self.string_field("content")
|
self.string_field("content")
|
||||||
|
|
||||||
@@ -20,13 +23,22 @@ class PostSort(Sort[Post]):
|
|||||||
|
|
||||||
class PostGraphType(GraphType[Post]):
|
class PostGraphType(GraphType[Post]):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, authors: AuthorDao):
|
||||||
GraphType.__init__(self)
|
GraphType.__init__(self)
|
||||||
|
|
||||||
self.int_field(
|
self.int_field(
|
||||||
"id",
|
"id",
|
||||||
resolver=lambda root: root.id,
|
resolver=lambda root: root.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _a(root: Post):
|
||||||
|
return await authors.get_by_id(root.author_id)
|
||||||
|
|
||||||
|
self.object_field(
|
||||||
|
"author",
|
||||||
|
AuthorGraphType,
|
||||||
|
resolver=_a#lambda root: root.author_id,
|
||||||
|
)
|
||||||
self.string_field(
|
self.string_field(
|
||||||
"title",
|
"title",
|
||||||
resolver=lambda root: root.title,
|
resolver=lambda root: root.title,
|
||||||
|
|||||||
@@ -1,7 +1,19 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS `authors` (
|
||||||
|
`id` INT(30) NOT NULL AUTO_INCREMENT,
|
||||||
|
`firstname` VARCHAR(64) NOT NULL,
|
||||||
|
`lastname` VARCHAR(64) NOT NULL,
|
||||||
|
deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
editorId INT NULL,
|
||||||
|
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||||
|
PRIMARY KEY(`id`)
|
||||||
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS `posts` (
|
CREATE TABLE IF NOT EXISTS `posts` (
|
||||||
`id` INT(30) NOT NULL AUTO_INCREMENT,
|
`id` INT(30) NOT NULL AUTO_INCREMENT,
|
||||||
`title` VARCHAR(64) NOT NULL,
|
`authorId` INT(30) NOT NULL REFERENCES `authors`(`id`) ON DELETE CASCADE,
|
||||||
`content` VARCHAR(512) NOT NULL,
|
`title` TEXT NOT NULL,
|
||||||
|
`content` TEXT NOT NULL,
|
||||||
deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
editorId INT NULL,
|
editorId INT NULL,
|
||||||
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ from faker import Faker
|
|||||||
|
|
||||||
from cpl.database.abc import DataSeederABC
|
from cpl.database.abc import DataSeederABC
|
||||||
from cpl.query import Enumerable
|
from cpl.query import Enumerable
|
||||||
|
from model.author import Author
|
||||||
|
from model.author_dao import AuthorDao
|
||||||
from model.post import Post
|
from model.post import Post
|
||||||
from model.post_dao import PostDao
|
from model.post_dao import PostDao
|
||||||
|
|
||||||
@@ -11,19 +13,34 @@ fake = Faker()
|
|||||||
|
|
||||||
class TestDataSeeder(DataSeederABC):
|
class TestDataSeeder(DataSeederABC):
|
||||||
|
|
||||||
def __init__(self, posts: PostDao):
|
def __init__(self, authors: AuthorDao, posts: PostDao):
|
||||||
DataSeederABC.__init__(self)
|
DataSeederABC.__init__(self)
|
||||||
|
|
||||||
|
self._authors = authors
|
||||||
self._posts = posts
|
self._posts = posts
|
||||||
|
|
||||||
async def seed(self):
|
async def seed(self):
|
||||||
|
if await self._authors.count() == 0:
|
||||||
|
await self._seed_authors()
|
||||||
|
|
||||||
if await self._posts.count() == 0:
|
if await self._posts.count() == 0:
|
||||||
await self._seed_posts()
|
await self._seed_posts()
|
||||||
|
|
||||||
|
async def _seed_authors(self):
|
||||||
|
authors = Enumerable.range(0, 35).select(
|
||||||
|
lambda x: Author(
|
||||||
|
0,
|
||||||
|
fake.first_name(),
|
||||||
|
fake.last_name(),
|
||||||
|
)
|
||||||
|
).to_list()
|
||||||
|
await self._authors.create_many(authors, skip_editor=True)
|
||||||
|
|
||||||
async def _seed_posts(self):
|
async def _seed_posts(self):
|
||||||
posts = Enumerable.range(0, 100).select(
|
posts = Enumerable.range(0, 100).select(
|
||||||
lambda x: Post(
|
lambda x: Post(
|
||||||
id=0,
|
id=0,
|
||||||
|
author_id=fake.random_int(min=1, max=35),
|
||||||
title=fake.sentence(nb_words=6),
|
title=fake.sentence(nb_words=6),
|
||||||
content=fake.paragraph(nb_sentences=6),
|
content=fake.paragraph(nb_sentences=6),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
|
|
||||||
import sqlparse
|
import sqlparse
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from mysql.connector import errors, PoolError
|
||||||
from mysql.connector.aio import MySQLConnectionPool
|
from mysql.connector.aio import MySQLConnectionPool
|
||||||
|
|
||||||
from cpl.core.environment import Environment
|
from cpl.core.environment import Environment
|
||||||
@@ -10,7 +12,6 @@ from cpl.dependency.context import get_provider
|
|||||||
|
|
||||||
|
|
||||||
class MySQLPool:
|
class MySQLPool:
|
||||||
|
|
||||||
def __init__(self, database_settings: DatabaseSettings):
|
def __init__(self, database_settings: DatabaseSettings):
|
||||||
self._dbconfig = {
|
self._dbconfig = {
|
||||||
"host": database_settings.host,
|
"host": database_settings.host,
|
||||||
@@ -25,35 +26,66 @@ class MySQLPool:
|
|||||||
"ssl_disabled": database_settings.ssl_disabled,
|
"ssl_disabled": database_settings.ssl_disabled,
|
||||||
}
|
}
|
||||||
self._pool: Optional[MySQLConnectionPool] = None
|
self._pool: Optional[MySQLConnectionPool] = None
|
||||||
|
self._pool_lock = asyncio.Lock()
|
||||||
|
|
||||||
async def _get_pool(self):
|
async def _get_pool(self) -> MySQLConnectionPool:
|
||||||
|
if self._pool is None:
|
||||||
|
async with self._pool_lock:
|
||||||
if self._pool is None:
|
if self._pool is None:
|
||||||
try:
|
try:
|
||||||
self._pool = MySQLConnectionPool(
|
self._pool = MySQLConnectionPool(
|
||||||
pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig
|
pool_name="cplpool",
|
||||||
|
pool_size=Environment.get("DB_POOL_SIZE", int, 20),
|
||||||
|
**self._dbconfig,
|
||||||
)
|
)
|
||||||
await self._pool.initialize_pool()
|
await self._pool.initialize_pool()
|
||||||
|
|
||||||
|
# Testverbindung (Ping)
|
||||||
con = await self._pool.get_connection()
|
con = await self._pool.get_connection()
|
||||||
|
try:
|
||||||
async with await con.cursor() as cursor:
|
async with await con.cursor() as cursor:
|
||||||
await cursor.execute("SELECT 1")
|
await cursor.execute("SELECT 1")
|
||||||
await cursor.fetchall()
|
await cursor.fetchall()
|
||||||
|
finally:
|
||||||
await con.close()
|
await con.close()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger = get_provider().get_service(DBLogger)
|
logger = get_provider().get_service(DBLogger)
|
||||||
logger.fatal(f"Error connecting to the database", e)
|
logger.fatal("Error connecting to the database", e)
|
||||||
|
raise
|
||||||
return self._pool
|
return self._pool
|
||||||
|
|
||||||
|
async def _get_connection(self, retries: int = 3, delay: float = 0.5):
|
||||||
|
"""Stabiler Connection-Getter mit Retry und Ping"""
|
||||||
|
pool = await self._get_pool()
|
||||||
|
|
||||||
|
for attempt in range(retries):
|
||||||
|
try:
|
||||||
|
con = await pool.get_connection()
|
||||||
|
|
||||||
|
# Verbindungs-Check (Ping)
|
||||||
|
try:
|
||||||
|
async with await con.cursor() as cursor:
|
||||||
|
await cursor.execute("SELECT 1")
|
||||||
|
await cursor.fetchall()
|
||||||
|
except errors.OperationalError:
|
||||||
|
await con.close()
|
||||||
|
raise
|
||||||
|
|
||||||
|
return con
|
||||||
|
|
||||||
|
except PoolError:
|
||||||
|
if attempt == retries - 1:
|
||||||
|
raise
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
|
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
|
||||||
result = []
|
result = []
|
||||||
if multi:
|
if multi:
|
||||||
queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()]
|
queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()]
|
||||||
for q in queries:
|
for q in queries:
|
||||||
if q.strip() == "":
|
if q:
|
||||||
continue
|
|
||||||
await cursor.execute(q, args)
|
await cursor.execute(q, args)
|
||||||
if cursor.description is not None:
|
if cursor.description is not None:
|
||||||
result = await cursor.fetchall()
|
result = await cursor.fetchall()
|
||||||
@@ -61,23 +93,20 @@ class MySQLPool:
|
|||||||
await cursor.execute(query, args)
|
await cursor.execute(query, args)
|
||||||
if cursor.description is not None:
|
if cursor.description is not None:
|
||||||
result = await cursor.fetchall()
|
result = await cursor.fetchall()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def execute(self, query: str, args=None, multi=True) -> list[list]:
|
async def execute(self, query: str, args=None, multi=True) -> list[str]:
|
||||||
pool = await self._get_pool()
|
con = await self._get_connection()
|
||||||
con = await pool.get_connection()
|
|
||||||
try:
|
try:
|
||||||
async with await con.cursor() as cursor:
|
async with await con.cursor() as cursor:
|
||||||
result = await self._exec_sql(cursor, query, args, multi)
|
res = await self._exec_sql(cursor, query, args, multi)
|
||||||
await con.commit()
|
await con.commit()
|
||||||
return result
|
return list(res)
|
||||||
finally:
|
finally:
|
||||||
await con.close()
|
await con.close()
|
||||||
|
|
||||||
async def select(self, query: str, args=None, multi=True) -> list[str]:
|
async def select(self, query: str, args=None, multi=True) -> list[str]:
|
||||||
pool = await self._get_pool()
|
con = await self._get_connection()
|
||||||
con = await pool.get_connection()
|
|
||||||
try:
|
try:
|
||||||
async with await con.cursor() as cursor:
|
async with await con.cursor() as cursor:
|
||||||
res = await self._exec_sql(cursor, query, args, multi)
|
res = await self._exec_sql(cursor, query, args, multi)
|
||||||
@@ -86,8 +115,7 @@ class MySQLPool:
|
|||||||
await con.close()
|
await con.close()
|
||||||
|
|
||||||
async def select_map(self, query: str, args=None, multi=True) -> list[dict]:
|
async def select_map(self, query: str, args=None, multi=True) -> list[dict]:
|
||||||
pool = await self._get_pool()
|
con = await self._get_connection()
|
||||||
con = await pool.get_connection()
|
|
||||||
try:
|
try:
|
||||||
async with await con.cursor(dictionary=True) as cursor:
|
async with await con.cursor(dictionary=True) as cursor:
|
||||||
res = await self._exec_sql(cursor, query, args, multi)
|
res = await self._exec_sql(cursor, query, args, multi)
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class PostgresPool:
|
|||||||
self._pool: Optional[AsyncConnectionPool] = None
|
self._pool: Optional[AsyncConnectionPool] = None
|
||||||
|
|
||||||
async def _get_pool(self):
|
async def _get_pool(self):
|
||||||
if self._pool is None:
|
if self._pool is None or self._pool.closed:
|
||||||
pool = AsyncConnectionPool(
|
pool = AsyncConnectionPool(
|
||||||
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
|
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import Type, Dict, List
|
|||||||
import strawberry
|
import strawberry
|
||||||
|
|
||||||
from cpl.core.typing import T
|
from cpl.core.typing import T
|
||||||
|
from cpl.dependency import get_provider
|
||||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||||
|
|
||||||
|
|
||||||
@@ -14,7 +15,12 @@ class CollectionGraphTypeFactory:
|
|||||||
if node_type in cls._cache:
|
if node_type in cls._cache:
|
||||||
return cls._cache[node_type]
|
return cls._cache[node_type]
|
||||||
|
|
||||||
gql_node = node_type().to_strawberry() if hasattr(node_type, "to_strawberry") else node_type
|
node_t = get_provider().get_service(node_type)
|
||||||
|
if not node_t:
|
||||||
|
raise ValueError(f"Node type '{node_type.__name__}' not registered in service provider")
|
||||||
|
|
||||||
|
|
||||||
|
gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type
|
||||||
|
|
||||||
gql_type = strawberry.type(
|
gql_type = strawberry.type(
|
||||||
type(
|
type(
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
from cpl.core.typing import T
|
from cpl.core.typing import T
|
||||||
from cpl.graphql.schema.filter.bool_filter import BoolFilter
|
from cpl.graphql.schema.filter.bool_filter import BoolFilter
|
||||||
from cpl.graphql.schema.filter.date_filter import DateFilter
|
from cpl.graphql.schema.filter.date_filter import DateFilter
|
||||||
@@ -10,6 +12,9 @@ class Filter(Input[T]):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
Input.__init__(self)
|
Input.__init__(self)
|
||||||
|
|
||||||
|
def filter_field(self, name: str, filter_type: Type["Filter"]):
|
||||||
|
self.field(name, filter_type())
|
||||||
|
|
||||||
def string_field(self, name: str):
|
def string_field(self, name: str):
|
||||||
self.field(name, StringFilter())
|
self.field(name, StringFilter())
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import strawberry
|
|||||||
from cpl.core.typing import T
|
from cpl.core.typing import T
|
||||||
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
|
||||||
from cpl.graphql.schema.field import Field
|
from cpl.graphql.schema.field import Field
|
||||||
|
from cpl.graphql.utils.type_collector import TypeCollector
|
||||||
|
|
||||||
_PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
|
_PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
|
||||||
|
|
||||||
@@ -18,12 +19,10 @@ class Input(StrawberryProtocol, Generic[T]):
|
|||||||
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True):
|
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True):
|
||||||
self._fields[name] = Field(name, typ, optional=optional)
|
self._fields[name] = Field(name, typ, optional=optional)
|
||||||
|
|
||||||
_registry: dict[type, Type] = {}
|
|
||||||
|
|
||||||
def to_strawberry(self) -> Type:
|
def to_strawberry(self) -> Type:
|
||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
if cls in self._registry:
|
if TypeCollector.has(cls):
|
||||||
return self._registry[cls]
|
return TypeCollector.get(cls)
|
||||||
|
|
||||||
annotations = {}
|
annotations = {}
|
||||||
namespace = {}
|
namespace = {}
|
||||||
@@ -50,5 +49,5 @@ class Input(StrawberryProtocol, Generic[T]):
|
|||||||
namespace["__annotations__"] = annotations
|
namespace["__annotations__"] = annotations
|
||||||
|
|
||||||
gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace))
|
gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace))
|
||||||
Input._registry[cls] = gql_type
|
TypeCollector.set(cls, gql_type)
|
||||||
return gql_type
|
return gql_type
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
|
|||||||
from cpl.graphql.schema.field import Field
|
from cpl.graphql.schema.field import Field
|
||||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||||
from cpl.graphql.typing import Resolver
|
from cpl.graphql.typing import Resolver
|
||||||
|
from cpl.graphql.utils.type_collector import TypeCollector
|
||||||
|
|
||||||
|
|
||||||
class Query(StrawberryProtocol):
|
class Query(StrawberryProtocol):
|
||||||
@@ -54,6 +55,9 @@ class Query(StrawberryProtocol):
|
|||||||
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
|
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
|
||||||
return self.field(name, list[t], resolver)
|
return self.field(name, list[t], resolver)
|
||||||
|
|
||||||
|
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
|
||||||
|
return self.field(name, t().to_strawberry(), resolver)
|
||||||
|
|
||||||
def with_query(self, name: str, subquery_cls: Type["Query"]):
|
def with_query(self, name: str, subquery_cls: Type["Query"]):
|
||||||
sub = self._provider.get_service(subquery_cls)
|
sub = self._provider.get_service(subquery_cls)
|
||||||
if not sub:
|
if not sub:
|
||||||
@@ -221,6 +225,10 @@ class Query(StrawberryProtocol):
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
def to_strawberry(self) -> Type:
|
def to_strawberry(self) -> Type:
|
||||||
|
cls = self.__class__
|
||||||
|
if TypeCollector.has(cls):
|
||||||
|
return TypeCollector.get(cls)
|
||||||
|
|
||||||
annotations: dict[str, Any] = {}
|
annotations: dict[str, Any] = {}
|
||||||
namespace: dict[str, Any] = {}
|
namespace: dict[str, Any] = {}
|
||||||
|
|
||||||
@@ -229,4 +237,6 @@ class Query(StrawberryProtocol):
|
|||||||
namespace[name] = self._field_to_strawberry(f)
|
namespace[name] = self._field_to_strawberry(f)
|
||||||
|
|
||||||
namespace["__annotations__"] = annotations
|
namespace["__annotations__"] = annotations
|
||||||
return strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
|
gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
|
||||||
|
TypeCollector.set(cls, gql_type)
|
||||||
|
return gql_type
|
||||||
|
|||||||
17
src/cpl-graphql/cpl/graphql/utils/type_collector.py
Normal file
17
src/cpl-graphql/cpl/graphql/utils/type_collector.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
|
|
||||||
|
class TypeCollector:
|
||||||
|
_registry: dict[type, Type] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def has(cls, base: type) -> bool:
|
||||||
|
return base in cls._registry
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls, base: type) -> Type:
|
||||||
|
return cls._registry[base]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set(cls, base: type, gql_type: Type):
|
||||||
|
cls._registry[base] = gql_type
|
||||||
Reference in New Issue
Block a user