WIP: dev into master #184

Draft
edraft wants to merge 121 commits from dev into master
16 changed files with 249 additions and 48 deletions
Showing only changes of commit 20e5da5770 - Show all commits

View File

@@ -14,6 +14,8 @@ from cpl.core.utils.cache import Cache
from cpl.database.mysql.mysql_module import MySQLModule from cpl.database.mysql.mysql_module import MySQLModule
from cpl.graphql.application.graphql_app import GraphQLApp from cpl.graphql.application.graphql_app import GraphQLApp
from cpl.graphql.graphql_module import GraphQLModule from cpl.graphql.graphql_module import GraphQLModule
from model.author_dao import AuthorDao
from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort
from model.post_dao import PostDao from model.post_dao import PostDao
from model.post_query import PostFilter, PostSort, PostGraphType from model.post_query import PostFilter, PostSort, PostGraphType
from queries.hello import HelloQuery from queries.hello import HelloQuery
@@ -49,14 +51,18 @@ def main():
.add_transient(AuthUserFilter) .add_transient(AuthUserFilter)
.add_transient(AuthUserSort) .add_transient(AuthUserSort)
.add_transient(HelloQuery) .add_transient(HelloQuery)
# test data
.add_singleton(TestDataSeeder)
# authors
.add_transient(AuthorDao)
.add_transient(AuthorGraphType)
.add_transient(AuthorFilter)
.add_transient(AuthorSort)
# posts # posts
.add_transient(PostDao) .add_transient(PostDao)
.add_transient(PostGraphType) .add_transient(PostGraphType)
.add_transient(PostFilter) .add_transient(PostFilter)
.add_transient(PostSort) .add_transient(PostSort)
# test data
.add_singleton(TestDataSeeder)
) )
app = builder.build() app = builder.build()
@@ -78,6 +84,7 @@ def main():
schema = app.with_graphql() schema = app.with_graphql()
schema.query.string_field("ping", resolver=lambda: "pong") schema.query.string_field("ping", resolver=lambda: "pong")
schema.query.with_query("hello", HelloQuery) schema.query.with_query("hello", HelloQuery)
schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort)
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort) schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort)
app.with_playground() app.with_playground()

View File

@@ -0,0 +1,30 @@
from datetime import datetime
from typing import Self
from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC
class Author(DbModelABC[Self]):
def __init__(
self,
id: int,
first_name: str,
last_name: str,
deleted: bool = False,
editor_id: SerialId | None = None,
created: datetime | None = None,
updated: datetime | None = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._first_name = first_name
self._last_name = last_name
@property
def first_name(self) -> str:
return self._first_name
@property
def last_name(self) -> str:
return self._last_name

View File

@@ -0,0 +1,11 @@
from cpl.database.abc import DbModelDaoABC
from model.author import Author
class AuthorDao(DbModelDaoABC):
def __init__(self):
DbModelDaoABC.__init__(self, Author, "authors")
self.attribute(Author.first_name, str)
self.attribute(Author.last_name, str)

View File

@@ -0,0 +1,37 @@
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
from model.author import Author
class AuthorFilter(Filter[Author]):
def __init__(self):
Filter.__init__(self)
self.int_field("id")
self.string_field("firstName")
self.string_field("lastName")
class AuthorSort(Sort[Author]):
def __init__(self):
Sort.__init__(self)
self.field("id", SortOrder)
self.field("firstName", SortOrder)
self.field("lastName", SortOrder)
class AuthorGraphType(GraphType[Author]):
def __init__(self):
GraphType.__init__(self)
self.int_field(
"id",
resolver=lambda root: root.id,
)
self.string_field(
"firstName",
resolver=lambda root: root.first_name,
)
self.string_field(
"lastName",
resolver=lambda root: root.last_name,
)

View File

@@ -10,6 +10,7 @@ class Post(DbModelABC[Self]):
def __init__( def __init__(
self, self,
id: int, id: int,
author_id: SerialId,
title: str, title: str,
content: str, content: str,
deleted: bool = False, deleted: bool = False,
@@ -18,9 +19,14 @@ class Post(DbModelABC[Self]):
updated: datetime | None = None, updated: datetime | None = None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._author_id = author_id
self._title = title self._title = title
self._content = content self._content = content
@property
def author_id(self) -> SerialId:
return self._author_id
@property @property
def title(self) -> str: def title(self) -> str:
return self._title return self._title

View File

@@ -1,11 +1,15 @@
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from model.author_dao import AuthorDao
from model.post import Post from model.post import Post
class PostDao(DbModelDaoABC): class PostDao(DbModelDaoABC):
def __init__(self): def __init__(self, authors: AuthorDao):
DbModelDaoABC.__init__(self, Post, "posts") DbModelDaoABC.__init__(self, Post, "posts")
self.attribute(Post.author_id, int, db_name="authorId")
self.reference("author", "id", Post.author_id, "authors", authors)
self.attribute(Post.title, str) self.attribute(Post.title, str)
self.attribute(Post.content, str) self.attribute(Post.content, str)

View File

@@ -2,12 +2,15 @@ from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.graph_type import GraphType from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.sort.sort import Sort from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder from cpl.graphql.schema.sort.sort_order import SortOrder
from model.author_dao import AuthorDao
from model.author_query import AuthorGraphType, AuthorFilter
from model.post import Post from model.post import Post
class PostFilter(Filter[Post]): class PostFilter(Filter[Post]):
def __init__(self): def __init__(self):
Filter.__init__(self) Filter.__init__(self)
self.int_field("id") self.int_field("id")
self.filter_field("author", AuthorFilter)
self.string_field("title") self.string_field("title")
self.string_field("content") self.string_field("content")
@@ -20,13 +23,22 @@ class PostSort(Sort[Post]):
class PostGraphType(GraphType[Post]): class PostGraphType(GraphType[Post]):
def __init__(self): def __init__(self, authors: AuthorDao):
GraphType.__init__(self) GraphType.__init__(self)
self.int_field( self.int_field(
"id", "id",
resolver=lambda root: root.id, resolver=lambda root: root.id,
) )
async def _a(root: Post):
return await authors.get_by_id(root.author_id)
self.object_field(
"author",
AuthorGraphType,
resolver=_a#lambda root: root.author_id,
)
self.string_field( self.string_field(
"title", "title",
resolver=lambda root: root.title, resolver=lambda root: root.title,

View File

@@ -1,7 +1,19 @@
CREATE TABLE IF NOT EXISTS `posts` ( CREATE TABLE IF NOT EXISTS `authors` (
`id` INT(30) NOT NULL AUTO_INCREMENT, `id` INT(30) NOT NULL AUTO_INCREMENT,
`title` VARCHAR(64) NOT NULL, `firstname` VARCHAR(64) NOT NULL,
`content` VARCHAR(512) NOT NULL, `lastname` VARCHAR(64) NOT NULL,
deleted BOOLEAN NOT NULL DEFAULT FALSE,
editorId INT NULL,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY(`id`)
);
CREATE TABLE IF NOT EXISTS `posts` (
`id` INT(30) NOT NULL AUTO_INCREMENT,
`authorId` INT(30) NOT NULL REFERENCES `authors`(`id`) ON DELETE CASCADE,
`title` TEXT NOT NULL,
`content` TEXT NOT NULL,
deleted BOOLEAN NOT NULL DEFAULT FALSE, deleted BOOLEAN NOT NULL DEFAULT FALSE,
editorId INT NULL, editorId INT NULL,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,

View File

@@ -2,6 +2,8 @@ from faker import Faker
from cpl.database.abc import DataSeederABC from cpl.database.abc import DataSeederABC
from cpl.query import Enumerable from cpl.query import Enumerable
from model.author import Author
from model.author_dao import AuthorDao
from model.post import Post from model.post import Post
from model.post_dao import PostDao from model.post_dao import PostDao
@@ -11,19 +13,34 @@ fake = Faker()
class TestDataSeeder(DataSeederABC): class TestDataSeeder(DataSeederABC):
def __init__(self, posts: PostDao): def __init__(self, authors: AuthorDao, posts: PostDao):
DataSeederABC.__init__(self) DataSeederABC.__init__(self)
self._authors = authors
self._posts = posts self._posts = posts
async def seed(self): async def seed(self):
if await self._authors.count() == 0:
await self._seed_authors()
if await self._posts.count() == 0: if await self._posts.count() == 0:
await self._seed_posts() await self._seed_posts()
async def _seed_authors(self):
authors = Enumerable.range(0, 35).select(
lambda x: Author(
0,
fake.first_name(),
fake.last_name(),
)
).to_list()
await self._authors.create_many(authors, skip_editor=True)
async def _seed_posts(self): async def _seed_posts(self):
posts = Enumerable.range(0, 100).select( posts = Enumerable.range(0, 100).select(
lambda x: Post( lambda x: Post(
id=0, id=0,
author_id=fake.random_int(min=1, max=35),
title=fake.sentence(nb_words=6), title=fake.sentence(nb_words=6),
content=fake.paragraph(nb_sentences=6), content=fake.paragraph(nb_sentences=6),
) )

View File

@@ -1,6 +1,8 @@
from typing import Optional, Any from typing import Optional, Any
import sqlparse import sqlparse
import asyncio
from mysql.connector import errors, PoolError
from mysql.connector.aio import MySQLConnectionPool from mysql.connector.aio import MySQLConnectionPool
from cpl.core.environment import Environment from cpl.core.environment import Environment
@@ -10,7 +12,6 @@ from cpl.dependency.context import get_provider
class MySQLPool: class MySQLPool:
def __init__(self, database_settings: DatabaseSettings): def __init__(self, database_settings: DatabaseSettings):
self._dbconfig = { self._dbconfig = {
"host": database_settings.host, "host": database_settings.host,
@@ -25,35 +26,66 @@ class MySQLPool:
"ssl_disabled": database_settings.ssl_disabled, "ssl_disabled": database_settings.ssl_disabled,
} }
self._pool: Optional[MySQLConnectionPool] = None self._pool: Optional[MySQLConnectionPool] = None
self._pool_lock = asyncio.Lock()
async def _get_pool(self): async def _get_pool(self) -> MySQLConnectionPool:
if self._pool is None:
async with self._pool_lock:
if self._pool is None: if self._pool is None:
try: try:
self._pool = MySQLConnectionPool( self._pool = MySQLConnectionPool(
pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig pool_name="cplpool",
pool_size=Environment.get("DB_POOL_SIZE", int, 20),
**self._dbconfig,
) )
await self._pool.initialize_pool() await self._pool.initialize_pool()
# Testverbindung (Ping)
con = await self._pool.get_connection() con = await self._pool.get_connection()
try:
async with await con.cursor() as cursor: async with await con.cursor() as cursor:
await cursor.execute("SELECT 1") await cursor.execute("SELECT 1")
await cursor.fetchall() await cursor.fetchall()
finally:
await con.close() await con.close()
except Exception as e: except Exception as e:
logger = get_provider().get_service(DBLogger) logger = get_provider().get_service(DBLogger)
logger.fatal(f"Error connecting to the database", e) logger.fatal("Error connecting to the database", e)
raise
return self._pool return self._pool
async def _get_connection(self, retries: int = 3, delay: float = 0.5):
"""Stabiler Connection-Getter mit Retry und Ping"""
pool = await self._get_pool()
for attempt in range(retries):
try:
con = await pool.get_connection()
# Verbindungs-Check (Ping)
try:
async with await con.cursor() as cursor:
await cursor.execute("SELECT 1")
await cursor.fetchall()
except errors.OperationalError:
await con.close()
raise
return con
except PoolError:
if attempt == retries - 1:
raise
await asyncio.sleep(delay)
@staticmethod @staticmethod
async def _exec_sql(cursor: Any, query: str, args=None, multi=True): async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
result = [] result = []
if multi: if multi:
queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()] queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()]
for q in queries: for q in queries:
if q.strip() == "": if q:
continue
await cursor.execute(q, args) await cursor.execute(q, args)
if cursor.description is not None: if cursor.description is not None:
result = await cursor.fetchall() result = await cursor.fetchall()
@@ -61,23 +93,20 @@ class MySQLPool:
await cursor.execute(query, args) await cursor.execute(query, args)
if cursor.description is not None: if cursor.description is not None:
result = await cursor.fetchall() result = await cursor.fetchall()
return result return result
async def execute(self, query: str, args=None, multi=True) -> list[list]: async def execute(self, query: str, args=None, multi=True) -> list[str]:
pool = await self._get_pool() con = await self._get_connection()
con = await pool.get_connection()
try: try:
async with await con.cursor() as cursor: async with await con.cursor() as cursor:
result = await self._exec_sql(cursor, query, args, multi) res = await self._exec_sql(cursor, query, args, multi)
await con.commit() await con.commit()
return result return list(res)
finally: finally:
await con.close() await con.close()
async def select(self, query: str, args=None, multi=True) -> list[str]: async def select(self, query: str, args=None, multi=True) -> list[str]:
pool = await self._get_pool() con = await self._get_connection()
con = await pool.get_connection()
try: try:
async with await con.cursor() as cursor: async with await con.cursor() as cursor:
res = await self._exec_sql(cursor, query, args, multi) res = await self._exec_sql(cursor, query, args, multi)
@@ -86,8 +115,7 @@ class MySQLPool:
await con.close() await con.close()
async def select_map(self, query: str, args=None, multi=True) -> list[dict]: async def select_map(self, query: str, args=None, multi=True) -> list[dict]:
pool = await self._get_pool() con = await self._get_connection()
con = await pool.get_connection()
try: try:
async with await con.cursor(dictionary=True) as cursor: async with await con.cursor(dictionary=True) as cursor:
res = await self._exec_sql(cursor, query, args, multi) res = await self._exec_sql(cursor, query, args, multi)

View File

@@ -27,7 +27,7 @@ class PostgresPool:
self._pool: Optional[AsyncConnectionPool] = None self._pool: Optional[AsyncConnectionPool] = None
async def _get_pool(self): async def _get_pool(self):
if self._pool is None: if self._pool is None or self._pool.closed:
pool = AsyncConnectionPool( pool = AsyncConnectionPool(
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1) conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
) )

View File

@@ -3,6 +3,7 @@ from typing import Type, Dict, List
import strawberry import strawberry
from cpl.core.typing import T from cpl.core.typing import T
from cpl.dependency import get_provider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
@@ -14,7 +15,12 @@ class CollectionGraphTypeFactory:
if node_type in cls._cache: if node_type in cls._cache:
return cls._cache[node_type] return cls._cache[node_type]
gql_node = node_type().to_strawberry() if hasattr(node_type, "to_strawberry") else node_type node_t = get_provider().get_service(node_type)
if not node_t:
raise ValueError(f"Node type '{node_type.__name__}' not registered in service provider")
gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type
gql_type = strawberry.type( gql_type = strawberry.type(
type( type(

View File

@@ -1,3 +1,5 @@
from typing import Type
from cpl.core.typing import T from cpl.core.typing import T
from cpl.graphql.schema.filter.bool_filter import BoolFilter from cpl.graphql.schema.filter.bool_filter import BoolFilter
from cpl.graphql.schema.filter.date_filter import DateFilter from cpl.graphql.schema.filter.date_filter import DateFilter
@@ -10,6 +12,9 @@ class Filter(Input[T]):
def __init__(self): def __init__(self):
Input.__init__(self) Input.__init__(self)
def filter_field(self, name: str, filter_type: Type["Filter"]):
self.field(name, filter_type())
def string_field(self, name: str): def string_field(self, name: str):
self.field(name, StringFilter()) self.field(name, StringFilter())

View File

@@ -5,6 +5,7 @@ import strawberry
from cpl.core.typing import T from cpl.core.typing import T
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.schema.field import Field from cpl.graphql.schema.field import Field
from cpl.graphql.utils.type_collector import TypeCollector
_PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"} _PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"}
@@ -18,12 +19,10 @@ class Input(StrawberryProtocol, Generic[T]):
def field(self, name: str, typ: Union[type, "Input"], optional: bool = True): def field(self, name: str, typ: Union[type, "Input"], optional: bool = True):
self._fields[name] = Field(name, typ, optional=optional) self._fields[name] = Field(name, typ, optional=optional)
_registry: dict[type, Type] = {}
def to_strawberry(self) -> Type: def to_strawberry(self) -> Type:
cls = self.__class__ cls = self.__class__
if cls in self._registry: if TypeCollector.has(cls):
return self._registry[cls] return TypeCollector.get(cls)
annotations = {} annotations = {}
namespace = {} namespace = {}
@@ -50,5 +49,5 @@ class Input(StrawberryProtocol, Generic[T]):
namespace["__annotations__"] = annotations namespace["__annotations__"] = annotations
gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace)) gql_type = strawberry.input(type(f"{cls.__name__}", (), namespace))
Input._registry[cls] = gql_type TypeCollector.set(cls, gql_type)
return gql_type return gql_type

View File

@@ -12,6 +12,7 @@ from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory
from cpl.graphql.schema.field import Field from cpl.graphql.schema.field import Field
from cpl.graphql.schema.sort.sort_order import SortOrder from cpl.graphql.schema.sort.sort_order import SortOrder
from cpl.graphql.typing import Resolver from cpl.graphql.typing import Resolver
from cpl.graphql.utils.type_collector import TypeCollector
class Query(StrawberryProtocol): class Query(StrawberryProtocol):
@@ -54,6 +55,9 @@ class Query(StrawberryProtocol):
def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field: def list_field(self, name: str, t: type, resolver: Resolver = None) -> Field:
return self.field(name, list[t], resolver) return self.field(name, list[t], resolver)
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
return self.field(name, t().to_strawberry(), resolver)
def with_query(self, name: str, subquery_cls: Type["Query"]): def with_query(self, name: str, subquery_cls: Type["Query"]):
sub = self._provider.get_service(subquery_cls) sub = self._provider.get_service(subquery_cls)
if not sub: if not sub:
@@ -221,6 +225,10 @@ class Query(StrawberryProtocol):
) from e ) from e
def to_strawberry(self) -> Type: def to_strawberry(self) -> Type:
cls = self.__class__
if TypeCollector.has(cls):
return TypeCollector.get(cls)
annotations: dict[str, Any] = {} annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {} namespace: dict[str, Any] = {}
@@ -229,4 +237,6 @@ class Query(StrawberryProtocol):
namespace[name] = self._field_to_strawberry(f) namespace[name] = self._field_to_strawberry(f)
namespace["__annotations__"] = annotations namespace["__annotations__"] = annotations
return strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace)) gql_type = strawberry.type(type(f"{self.__class__.__name__.replace("GraphType", "")}", (), namespace))
TypeCollector.set(cls, gql_type)
return gql_type

View File

@@ -0,0 +1,17 @@
from typing import Type
class TypeCollector:
_registry: dict[type, Type] = {}
@classmethod
def has(cls, base: type) -> bool:
return base in cls._registry
@classmethod
def get(cls, base: type) -> Type:
return cls._registry[base]
@classmethod
def set(cls, base: type, gql_type: Type):
cls._registry[base] = gql_type