diff --git a/example/api/src/main.py b/example/api/src/main.py index 4732035b..4c71bbc9 100644 --- a/example/api/src/main.py +++ b/example/api/src/main.py @@ -1,40 +1,80 @@ from starlette.responses import JSONResponse +from cpl.dependency.event_bus import EventBusABC +from cpl.graphql.event_bus.memory import InMemoryEventBus +from queries.cities import CityGraphType, CityFilter, CitySort +from queries.hello import UserGraphType # , UserFilter, UserSort, UserGraphType +from queries.user import UserFilter, UserSort from cpl.api.api_module import ApiModule -from cpl.api.application.web_app import WebApp from cpl.application.application_builder import ApplicationBuilder -from cpl.auth import AuthModule -from cpl.auth.permission.permissions import Permissions -from cpl.auth.schema import AuthUser, Role +from cpl.auth.schema import User, Role from cpl.core.configuration import Configuration from cpl.core.console import Console from cpl.core.environment import Environment from cpl.core.utils.cache import Cache from cpl.database.mysql.mysql_module import MySQLModule +from cpl.graphql.application.graphql_app import GraphQLApp +from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule +from cpl.graphql.graphql_module import GraphQLModule +from model.author_dao import AuthorDao +from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort +from model.post_dao import PostDao +from model.post_query import PostFilter, PostSort, PostGraphType, PostMutation, PostSubscription +from permissions import PostPermissions +from queries.hello import HelloQuery from scoped_service import ScopedService from service import PingService +from test_data_seeder import TestDataSeeder def main(): - builder = ApplicationBuilder[WebApp](WebApp) + builder = ApplicationBuilder[GraphQLApp](GraphQLApp) Configuration.add_json_file(f"appsettings.json") Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json") Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True) # builder.services.add_logging() - builder.services.add_structured_logging() - builder.services.add_transient(PingService) - builder.services.add_module(MySQLModule) - builder.services.add_module(ApiModule) - - builder.services.add_scoped(ScopedService) - - builder.services.add_cache(AuthUser) - builder.services.add_cache(Role) + ( + builder.services.add_structured_logging() + .add_transient(PingService) + .add_module(MySQLModule) + .add_module(ApiModule) + .add_module(GraphQLModule) + .add_module(GraphQLAuthModule) + .add_scoped(ScopedService) + .add_singleton(EventBusABC, InMemoryEventBus) + .add_cache(User) + .add_cache(Role) + .add_transient(CityGraphType) + .add_transient(CityFilter) + .add_transient(CitySort) + .add_transient(UserGraphType) + .add_transient(UserFilter) + .add_transient(UserSort) + # .add_transient(UserGraphType) + # .add_transient(UserFilter) + # .add_transient(UserSort) + .add_transient(HelloQuery) + # test data + .add_singleton(TestDataSeeder) + # authors + .add_transient(AuthorDao) + .add_transient(AuthorGraphType) + .add_transient(AuthorFilter) + .add_transient(AuthorSort) + # posts + .add_transient(PostDao) + .add_transient(PostGraphType) + .add_transient(PostFilter) + .add_transient(PostSort) + .add_transient(PostMutation) + .add_transient(PostSubscription) + ) app = builder.build() app.with_logging() + app.with_migrations("./scripts") app.with_authentication() app.with_authorization() @@ -43,13 +83,35 @@ def main(): path="/route1", fn=lambda r: JSONResponse("route1"), method="GET", - authentication=True, - permissions=[Permissions.administrator], + # authentication=True, + # permissions=[Permissions.administrator], ) app.with_routes_directory("routes") + schema = app.with_graphql() + schema.query.string_field("ping", resolver=lambda: "pong") + schema.query.with_query("hello", HelloQuery) + schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort) + ( + schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort) + # .with_require_any_permission(PostPermissions.read) + .with_public() + ) + + schema.mutation.with_mutation("post", PostMutation).with_public() + + schema.subscription.with_subscription(PostSubscription) + + app.with_auth_root_queries(True) + app.with_auth_root_mutations(True) + + app.with_playground() + app.with_graphiql() + + app.with_permissions(PostPermissions) + provider = builder.service_provider - user_cache = provider.get_service(Cache[AuthUser]) + user_cache = provider.get_service(Cache[User]) role_cache = provider.get_service(Cache[Role]) if role_cache == user_cache: diff --git a/example/api/src/model/__init__.py b/example/api/src/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/example/api/src/model/author.py b/example/api/src/model/author.py new file mode 100644 index 00000000..05e2d3e3 --- /dev/null +++ b/example/api/src/model/author.py @@ -0,0 +1,30 @@ +from datetime import datetime +from typing import Self + +from cpl.core.typing import SerialId +from cpl.database.abc import DbModelABC + + +class Author(DbModelABC[Self]): + + def __init__( + self, + id: int, + first_name: str, + last_name: str, + deleted: bool = False, + editor_id: SerialId | None = None, + created: datetime | None = None, + updated: datetime | None = None, + ): + DbModelABC.__init__(self, id, deleted, editor_id, created, updated) + self._first_name = first_name + self._last_name = last_name + + @property + def first_name(self) -> str: + return self._first_name + + @property + def last_name(self) -> str: + return self._last_name diff --git a/example/api/src/model/author_dao.py b/example/api/src/model/author_dao.py new file mode 100644 index 00000000..d1b1afc0 --- /dev/null +++ b/example/api/src/model/author_dao.py @@ -0,0 +1,11 @@ +from cpl.database.abc import DbModelDaoABC +from model.author import Author + + +class AuthorDao(DbModelDaoABC): + + def __init__(self): + DbModelDaoABC.__init__(self, Author, "authors") + + self.attribute(Author.first_name, str, db_name="firstname") + self.attribute(Author.last_name, str, db_name="lastname") \ No newline at end of file diff --git a/example/api/src/model/author_query.py b/example/api/src/model/author_query.py new file mode 100644 index 00000000..3fa4ab65 --- /dev/null +++ b/example/api/src/model/author_query.py @@ -0,0 +1,37 @@ +from cpl.graphql.schema.db_model_graph_type import DbModelGraphType +from cpl.graphql.schema.filter.db_model_filter import DbModelFilter +from cpl.graphql.schema.sort.sort import Sort +from cpl.graphql.schema.sort.sort_order import SortOrder +from model.author import Author + +class AuthorFilter(DbModelFilter[Author]): + def __init__(self): + DbModelFilter.__init__(self, public=True) + self.int_field("id") + self.string_field("firstName") + self.string_field("lastName") + +class AuthorSort(Sort[Author]): + def __init__(self): + Sort.__init__(self) + self.field("id", SortOrder) + self.field("firstName", SortOrder) + self.field("lastName", SortOrder) + +class AuthorGraphType(DbModelGraphType[Author]): + + def __init__(self): + DbModelGraphType.__init__(self, public=True) + + self.int_field( + "id", + resolver=lambda root: root.id, + ).with_public(True) + self.string_field( + "firstName", + resolver=lambda root: root.first_name, + ).with_public(True) + self.string_field( + "lastName", + resolver=lambda root: root.last_name, + ).with_public(True) diff --git a/example/api/src/model/post.py b/example/api/src/model/post.py new file mode 100644 index 00000000..15b670b8 --- /dev/null +++ b/example/api/src/model/post.py @@ -0,0 +1,44 @@ +from datetime import datetime +from typing import Self + +from cpl.core.typing import SerialId +from cpl.database.abc import DbModelABC + + +class Post(DbModelABC[Self]): + + def __init__( + self, + id: int, + author_id: SerialId, + title: str, + content: str, + deleted: bool = False, + editor_id: SerialId | None = None, + created: datetime | None = None, + updated: datetime | None = None, + ): + DbModelABC.__init__(self, id, deleted, editor_id, created, updated) + self._author_id = author_id + self._title = title + self._content = content + + @property + def author_id(self) -> SerialId: + return self._author_id + + @property + def title(self) -> str: + return self._title + + @title.setter + def title(self, value: str): + self._title = value + + @property + def content(self) -> str: + return self._content + + @content.setter + def content(self, value: str): + self._content = value diff --git a/example/api/src/model/post_dao.py b/example/api/src/model/post_dao.py new file mode 100644 index 00000000..3205f8de --- /dev/null +++ b/example/api/src/model/post_dao.py @@ -0,0 +1,15 @@ +from cpl.database.abc import DbModelDaoABC +from model.author_dao import AuthorDao +from model.post import Post + + +class PostDao(DbModelDaoABC[Post]): + + def __init__(self, authors: AuthorDao): + DbModelDaoABC.__init__(self, Post, "posts") + + self.attribute(Post.author_id, int, db_name="authorId") + self.reference("author", "id", Post.author_id, "authors", authors) + + self.attribute(Post.title, str) + self.attribute(Post.content, str) \ No newline at end of file diff --git a/example/api/src/model/post_query.py b/example/api/src/model/post_query.py new file mode 100644 index 00000000..5fe134f6 --- /dev/null +++ b/example/api/src/model/post_query.py @@ -0,0 +1,148 @@ +from cpl.dependency.event_bus import EventBusABC +from cpl.graphql.query_context import QueryContext +from cpl.graphql.schema.db_model_graph_type import DbModelGraphType +from cpl.graphql.schema.filter.db_model_filter import DbModelFilter +from cpl.graphql.schema.input import Input +from cpl.graphql.schema.mutation import Mutation +from cpl.graphql.schema.sort.sort import Sort +from cpl.graphql.schema.sort.sort_order import SortOrder +from cpl.graphql.schema.subscription import Subscription +from model.author_dao import AuthorDao +from model.author_query import AuthorGraphType, AuthorFilter +from model.post import Post +from model.post_dao import PostDao + + +class PostFilter(DbModelFilter[Post]): + def __init__(self): + DbModelFilter.__init__(self, public=True) + self.int_field("id") + self.filter_field("author", AuthorFilter) + self.string_field("title") + self.string_field("content") + + +class PostSort(Sort[Post]): + def __init__(self): + Sort.__init__(self) + self.field("id", SortOrder) + self.field("title", SortOrder) + self.field("content", SortOrder) + + +class PostGraphType(DbModelGraphType[Post]): + + def __init__(self, authors: AuthorDao): + DbModelGraphType.__init__(self, public=True) + + self.int_field( + "id", + resolver=lambda root: root.id, + ).with_optional().with_public(True) + + async def _a(root: Post): + return await authors.get_by_id(root.author_id) + + def r_name(ctx: QueryContext): + return ctx.user.username == "admin" + + self.object_field("author", AuthorGraphType, resolver=_a).with_public(True) # .with_require_any([], [r_name])) + self.string_field( + "title", + resolver=lambda root: root.title, + ).with_public(True) + self.string_field( + "content", + resolver=lambda root: root.content, + ).with_public(True) + + +class PostCreateInput(Input[Post]): + title: str + content: str + author_id: int + + def __init__(self): + Input.__init__(self) + self.string_field("title").with_required() + self.string_field("content").with_required() + self.int_field("author_id").with_required() + + +class PostUpdateInput(Input[Post]): + title: str + content: str + author_id: int + + def __init__(self): + Input.__init__(self) + self.int_field("id").with_required() + self.string_field("title").with_required(False) + self.string_field("content").with_required(False) + + +class PostSubscription(Subscription): + def __init__(self, bus: EventBusABC): + Subscription.__init__(self) + self._bus = bus + + def selector(event: Post, info) -> bool: + return event.id == 101 + + self.subscription_field("postChange", PostGraphType, selector).with_public() + + +class PostMutation(Mutation): + + def __init__(self, posts: PostDao, authors: AuthorDao, bus: EventBusABC): + Mutation.__init__(self) + + self._posts = posts + self._authors = authors + self._bus = bus + + self.field("create", int, resolver=self.create_post).with_public().with_required().with_argument( + "input", + PostCreateInput, + ).with_required() + self.field("update", bool, resolver=self.update_post).with_public().with_required().with_argument( + "input", + PostUpdateInput, + ).with_required() + self.field("delete", bool, resolver=self.delete_post).with_public().with_required().with_argument( + "id", + int, + ).with_required() + self.field("restore", bool, resolver=self.restore_post).with_public().with_required().with_argument( + "id", + int, + ).with_required() + + async def create_post(self, input: PostCreateInput) -> int: + return await self._posts.create(Post(0, input.author_id, input.title, input.content)) + + async def update_post(self, input: PostUpdateInput) -> bool: + post = await self._posts.get_by_id(input.id) + if post is None: + return False + + post.title = input.title if input.title is not None else post.title + post.content = input.content if input.content is not None else post.content + + await self._posts.update(post) + await self._bus.publish("postChange", post) + return True + + async def delete_post(self, id: int) -> bool: + post = await self._posts.get_by_id(id) + if post is None: + return False + await self._posts.delete(post) + return True + + async def restore_post(self, id: int) -> bool: + post = await self._posts.get_by_id(id) + if post is None: + return False + await self._posts.restore(post) + return True diff --git a/example/api/src/permissions.py b/example/api/src/permissions.py new file mode 100644 index 00000000..d2e1d450 --- /dev/null +++ b/example/api/src/permissions.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class PostPermissions(Enum): + + read = "post.read" + write = "post.write" + delete = "post.delete" \ No newline at end of file diff --git a/example/api/src/queries/__init__.py b/example/api/src/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/example/api/src/queries/cities.py b/example/api/src/queries/cities.py new file mode 100644 index 00000000..7fd88273 --- /dev/null +++ b/example/api/src/queries/cities.py @@ -0,0 +1,39 @@ +from cpl.graphql.schema.filter.filter import Filter +from cpl.graphql.schema.graph_type import GraphType + +from cpl.graphql.schema.sort.sort import Sort +from cpl.graphql.schema.sort.sort_order import SortOrder + + +class City: + def __init__(self, id: int, name: str): + self.id = id + self.name = name + + +class CityFilter(Filter[City]): + def __init__(self): + Filter.__init__(self) + self.field("id", int) + self.field("name", str) + + +class CitySort(Sort[City]): + def __init__(self): + Sort.__init__(self) + self.field("id", SortOrder) + self.field("name", SortOrder) + + +class CityGraphType(GraphType[City]): + def __init__(self): + GraphType.__init__(self) + + self.int_field( + "id", + resolver=lambda root: root.id, + ) + self.string_field( + "name", + resolver=lambda root: root.name, + ) diff --git a/example/api/src/queries/hello.py b/example/api/src/queries/hello.py new file mode 100644 index 00000000..19a1f774 --- /dev/null +++ b/example/api/src/queries/hello.py @@ -0,0 +1,70 @@ +from queries.cities import CityFilter, CitySort, CityGraphType, City +from queries.user import User, UserFilter, UserSort, UserGraphType +from cpl.api.middleware.request import get_request +from cpl.auth.schema import UserDao, User +from cpl.graphql.schema.filter.filter import Filter +from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.schema.query import Query +from cpl.graphql.schema.sort.sort import Sort +from cpl.graphql.schema.sort.sort_order import SortOrder + +users = [User(i, f"User {i}") for i in range(1, 101)] +cities = [City(i, f"City {i}") for i in range(1, 101)] + +# class UserFilter(Filter[User]): +# def __init__(self): +# Filter.__init__(self) +# self.field("id", int) +# self.field("username", str) +# +# +# class UserSort(Sort[User]): +# def __init__(self): +# Sort.__init__(self) +# self.field("id", SortOrder) +# self.field("username", SortOrder) +# +# class UserGraphType(GraphType[User]): +# +# def __init__(self): +# GraphType.__init__(self) +# +# self.int_field( +# "id", +# resolver=lambda root: root.id, +# ) +# self.string_field( +# "username", +# resolver=lambda root: root.username, +# ) + + +class HelloQuery(Query): + def __init__(self): + Query.__init__(self) + self.string_field( + "message", + resolver=lambda name: f"Hello {name} {get_request().state.request_id}", + ).with_argument("name", str, "Name to greet", "world") + + self.collection_field( + UserGraphType, + "users", + UserFilter, + UserSort, + resolver=lambda: users, + ) + self.collection_field( + CityGraphType, + "cities", + CityFilter, + CitySort, + resolver=lambda: cities, + ) + # self.dao_collection_field( + # UserGraphType, + # UserDao, + # "Users", + # UserFilter, + # UserSort, + # ) diff --git a/example/api/src/queries/user.py b/example/api/src/queries/user.py new file mode 100644 index 00000000..a35a1780 --- /dev/null +++ b/example/api/src/queries/user.py @@ -0,0 +1,39 @@ +from cpl.graphql.schema.filter.filter import Filter +from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.schema.sort.sort import Sort +from cpl.graphql.schema.sort.sort_order import SortOrder + + +class User: + def __init__(self, id: int, name: str): + self.id = id + self.name = name + + +class UserFilter(Filter[User]): + def __init__(self): + Filter.__init__(self) + self.field("id", int) + self.field("name", str) + + +class UserSort(Sort[User]): + def __init__(self): + Sort.__init__(self) + self.field("id", SortOrder) + self.field("name", SortOrder) + + +class UserGraphType(GraphType[User]): + + def __init__(self): + GraphType.__init__(self) + + self.int_field( + "id", + resolver=lambda root: root.id, + ) + self.string_field( + "name", + resolver=lambda root: root.name, + ) diff --git a/example/api/src/scripts/0-posts.sql b/example/api/src/scripts/0-posts.sql new file mode 100644 index 00000000..26268f17 --- /dev/null +++ b/example/api/src/scripts/0-posts.sql @@ -0,0 +1,22 @@ +CREATE TABLE IF NOT EXISTS `authors` ( + `id` INT(30) NOT NULL AUTO_INCREMENT, + `firstname` VARCHAR(64) NOT NULL, + `lastname` VARCHAR(64) NOT NULL, + deleted BOOLEAN NOT NULL DEFAULT FALSE, + editorId INT NULL, + created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY(`id`) + ); + +CREATE TABLE IF NOT EXISTS `posts` ( + `id` INT(30) NOT NULL AUTO_INCREMENT, + `authorId` INT(30) NOT NULL REFERENCES `authors`(`id`) ON DELETE CASCADE, + `title` TEXT NOT NULL, + `content` TEXT NOT NULL, + deleted BOOLEAN NOT NULL DEFAULT FALSE, + editorId INT NULL, + created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY(`id`) +); \ No newline at end of file diff --git a/example/api/src/test_data_seeder.py b/example/api/src/test_data_seeder.py new file mode 100644 index 00000000..38bcc1f1 --- /dev/null +++ b/example/api/src/test_data_seeder.py @@ -0,0 +1,48 @@ +from faker import Faker + +from cpl.database.abc import DataSeederABC +from cpl.query import Enumerable +from model.author import Author +from model.author_dao import AuthorDao +from model.post import Post +from model.post_dao import PostDao + + +fake = Faker() + + +class TestDataSeeder(DataSeederABC): + + def __init__(self, authors: AuthorDao, posts: PostDao): + DataSeederABC.__init__(self) + + self._authors = authors + self._posts = posts + + async def seed(self): + if await self._authors.count() == 0: + await self._seed_authors() + + if await self._posts.count() == 0: + await self._seed_posts() + + async def _seed_authors(self): + authors = Enumerable.range(0, 35).select( + lambda x: Author( + 0, + fake.first_name(), + fake.last_name(), + ) + ).to_list() + await self._authors.create_many(authors, skip_editor=True) + + async def _seed_posts(self): + posts = Enumerable.range(0, 100).select( + lambda x: Post( + id=0, + author_id=fake.random_int(min=1, max=35), + title=fake.sentence(nb_words=6), + content=fake.paragraph(nb_sentences=6), + ) + ).to_list() + await self._posts.create_many(posts, skip_editor=True) diff --git a/example/database/src/model/city.py b/example/database/src/model/city.py index c98bef85..2d61f92f 100644 --- a/example/database/src/model/city.py +++ b/example/database/src/model/city.py @@ -5,16 +5,16 @@ from cpl.core.typing import SerialId from cpl.database.abc.db_model_abc import DbModelABC -class City(DbModelABC): +class City(DbModelABC[Self]): def __init__( self, id: int, name: str, zip: str, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None= None, + updated: datetime | None= None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._name = name diff --git a/example/database/src/model/user.py b/example/database/src/model/user.py index 445c56b7..e0116423 100644 --- a/example/database/src/model/user.py +++ b/example/database/src/model/user.py @@ -5,7 +5,7 @@ from cpl.core.typing import SerialId from cpl.database.abc.db_model_abc import DbModelABC -class User(DbModelABC): +class User(DbModelABC[Self]): def __init__( self, @@ -13,9 +13,9 @@ class User(DbModelABC): name: str, city_id: int = 0, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None= None, + updated: datetime | None= None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._name = name diff --git a/src/cpl-api/cpl/api/abc/web_app_abc.py b/src/cpl-api/cpl/api/abc/web_app_abc.py new file mode 100644 index 00000000..fa7eec6e --- /dev/null +++ b/src/cpl-api/cpl/api/abc/web_app_abc.py @@ -0,0 +1,45 @@ +from abc import ABC +from enum import Enum +from typing import Self + +from starlette.applications import Starlette + +from cpl.api.model.api_route import ApiRoute +from cpl.api.model.validation_match import ValidationMatch +from cpl.api.typing import HTTPMethods, PartialMiddleware, TEndpoint, PolicyInput +from cpl.application.abc.application_abc import ApplicationABC +from cpl.dependency.service_provider import ServiceProvider +from cpl.dependency.typing import Modules + + +class WebAppABC(ApplicationABC, ABC): + + def __init__(self, services: ServiceProvider, modules: Modules, required_modules: list[str | object] = None): + ApplicationABC.__init__(self, services, modules, required_modules) + + def with_routes_directory(self, directory: str) -> Self: ... + def with_app(self, app: Starlette) -> Self: ... + def with_routes( + self, + routes: list[ApiRoute], + method: HTTPMethods, + authentication: bool = False, + roles: list[str | Enum] = None, + permissions: list[str | Enum] = None, + policies: list[str] = None, + match: ValidationMatch = None, + ) -> Self: ... + def with_route( + self, + path: str, + fn: TEndpoint, + method: HTTPMethods, + authentication: bool = False, + roles: list[str | Enum] = None, + permissions: list[str | Enum] = None, + policies: list[str] = None, + match: ValidationMatch = None, + ) -> Self: ... + def with_middleware(self, middleware: PartialMiddleware) -> Self: ... + def with_authentication(self) -> Self: ... + def with_authorization(self, *policies: list[PolicyInput] | PolicyInput) -> Self: ... diff --git a/src/cpl-api/cpl/api/application/web_app.py b/src/cpl-api/cpl/api/application/web_app.py index 476e54d2..f94694f9 100644 --- a/src/cpl-api/cpl/api/application/web_app.py +++ b/src/cpl-api/cpl/api/application/web_app.py @@ -1,6 +1,6 @@ import os from enum import Enum -from typing import Mapping, Any, Callable, Self, Union +from typing import Mapping, Any, Self import uvicorn from starlette.applications import Starlette @@ -10,6 +10,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse from starlette.types import ExceptionHandler +from cpl.api.abc.web_app_abc import WebAppABC from cpl.api.api_module import ApiModule from cpl.api.error import APIError from cpl.api.logger import APILogger @@ -24,8 +25,7 @@ from cpl.api.registry.policy import PolicyRegistry from cpl.api.registry.route import RouteRegistry from cpl.api.router import Router from cpl.api.settings import ApiSettings -from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver -from cpl.application.abc.application_abc import ApplicationABC +from cpl.api.typing import HTTPMethods, PartialMiddleware, TEndpoint, PolicyInput from cpl.auth.auth_module import AuthModule from cpl.auth.permission.permission_module import PermissionsModule from cpl.core.configuration.configuration import Configuration @@ -33,12 +33,12 @@ from cpl.dependency.inject import inject from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.typing import Modules -PolicyInput = Union[dict[str, PolicyResolver], Policy] - -class WebApp(ApplicationABC): - def __init__(self, services: ServiceProvider, modules: Modules): - super().__init__(services, modules, [AuthModule, PermissionsModule, ApiModule]) +class WebApp(WebAppABC): + def __init__(self, services: ServiceProvider, modules: Modules, required_modules: list[str | object] = None): + WebAppABC.__init__( + self, services, modules, [AuthModule, PermissionsModule, ApiModule] + (required_modules or []) + ) self._app: Starlette | None = None self._logger = services.get_service(APILogger) @@ -78,16 +78,17 @@ class WebApp(ApplicationABC): self._logger.debug(f"Allowed origins: {origins}") return origins.split(",") - def with_app(self, app: Starlette) -> Self: - assert app is not None, "app must not be None" - assert isinstance(app, Starlette), "app must be an instance of Starlette" - self._app = app - return self - def _check_for_app(self): if self._app is not None: raise ValueError("App is already set, cannot add routes or middleware") + def _validate_policies(self): + for rule in Router.get_authorization_rules(): + for policy_name in rule["policies"]: + policy = self._policies.get(policy_name) + if not policy: + self._logger.fatal(f"Authorization policy '{policy_name}' not found") + def with_routes_directory(self, directory: str) -> Self: self._check_for_app() assert directory is not None, "directory must not be None" @@ -102,6 +103,12 @@ class WebApp(ApplicationABC): return self + def with_app(self, app: Starlette) -> Self: + assert app is not None, "app must not be None" + assert isinstance(app, Starlette), "app must be an instance of Starlette" + self._app = app + return self + def with_routes( self, routes: list[ApiRoute], @@ -131,7 +138,7 @@ class WebApp(ApplicationABC): def with_route( self, path: str, - fn: Callable[[Request], Any], + fn: TEndpoint, method: HTTPMethods, authentication: bool = False, roles: list[str | Enum] = None, @@ -162,6 +169,30 @@ class WebApp(ApplicationABC): return self + def with_websocket( + self, + path: str, + fn: TEndpoint, + authentication: bool = False, + roles: list[str | Enum] = None, + permissions: list[str | Enum] = None, + policies: list[str] = None, + match: ValidationMatch = None, + ) -> Self: + self._check_for_app() + assert path is not None, "path must not be None" + assert fn is not None, "fn must not be None" + + Router.websocket(path, registry=self._routes)(fn) + + if authentication: + Router.authenticate()(fn) + + if roles or permissions or policies: + Router.authorize(roles, permissions, policies, match)(fn) + + return self + def with_middleware(self, middleware: PartialMiddleware) -> Self: self._check_for_app() @@ -179,6 +210,7 @@ class WebApp(ApplicationABC): return self def with_authorization(self, *policies: list[PolicyInput] | PolicyInput) -> Self: + self._check_for_app() if policies: _policies = [] @@ -206,12 +238,8 @@ class WebApp(ApplicationABC): self.with_middleware(AuthorizationMiddleware) return self - def _validate_policies(self): - for rule in Router.get_authorization_rules(): - for policy_name in rule["policies"]: - policy = self._policies.get(policy_name) - if not policy: - self._logger.fatal(f"Authorization policy '{policy_name}' not found") + async def _log_before_startup(self): + self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}") async def main(self): self._logger.debug(f"Preparing API") @@ -236,7 +264,7 @@ class WebApp(ApplicationABC): else: app = self._app - self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}") + await self._log_before_startup() config = uvicorn.Config( app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio" diff --git a/src/cpl-api/cpl/api/error.py b/src/cpl-api/cpl/api/error.py index 50329e98..8fad7e5e 100644 --- a/src/cpl-api/cpl/api/error.py +++ b/src/cpl-api/cpl/api/error.py @@ -8,7 +8,7 @@ class APIError(HTTPException): status_code = 500 def __init__(self, message: str = ""): - super().__init__(self.status_code, message) + HTTPException.__init__(self, self.status_code, message) self._message = message @property diff --git a/src/cpl-api/cpl/api/middleware/authentication.py b/src/cpl-api/cpl/api/middleware/authentication.py index c0dc95f1..8b40cdd1 100644 --- a/src/cpl-api/cpl/api/middleware/authentication.py +++ b/src/cpl-api/cpl/api/middleware/authentication.py @@ -7,13 +7,13 @@ from cpl.api.logger import APILogger from cpl.api.middleware.request import get_request from cpl.api.router import Router from cpl.auth.keycloak import KeycloakClient -from cpl.auth.schema import AuthUserDao, AuthUser +from cpl.auth.schema import UserDao, User from cpl.core.ctx import set_user class AuthenticationMiddleware(ASGIMiddleware): - def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao): + def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao): ASGIMiddleware.__init__(self, app) self._logger = logger @@ -25,6 +25,21 @@ class AuthenticationMiddleware(ASGIMiddleware): request = get_request() url = request.url.path + if url not in Router.get_auth_required_routes(): + self._logger.trace(f"No authentication required for {url}") + return await self._app(scope, receive, send) + + user = getattr(request.state, "user", None) + if not user or user.deleted: + self._logger.debug(f"Unauthorized access to {url}, user missing or deleted") + return await Unauthorized("Unauthorized").asgi_response(scope, receive, send) + + return await self._call_next(scope, receive, send) + + async def _old_call__(self, scope: Scope, receive: Receive, send: Send): + request = get_request() + url = request.url.path + if url not in Router.get_auth_required_routes(): self._logger.trace(f"No authentication required for {url}") return await self._app(scope, receive, send) @@ -57,12 +72,12 @@ class AuthenticationMiddleware(ASGIMiddleware): return await self._call_next(scope, receive, send) - async def _get_or_crate_user(self, keycloak_id: str) -> AuthUser: + async def _get_or_crate_user(self, keycloak_id: str) -> User: existing = await self._user_dao.find_by_keycloak_id(keycloak_id) if existing is not None: return existing - user = AuthUser(0, keycloak_id) + user = User(0, keycloak_id) uid = await self._user_dao.create(user) return await self._user_dao.get_by_id(uid) diff --git a/src/cpl-api/cpl/api/middleware/authorization.py b/src/cpl-api/cpl/api/middleware/authorization.py index b0b0d18c..64347cdc 100644 --- a/src/cpl-api/cpl/api/middleware/authorization.py +++ b/src/cpl-api/cpl/api/middleware/authorization.py @@ -7,13 +7,13 @@ from cpl.api.middleware.request import get_request from cpl.api.model.validation_match import ValidationMatch from cpl.api.registry.policy import PolicyRegistry from cpl.api.router import Router -from cpl.auth.schema._administration.auth_user_dao import AuthUserDao +from cpl.auth.schema._administration.user_dao import UserDao from cpl.core.ctx.user_context import get_user class AuthorizationMiddleware(ASGIMiddleware): - def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao): + def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: UserDao): ASGIMiddleware.__init__(self, app) self._logger = logger diff --git a/src/cpl-api/cpl/api/middleware/request.py b/src/cpl-api/cpl/api/middleware/request.py index 0cedc88b..d5e73721 100644 --- a/src/cpl-api/cpl/api/middleware/request.py +++ b/src/cpl-api/cpl/api/middleware/request.py @@ -5,10 +5,15 @@ from uuid import uuid4 from starlette.requests import Request from starlette.types import Scope, Receive, Send +from starlette.websockets import WebSocket from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.logger import APILogger from cpl.api.typing import TRequest +from cpl.auth.keycloak.keycloak_client import KeycloakClient +from cpl.auth.schema import User +from cpl.auth.schema._administration.user_dao import UserDao +from cpl.core.ctx import set_user from cpl.dependency.inject import inject from cpl.dependency.service_provider import ServiceProvider @@ -17,19 +22,23 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa class RequestMiddleware(ASGIMiddleware): - def __init__(self, app, provider: ServiceProvider, logger: APILogger): + def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao): ASGIMiddleware.__init__(self, app) self._provider = provider self._logger = logger + self._keycloak = keycloak + self._user_dao = user_dao + self._ctx_token = None async def __call__(self, scope: Scope, receive: Receive, send: Send): - request = Request(scope, receive, send) + request = Request(scope, receive, send) if scope["type"] != "websocket" else WebSocket(scope, receive, send) await self.set_request_data(request) try: + await self._try_set_user(request) with self._provider.create_scope(): inject(await self._app(scope, receive, send)) finally: @@ -53,6 +62,37 @@ class RequestMiddleware(ASGIMiddleware): self._logger.trace(f"Clearing current request: {request.state.request_id}") _request_context.reset(self._ctx_token) + async def _try_set_user(self, request: Request): + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + return + + token = auth_header.split("Bearer ")[1] + try: + token_info = self._keycloak.introspect(token) + if not token_info.get("active", False): + return + + keycloak_id = self._keycloak.get_user_id(token) + if not keycloak_id: + return + + user = await self._user_dao.find_by_keycloak_id(keycloak_id) + if not user: + user = User(0, keycloak_id) + uid = await self._user_dao.create(user) + user = await self._user_dao.get_by_id(uid) + + if user.deleted: + return + + request.state.user = user + set_user(user) + self._logger.trace(f"User {user.id} bound to request {request.state.request_id}") + + except Exception as e: + self._logger.debug(f"Silent user binding failed: {e}") + def get_request() -> Optional[TRequest]: return _request_context.get() diff --git a/src/cpl-api/cpl/api/model/websocket_route.py b/src/cpl-api/cpl/api/model/websocket_route.py new file mode 100644 index 00000000..3c09ca3f --- /dev/null +++ b/src/cpl-api/cpl/api/model/websocket_route.py @@ -0,0 +1,31 @@ +from typing import Callable + +import starlette.routing + + +class WebSocketRoute: + + def __init__(self, path: str, fn: Callable, **kwargs): + self._path = path + self._fn = fn + + self._kwargs = kwargs + + @property + def name(self) -> str: + return self._fn.__name__ + + @property + def fn(self) -> Callable: + return self._fn + + @property + def path(self) -> str: + return self._path + + @property + def kwargs(self) -> dict: + return self._kwargs + + def to_starlette(self, *args) -> starlette.routing.WebSocketRoute: + return starlette.routing.WebSocketRoute(self._path, self._fn) diff --git a/src/cpl-api/cpl/api/registry/route.py b/src/cpl-api/cpl/api/registry/route.py index e030007b..83ce7862 100644 --- a/src/cpl-api/cpl/api/registry/route.py +++ b/src/cpl-api/cpl/api/registry/route.py @@ -1,32 +1,35 @@ -from typing import Optional +from typing import Optional, Union from cpl.api.model.api_route import ApiRoute +from cpl.api.model.websocket_route import WebSocketRoute from cpl.core.abc.registry_abc import RegistryABC +TRoute = Union[ApiRoute, WebSocketRoute] + class RouteRegistry(RegistryABC): def __init__(self): RegistryABC.__init__(self) - def extend(self, items: list[ApiRoute]): + def extend(self, items: list[TRoute]): for policy in items: self.add(policy) - def add(self, item: ApiRoute): - assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute" + def add(self, item: TRoute): + assert isinstance(item, (ApiRoute, WebSocketRoute)), "route must be an instance of ApiRoute" if item.path in self._items: raise ValueError(f"ApiRoute {item.path} is already registered") self._items[item.path] = item - def set(self, item: ApiRoute): + def set(self, item: TRoute): assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute" self._items[item.path] = item - def get(self, key: str) -> Optional[ApiRoute]: + def get(self, key: str) -> Optional[TRoute]: return self._items.get(key) - def all(self) -> list[ApiRoute]: + def all(self) -> list[TRoute]: return list(self._items.values()) diff --git a/src/cpl-api/cpl/api/router.py b/src/cpl-api/cpl/api/router.py index 27dfd5ab..55369c38 100644 --- a/src/cpl-api/cpl/api/router.py +++ b/src/cpl-api/cpl/api/router.py @@ -91,6 +91,22 @@ class Router: return inner + @classmethod + def websocket(cls, path: str, registry: RouteRegistry = None, **kwargs): + from cpl.api.model.websocket_route import WebSocketRoute + + if not registry: + routes = get_provider().get_service(RouteRegistry) + else: + routes = registry + + def inner(fn): + routes.add(WebSocketRoute(path, fn, **kwargs)) + setattr(fn, "_route_path", path) + return fn + + return inner + @classmethod def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs): from cpl.api.model.api_route import ApiRoute diff --git a/src/cpl-api/cpl/api/settings.py b/src/cpl-api/cpl/api/settings.py index 2f11f5d7..900c2dd2 100644 --- a/src/cpl-api/cpl/api/settings.py +++ b/src/cpl-api/cpl/api/settings.py @@ -6,7 +6,7 @@ from cpl.core.configuration import ConfigurationModelABC class ApiSettings(ConfigurationModelABC): def __init__(self, src: Optional[dict] = None): - super().__init__(src) + ConfigurationModelABC.__init__(self, src) self.option("host", str, "0.0.0.0") self.option("port", int, 5000) diff --git a/src/cpl-api/cpl/api/typing.py b/src/cpl-api/cpl/api/typing.py index c8319900..8d5f0c73 100644 --- a/src/cpl-api/cpl/api/typing.py +++ b/src/cpl-api/cpl/api/typing.py @@ -2,13 +2,15 @@ from typing import Union, Literal, Callable, Type, Awaitable from urllib.request import Request from starlette.middleware import Middleware +from starlette.responses import Response from starlette.types import ASGIApp from starlette.websockets import WebSocket from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware -from cpl.auth.schema import AuthUser +from cpl.auth.schema import User TRequest = Union[Request, WebSocket] +TEndpoint = Callable[[TRequest, ...], Awaitable[Response]] | Callable[[TRequest, ...], Response] HTTPMethods = Literal["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"] PartialMiddleware = Union[ ASGIMiddleware, @@ -16,4 +18,5 @@ PartialMiddleware = Union[ Middleware, Callable[[ASGIApp], ASGIApp], ] -PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]] +PolicyResolver = Callable[[User], bool | Awaitable[bool]] +PolicyInput = Union[dict[str, PolicyResolver], "Policy"] diff --git a/src/cpl-application/cpl/application/abc/application_abc.py b/src/cpl-application/cpl/application/abc/application_abc.py index 59c43b88..a90db406 100644 --- a/src/cpl-application/cpl/application/abc/application_abc.py +++ b/src/cpl-application/cpl/application/abc/application_abc.py @@ -56,7 +56,7 @@ class ApplicationABC(ABC): module_dependency_error( type(self).__name__, - module.__name__, + module.__name__ if not isinstance(module, str) else module, ImportError( f"Required module '{module}' for application '{self.__class__.__name__}' is not loaded. Load using 'add_module({module})' method." ), diff --git a/src/cpl-auth/cpl/auth/auth_module.py b/src/cpl-auth/cpl/auth/auth_module.py index ea2b8582..aa1f7bef 100644 --- a/src/cpl-auth/cpl/auth/auth_module.py +++ b/src/cpl-auth/cpl/auth/auth_module.py @@ -12,7 +12,7 @@ from cpl.dependency.service_provider import ServiceProvider from .keycloak.keycloak_admin import KeycloakAdmin from .keycloak.keycloak_client import KeycloakClient from .schema._administration.api_key_dao import ApiKeyDao -from .schema._administration.auth_user_dao import AuthUserDao +from .schema._administration.user_dao import UserDao from .schema._permission.api_key_permission_dao import ApiKeyPermissionDao from .schema._permission.permission_dao import PermissionDao from .schema._permission.role_dao import RoleDao @@ -26,7 +26,7 @@ class AuthModule(Module): singleton = [ KeycloakClient, KeycloakAdmin, - AuthUserDao, + UserDao, ApiKeyDao, ApiKeyPermissionDao, PermissionDao, diff --git a/src/cpl-auth/cpl/auth/permission/permission_module.py b/src/cpl-auth/cpl/auth/permission/permission_module.py index 16955c57..eafaeadc 100644 --- a/src/cpl-auth/cpl/auth/permission/permission_module.py +++ b/src/cpl-auth/cpl/auth/permission/permission_module.py @@ -2,6 +2,7 @@ from cpl.auth.auth_module import AuthModule from cpl.auth.permission.permission_seeder import PermissionSeeder from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions_registry import PermissionsRegistry +from cpl.auth.permission.role_seeder import RoleSeeder from cpl.database.abc.data_seeder_abc import DataSeederABC from cpl.database.database_module import DatabaseModule from cpl.dependency.module.module import Module @@ -10,7 +11,7 @@ from cpl.dependency.service_collection import ServiceCollection class PermissionsModule(Module): dependencies = [DatabaseModule, AuthModule] - singleton = [(DataSeederABC, PermissionSeeder)] + transient = [(DataSeederABC, PermissionSeeder), (DataSeederABC, RoleSeeder)] @staticmethod def register(collection: ServiceCollection): diff --git a/src/cpl-auth/cpl/auth/permission/permission_seeder.py b/src/cpl-auth/cpl/auth/permission/permission_seeder.py index d9d42cfa..aab41139 100644 --- a/src/cpl-auth/cpl/auth/permission/permission_seeder.py +++ b/src/cpl-auth/cpl/auth/permission/permission_seeder.py @@ -1,4 +1,3 @@ -from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions_registry import PermissionsRegistry from cpl.auth.schema import ( Permission, diff --git a/src/cpl-auth/cpl/auth/permission/role_seeder.py b/src/cpl-auth/cpl/auth/permission/role_seeder.py new file mode 100644 index 00000000..b6a2db43 --- /dev/null +++ b/src/cpl-auth/cpl/auth/permission/role_seeder.py @@ -0,0 +1,60 @@ +from cpl.auth.schema import ( + Role, + RolePermission, + PermissionDao, + RoleDao, + RolePermissionDao, + ApiKeyDao, + ApiKeyPermissionDao, + UserDao, + RoleUserDao, + RoleUser, +) +from cpl.database.abc.data_seeder_abc import DataSeederABC +from cpl.database.logger import DBLogger + + +class RoleSeeder(DataSeederABC): + def __init__( + self, + logger: DBLogger, + permission_dao: PermissionDao, + role_dao: RoleDao, + role_permission_dao: RolePermissionDao, + api_key_dao: ApiKeyDao, + api_key_permission_dao: ApiKeyPermissionDao, + user_dao: UserDao, + role_user_dao: RoleUserDao, + ): + DataSeederABC.__init__(self) + self._logger = logger + self._permission_dao = permission_dao + self._role_dao = role_dao + self._role_permission_dao = role_permission_dao + self._api_key_dao = api_key_dao + self._api_key_permission_dao = api_key_permission_dao + self._user_dao = user_dao + self._role_user_dao = role_user_dao + + async def seed(self): + self._logger.info("Creating admin role") + roles = await self._role_dao.get_all() + if len(roles) == 0: + rid = await self._role_dao.create(Role(0, "admin", "Default admin role")) + permissions = await self._permission_dao.get_all() + + await self._role_permission_dao.create_many( + [RolePermission(0, rid, permission.id) for permission in permissions] + ) + + role = await self._role_dao.get_by_name("admin") + if len(await role.users) > 0: + return + + users = await self._user_dao.get_all() + if len(users) == 0: + return + + user = users[0] + self._logger.warning(f"Assigning admin role to first user {user.id}") + await self._role_user_dao.create(RoleUser(0, role.id, user.id)) diff --git a/src/cpl-auth/cpl/auth/schema/__init__.py b/src/cpl-auth/cpl/auth/schema/__init__.py index cdb4b9d1..af3373ee 100644 --- a/src/cpl-auth/cpl/auth/schema/__init__.py +++ b/src/cpl-auth/cpl/auth/schema/__init__.py @@ -1,7 +1,7 @@ from ._administration.api_key import ApiKey from ._administration.api_key_dao import ApiKeyDao -from ._administration.auth_user import AuthUser -from ._administration.auth_user_dao import AuthUserDao +from ._administration.user import User +from ._administration.user_dao import UserDao from ._permission.api_key_permission import ApiKeyPermission from ._permission.api_key_permission_dao import ApiKeyPermissionDao diff --git a/src/cpl-auth/cpl/auth/schema/_administration/api_key.py b/src/cpl-auth/cpl/auth/schema/_administration/api_key.py index 16f57a7d..9a6d5f6c 100644 --- a/src/cpl-auth/cpl/auth/schema/_administration/api_key.py +++ b/src/cpl-auth/cpl/auth/schema/_administration/api_key.py @@ -1,6 +1,6 @@ import secrets from datetime import datetime -from typing import Optional, Union +from typing import Optional, Union, Self from async_property import async_property @@ -16,7 +16,7 @@ from cpl.dependency.service_provider import ServiceProvider _logger = Logger(__name__) -class ApiKey(DbModelABC): +class ApiKey(DbModelABC[Self]): def __init__( self, @@ -25,8 +25,8 @@ class ApiKey(DbModelABC): key: Union[str, bytes], deleted: bool = False, editor_id: Optional[Id] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + created: datetime | None = None, + updated: datetime | None = None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._identifier = identifier diff --git a/src/cpl-auth/cpl/auth/schema/_administration/auth_user.py b/src/cpl-auth/cpl/auth/schema/_administration/user.py similarity index 72% rename from src/cpl-auth/cpl/auth/schema/_administration/auth_user.py rename to src/cpl-auth/cpl/auth/schema/_administration/user.py index cae14f97..f20740e6 100644 --- a/src/cpl-auth/cpl/auth/schema/_administration/auth_user.py +++ b/src/cpl-auth/cpl/auth/schema/_administration/user.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime -from typing import Optional +from typing import Optional, Self from async_property import async_property from keycloak import KeycloakGetError @@ -10,18 +10,18 @@ from cpl.auth.permission.permissions import Permissions from cpl.core.typing import SerialId from cpl.database.abc import DbModelABC from cpl.database.logger import DBLogger -from cpl.dependency import ServiceProvider +from cpl.dependency import get_provider -class AuthUser(DbModelABC): +class User(DbModelABC[Self]): def __init__( self, id: SerialId, keycloak_id: str, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None = None, + updated: datetime | None = None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._keycloak_id = keycloak_id @@ -69,21 +69,21 @@ class AuthUser(DbModelABC): @async_property async def permissions(self): - from cpl.auth.schema._administration.auth_user_dao import AuthUserDao + from cpl.auth.schema._administration.user_dao import UserDao - auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) - return await auth_user_dao.get_permissions(self.id) + user_dao: UserDao = get_provider().get_service(UserDao) + return await user_dao.get_permissions(self.id) async def has_permission(self, permission: Permissions) -> bool: - from cpl.auth.schema._administration.auth_user_dao import AuthUserDao + from cpl.auth.schema._administration.user_dao import UserDao - auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) - return await auth_user_dao.has_permission(self.id, permission) + user_dao: UserDao = get_provider().get_service(UserDao) + return await user_dao.has_permission(self.id, permission) async def anonymize(self): - from cpl.auth.schema._administration.auth_user_dao import AuthUserDao + from cpl.auth.schema._administration.user_dao import UserDao - auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) + user_dao: UserDao = get_provider().get_service(UserDao) self._keycloak_id = str(uuid.UUID(int=0)) - await auth_user_dao.update(self) + await user_dao.update(self) diff --git a/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py b/src/cpl-auth/cpl/auth/schema/_administration/user_dao.py similarity index 73% rename from src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py rename to src/cpl-auth/cpl/auth/schema/_administration/user_dao.py index 8963259f..206ab553 100644 --- a/src/cpl-auth/cpl/auth/schema/_administration/auth_user_dao.py +++ b/src/cpl-auth/cpl/auth/schema/_administration/user_dao.py @@ -1,19 +1,23 @@ from typing import Optional, Union from cpl.auth.permission.permissions import Permissions -from cpl.auth.schema._administration.auth_user import AuthUser +from cpl.auth.schema._permission.permission_dao import PermissionDao +from cpl.auth.schema._permission.permission import Permission +from cpl.auth.schema._administration.user import User from cpl.database import TableManager from cpl.database.abc import DbModelDaoABC from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder -from cpl.dependency import ServiceProvider +from cpl.dependency.context import get_provider -class AuthUserDao(DbModelDaoABC[AuthUser]): +class UserDao(DbModelDaoABC[User]): - def __init__(self): - DbModelDaoABC.__init__(self, AuthUser, TableManager.get("auth_users")) + def __init__(self, permission_dao: PermissionDao): + DbModelDaoABC.__init__(self, User, TableManager.get("users")) - self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId") + self._permissions = permission_dao + + self.attribute(User.keycloak_id, str) async def get_users(): return [(x.id, x.username, x.email) for x in await self.get_all()] @@ -27,11 +31,11 @@ class AuthUserDao(DbModelDaoABC[AuthUser]): .with_value_getter(get_users) ) - async def get_by_keycloak_id(self, keycloak_id: str) -> AuthUser: - return await self.get_single_by({AuthUser.keycloak_id: keycloak_id}) + async def get_by_keycloak_id(self, keycloak_id: str) -> User: + return await self.get_single_by({User.keycloak_id: keycloak_id}) - async def find_by_keycloak_id(self, keycloak_id: str) -> Optional[AuthUser]: - return await self.find_single_by({AuthUser.keycloak_id: keycloak_id}) + async def find_by_keycloak_id(self, keycloak_id: str) -> Optional[User]: + return await self.find_single_by({User.keycloak_id: keycloak_id}) async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool: from cpl.auth.schema._permission.permission_dao import PermissionDao @@ -54,7 +58,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]): return result[0]["count"] > 0 - async def get_permissions(self, user_id: int) -> list[Permissions]: + async def get_permissions(self, user_id: int) -> list[Permission]: result = await self._db.select_map( f""" SELECT p.* @@ -66,4 +70,4 @@ class AuthUserDao(DbModelDaoABC[AuthUser]): AND ru.deleted = FALSE; """ ) - return [Permissions(p["name"]) for p in result] + return [self._permissions.to_object(x) for x in result] diff --git a/src/cpl-auth/cpl/auth/schema/_permission/api_key_permission.py b/src/cpl-auth/cpl/auth/schema/_permission/api_key_permission.py index 8a7f8e4b..5a807e76 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/api_key_permission.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/api_key_permission.py @@ -15,9 +15,9 @@ class ApiKeyPermission(DbJoinModelABC): api_key_id: SerialId, permission_id: SerialId, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None = None, + updated: datetime | None = None, ): DbJoinModelABC.__init__(self, api_key_id, permission_id, id, deleted, editor_id, created, updated) self._api_key_id = api_key_id diff --git a/src/cpl-auth/cpl/auth/schema/_permission/permission.py b/src/cpl-auth/cpl/auth/schema/_permission/permission.py index e5bb046d..6ca5849a 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/permission.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/permission.py @@ -1,20 +1,20 @@ from datetime import datetime -from typing import Optional +from typing import Optional, Self from cpl.core.typing import SerialId from cpl.database.abc import DbModelABC -class Permission(DbModelABC): +class Permission(DbModelABC[Self]): def __init__( self, id: SerialId, name: str, description: str, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None = None, + updated: datetime | None = None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._name = name diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role.py b/src/cpl-auth/cpl/auth/schema/_permission/role.py index 325fec91..d5da2c12 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role.py @@ -1,24 +1,24 @@ from datetime import datetime -from typing import Optional +from typing import Optional, Self from async_property import async_property from cpl.auth.permission.permissions import Permissions from cpl.core.typing import SerialId from cpl.database.abc import DbModelABC -from cpl.dependency import ServiceProvider +from cpl.dependency import ServiceProvider, get_provider -class Role(DbModelABC): +class Role(DbModelABC[Self]): def __init__( self, id: SerialId, name: str, description: str, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None = None, + updated: datetime | None = None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) self._name = name diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py b/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py index 33b60f04..6aea5fbf 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role_permission.py @@ -1,46 +1,44 @@ from datetime import datetime -from typing import Optional +from typing import Self from async_property import async_property from cpl.core.typing import SerialId -from cpl.database.abc import DbModelABC -from cpl.dependency import ServiceProvider +from cpl.database.abc import DbJoinModelABC +from cpl.dependency import get_provider -class RolePermission(DbModelABC): +class RolePermission(DbJoinModelABC[Self]): def __init__( self, id: SerialId, role_id: SerialId, permission_id: SerialId, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None = None, + updated: datetime | None = None, ): - DbModelABC.__init__(self, id, deleted, editor_id, created, updated) - self._role_id = role_id - self._permission_id = permission_id + DbJoinModelABC.__init__(self, id, role_id, permission_id, deleted, editor_id, created, updated) @property def role_id(self) -> int: - return self._role_id + return self._source_id @async_property async def role(self): from cpl.auth.schema._permission.role_dao import RoleDao role_dao: RoleDao = get_provider().get_service(RoleDao) - return await role_dao.get_by_id(self._role_id) + return await role_dao.get_by_id(self._source_id) @property def permission_id(self) -> int: - return self._permission_id + return self._foreign_id @async_property async def permission(self): from cpl.auth.schema._permission.permission_dao import PermissionDao permission_dao: PermissionDao = get_provider().get_service(PermissionDao) - return await permission_dao.get_by_id(self._permission_id) + return await permission_dao.get_by_id(self._foreign_id) diff --git a/src/cpl-auth/cpl/auth/schema/_permission/role_user.py b/src/cpl-auth/cpl/auth/schema/_permission/role_user.py index 6f1f659e..53806c9c 100644 --- a/src/cpl-auth/cpl/auth/schema/_permission/role_user.py +++ b/src/cpl-auth/cpl/auth/schema/_permission/role_user.py @@ -5,7 +5,7 @@ from async_property import async_property from cpl.core.typing import SerialId from cpl.database.abc import DbJoinModelABC -from cpl.dependency import ServiceProvider +from cpl.dependency import ServiceProvider, get_provider class RoleUser(DbJoinModelABC): @@ -15,9 +15,9 @@ class RoleUser(DbJoinModelABC): user_id: SerialId, role_id: SerialId, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None = None, + updated: datetime | None = None, ): DbJoinModelABC.__init__(self, id, user_id, role_id, deleted, editor_id, created, updated) self._user_id = user_id @@ -29,10 +29,10 @@ class RoleUser(DbJoinModelABC): @async_property async def user(self): - from cpl.auth.schema._administration.auth_user_dao import AuthUserDao + from cpl.auth.schema._administration.user_dao import UserDao - auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) - return await auth_user_dao.get_by_id(self._user_id) + user_dao: UserDao = get_provider().get_service(UserDao) + return await user_dao.get_by_id(self._user_id) @property def role_id(self) -> int: diff --git a/src/cpl-auth/cpl/auth/scripts/mysql/1-users.sql b/src/cpl-auth/cpl/auth/scripts/mysql/1-users.sql index c3e09082..2226a9c2 100644 --- a/src/cpl-auth/cpl/auth/scripts/mysql/1-users.sql +++ b/src/cpl-auth/cpl/auth/scripts/mysql/1-users.sql @@ -1,4 +1,4 @@ -CREATE TABLE IF NOT EXISTS administration_auth_users +CREATE TABLE IF NOT EXISTS administration_users ( id INT AUTO_INCREMENT PRIMARY KEY, keycloakId CHAR(36) NOT NULL, @@ -9,10 +9,10 @@ CREATE TABLE IF NOT EXISTS administration_auth_users updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, CONSTRAINT UC_KeycloakId UNIQUE (keycloakId), - CONSTRAINT FK_EditorId FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) + CONSTRAINT FK_EditorId FOREIGN KEY (editorId) REFERENCES administration_users (id) ); -CREATE TABLE IF NOT EXISTS administration_auth_users_history +CREATE TABLE IF NOT EXISTS administration_users_history ( id INT NOT NULL, keycloakId CHAR(36) NOT NULL, @@ -23,22 +23,22 @@ CREATE TABLE IF NOT EXISTS administration_auth_users_history updated TIMESTAMP NOT NULL ); -CREATE TRIGGER TR_administration_auth_usersUpdate +CREATE TRIGGER TR_administration_usersUpdate AFTER UPDATE - ON administration_auth_users + ON administration_users FOR EACH ROW BEGIN - INSERT INTO administration_auth_users_history + INSERT INTO administration_users_history (id, keycloakId, deleted, editorId, created, updated) VALUES (OLD.id, OLD.keycloakId, OLD.deleted, OLD.editorId, OLD.created, NOW()); END; -CREATE TRIGGER TR_administration_auth_usersDelete +CREATE TRIGGER TR_administration_usersDelete AFTER DELETE - ON administration_auth_users + ON administration_users FOR EACH ROW BEGIN - INSERT INTO administration_auth_users_history + INSERT INTO administration_users_history (id, keycloakId, deleted, editorId, created, updated) VALUES (OLD.id, OLD.keycloakId, 1, OLD.editorId, OLD.created, NOW()); END; \ No newline at end of file diff --git a/src/cpl-auth/cpl/auth/scripts/mysql/2-api-key.sql b/src/cpl-auth/cpl/auth/scripts/mysql/2-api-key.sql index 134c6c78..09418f91 100644 --- a/src/cpl-auth/cpl/auth/scripts/mysql/2-api-key.sql +++ b/src/cpl-auth/cpl/auth/scripts/mysql/2-api-key.sql @@ -10,7 +10,7 @@ CREATE TABLE IF NOT EXISTS administration_api_keys CONSTRAINT UC_Identifier_Key UNIQUE (identifier, keyString), CONSTRAINT UC_Key UNIQUE (keyString), - CONSTRAINT FK_ApiKeys_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) + CONSTRAINT FK_ApiKeys_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id) ); CREATE TABLE IF NOT EXISTS administration_api_keys_history diff --git a/src/cpl-auth/cpl/auth/scripts/mysql/3-roles-permissions.sql b/src/cpl-auth/cpl/auth/scripts/mysql/3-roles-permissions.sql index f3082a48..23b4ecc8 100644 --- a/src/cpl-auth/cpl/auth/scripts/mysql/3-roles-permissions.sql +++ b/src/cpl-auth/cpl/auth/scripts/mysql/3-roles-permissions.sql @@ -8,7 +8,7 @@ CREATE TABLE IF NOT EXISTS permission_permissions created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, CONSTRAINT UQ_PermissionName UNIQUE (name), - CONSTRAINT FK_Permissions_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) + CONSTRAINT FK_Permissions_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id) ); CREATE TABLE IF NOT EXISTS permission_permissions_history @@ -52,7 +52,7 @@ CREATE TABLE IF NOT EXISTS permission_roles created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, CONSTRAINT UQ_RoleName UNIQUE (name), - CONSTRAINT FK_Roles_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) + CONSTRAINT FK_Roles_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id) ); CREATE TABLE IF NOT EXISTS permission_roles_history @@ -89,22 +89,22 @@ END; CREATE TABLE IF NOT EXISTS permission_role_permissions ( id INT AUTO_INCREMENT PRIMARY KEY, - RoleId INT NOT NULL, + roleId INT NOT NULL, permissionId INT NOT NULL, deleted BOOL NOT NULL DEFAULT FALSE, editorId INT NULL, created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - CONSTRAINT UQ_RolePermission UNIQUE (RoleId, permissionId), - CONSTRAINT FK_RolePermissions_Role FOREIGN KEY (RoleId) REFERENCES permission_roles (id) ON DELETE CASCADE, + CONSTRAINT UQ_RolePermission UNIQUE (roleId, permissionId), + CONSTRAINT FK_RolePermissions_Role FOREIGN KEY (roleId) REFERENCES permission_roles (id) ON DELETE CASCADE, CONSTRAINT FK_RolePermissions_Permission FOREIGN KEY (permissionId) REFERENCES permission_permissions (id) ON DELETE CASCADE, - CONSTRAINT FK_RolePermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) + CONSTRAINT FK_RolePermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id) ); CREATE TABLE IF NOT EXISTS permission_role_permissions_history ( id INT NOT NULL, - RoleId INT NOT NULL, + roleId INT NOT NULL, permissionId INT NOT NULL, deleted BOOL NOT NULL, editorId INT NULL, @@ -118,8 +118,8 @@ CREATE TRIGGER TR_RolePermissionsUpdate FOR EACH ROW BEGIN INSERT INTO permission_role_permissions_history - (id, RoleId, permissionId, deleted, editorId, created, updated) - VALUES (OLD.id, OLD.RoleId, OLD.permissionId, OLD.deleted, OLD.editorId, OLD.created, NOW()); + (id, roleId, permissionId, deleted, editorId, created, updated) + VALUES (OLD.id, OLD.roleId, OLD.permissionId, OLD.deleted, OLD.editorId, OLD.created, NOW()); END; CREATE TRIGGER TR_RolePermissionsDelete @@ -128,52 +128,52 @@ CREATE TRIGGER TR_RolePermissionsDelete FOR EACH ROW BEGIN INSERT INTO permission_role_permissions_history - (id, RoleId, permissionId, deleted, editorId, created, updated) - VALUES (OLD.id, OLD.RoleId, OLD.permissionId, 1, OLD.editorId, OLD.created, NOW()); + (id, roleId, permissionId, deleted, editorId, created, updated) + VALUES (OLD.id, OLD.roleId, OLD.permissionId, 1, OLD.editorId, OLD.created, NOW()); END; -CREATE TABLE IF NOT EXISTS permission_role_auth_users +CREATE TABLE IF NOT EXISTS permission_role_users ( id INT AUTO_INCREMENT PRIMARY KEY, - RoleId INT NOT NULL, - UserId INT NOT NULL, + roleId INT NOT NULL, + userId INT NOT NULL, deleted BOOL NOT NULL DEFAULT FALSE, editorId INT NULL, created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - CONSTRAINT UQ_RoleUser UNIQUE (RoleId, UserId), - CONSTRAINT FK_Roleauth_users_Role FOREIGN KEY (RoleId) REFERENCES permission_roles (id) ON DELETE CASCADE, - CONSTRAINT FK_Roleauth_users_User FOREIGN KEY (UserId) REFERENCES administration_auth_users (id) ON DELETE CASCADE, - CONSTRAINT FK_Roleauth_users_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) + CONSTRAINT UQ_RoleUser UNIQUE (roleId, userId), + CONSTRAINT FK_Roleusers_Role FOREIGN KEY (roleId) REFERENCES permission_roles (id) ON DELETE CASCADE, + CONSTRAINT FK_Roleusers_User FOREIGN KEY (userId) REFERENCES administration_users (id) ON DELETE CASCADE, + CONSTRAINT FK_Roleusers_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id) ); -CREATE TABLE IF NOT EXISTS permission_role_auth_users_history +CREATE TABLE IF NOT EXISTS permission_role_users_history ( id INT NOT NULL, - RoleId INT NOT NULL, - UserId INT NOT NULL, + roleId INT NOT NULL, + userId INT NOT NULL, deleted BOOL NOT NULL, editorId INT NULL, created TIMESTAMP NOT NULL, updated TIMESTAMP NOT NULL ); -CREATE TRIGGER TR_Roleauth_usersUpdate +CREATE TRIGGER TR_RoleusersUpdate AFTER UPDATE - ON permission_role_auth_users + ON permission_role_users FOR EACH ROW BEGIN - INSERT INTO permission_role_auth_users_history - (id, RoleId, UserId, deleted, editorId, created, updated) - VALUES (OLD.id, OLD.RoleId, OLD.UserId, OLD.deleted, OLD.editorId, OLD.created, NOW()); + INSERT INTO permission_role_users_history + (id, roleId, userId, deleted, editorId, created, updated) + VALUES (OLD.id, OLD.roleId, OLD.userId, OLD.deleted, OLD.editorId, OLD.created, NOW()); END; -CREATE TRIGGER TR_Roleauth_usersDelete +CREATE TRIGGER TR_RoleusersDelete AFTER DELETE - ON permission_role_auth_users + ON permission_role_users FOR EACH ROW BEGIN - INSERT INTO permission_role_auth_users_history - (id, RoleId, UserId, deleted, editorId, created, updated) - VALUES (OLD.id, OLD.RoleId, OLD.UserId, 1, OLD.editorId, OLD.created, NOW()); + INSERT INTO permission_role_users_history + (id, roleId, userId, deleted, editorId, created, updated) + VALUES (OLD.id, OLD.roleId, OLD.userId, 1, OLD.editorId, OLD.created, NOW()); END; diff --git a/src/cpl-auth/cpl/auth/scripts/mysql/4-api-key-permissions.sql b/src/cpl-auth/cpl/auth/scripts/mysql/4-api-key-permissions.sql index 8f8253fd..3effa6c0 100644 --- a/src/cpl-auth/cpl/auth/scripts/mysql/4-api-key-permissions.sql +++ b/src/cpl-auth/cpl/auth/scripts/mysql/4-api-key-permissions.sql @@ -10,7 +10,7 @@ CREATE TABLE IF NOT EXISTS permission_api_key_permissions CONSTRAINT UQ_ApiKeyPermission UNIQUE (apiKeyId, permissionId), CONSTRAINT FK_ApiKeyPermissions_ApiKey FOREIGN KEY (apiKeyId) REFERENCES administration_api_keys (id) ON DELETE CASCADE, CONSTRAINT FK_ApiKeyPermissions_Permission FOREIGN KEY (permissionId) REFERENCES permission_permissions (id) ON DELETE CASCADE, - CONSTRAINT FK_ApiKeyPermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) + CONSTRAINT FK_ApiKeyPermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id) ); CREATE TABLE IF NOT EXISTS permission_api_key_permissions_history diff --git a/src/cpl-auth/cpl/auth/scripts/postgres/1-users.sql b/src/cpl-auth/cpl/auth/scripts/postgres/1-users.sql index 41d15483..1735852a 100644 --- a/src/cpl-auth/cpl/auth/scripts/postgres/1-users.sql +++ b/src/cpl-auth/cpl/auth/scripts/postgres/1-users.sql @@ -1,26 +1,26 @@ CREATE SCHEMA IF NOT EXISTS administration; -CREATE TABLE IF NOT EXISTS administration.auth_users +CREATE TABLE IF NOT EXISTS administration.users ( id SERIAL PRIMARY KEY, keycloakId UUID NOT NULL, -- for history deleted BOOLEAN NOT NULL DEFAULT FALSE, - editorId INT NULL REFERENCES administration.auth_users (id), + editorId INT NULL REFERENCES administration.users (id), created timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(), CONSTRAINT UC_KeycloakId UNIQUE (keycloakId) ); -CREATE TABLE IF NOT EXISTS administration.auth_users_history +CREATE TABLE IF NOT EXISTS administration.users_history ( - LIKE administration.auth_users + LIKE administration.users ); CREATE TRIGGER users_history_trigger BEFORE INSERT OR UPDATE OR DELETE - ON administration.auth_users + ON administration.users FOR EACH ROW EXECUTE FUNCTION public.history_trigger_function(); diff --git a/src/cpl-auth/cpl/auth/scripts/postgres/2-api-key.sql b/src/cpl-auth/cpl/auth/scripts/postgres/2-api-key.sql index 9944d667..e96ed708 100644 --- a/src/cpl-auth/cpl/auth/scripts/postgres/2-api-key.sql +++ b/src/cpl-auth/cpl/auth/scripts/postgres/2-api-key.sql @@ -7,7 +7,7 @@ CREATE TABLE IF NOT EXISTS administration.api_keys keyString VARCHAR(255) NOT NULL, -- for history deleted BOOLEAN NOT NULL DEFAULT FALSE, - editorId INT NULL REFERENCES administration.auth_users (id), + editorId INT NULL REFERENCES administration.users (id), created timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(), diff --git a/src/cpl-auth/cpl/auth/scripts/postgres/3-roles-permissions.sql b/src/cpl-auth/cpl/auth/scripts/postgres/3-roles-permissions.sql index 42b9283b..8ac5e1b1 100644 --- a/src/cpl-auth/cpl/auth/scripts/postgres/3-roles-permissions.sql +++ b/src/cpl-auth/cpl/auth/scripts/postgres/3-roles-permissions.sql @@ -9,7 +9,7 @@ CREATE TABLE permission.permissions -- for history deleted BOOLEAN NOT NULL DEFAULT FALSE, - editorId INT NULL REFERENCES administration.auth_users (id), + editorId INT NULL REFERENCES administration.users (id), created timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(), CONSTRAINT UQ_PermissionName UNIQUE (name) @@ -35,7 +35,7 @@ CREATE TABLE permission.roles -- for history deleted BOOLEAN NOT NULL DEFAULT FALSE, - editorId INT NULL REFERENCES administration.auth_users (id), + editorId INT NULL REFERENCES administration.users (id), created timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(), CONSTRAINT UQ_RoleName UNIQUE (name) @@ -61,7 +61,7 @@ CREATE TABLE permission.role_permissions -- for history deleted BOOLEAN NOT NULL DEFAULT FALSE, - editorId INT NULL REFERENCES administration.auth_users (id), + editorId INT NULL REFERENCES administration.users (id), created timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(), CONSTRAINT UQ_RolePermission UNIQUE (RoleId, permissionId) @@ -83,11 +83,11 @@ CREATE TABLE permission.role_users ( id SERIAL PRIMARY KEY, RoleId INT NOT NULL REFERENCES permission.roles (id) ON DELETE CASCADE, - UserId INT NOT NULL REFERENCES administration.auth_users (id) ON DELETE CASCADE, + UserId INT NOT NULL REFERENCES administration.users (id) ON DELETE CASCADE, -- for history deleted BOOLEAN NOT NULL DEFAULT FALSE, - editorId INT NULL REFERENCES administration.auth_users (id), + editorId INT NULL REFERENCES administration.users (id), created timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(), CONSTRAINT UQ_RoleUser UNIQUE (RoleId, UserId) diff --git a/src/cpl-auth/cpl/auth/scripts/postgres/4-api-key-permissions.sql b/src/cpl-auth/cpl/auth/scripts/postgres/4-api-key-permissions.sql index 18e0d706..e0d677bb 100644 --- a/src/cpl-auth/cpl/auth/scripts/postgres/4-api-key-permissions.sql +++ b/src/cpl-auth/cpl/auth/scripts/postgres/4-api-key-permissions.sql @@ -6,7 +6,7 @@ CREATE TABLE permission.api_key_permissions -- for history deleted BOOLEAN NOT NULL DEFAULT FALSE, - editorId INT NULL REFERENCES administration.auth_users (id), + editorId INT NULL REFERENCES administration.users (id), created timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(), CONSTRAINT UQ_ApiKeyPermission UNIQUE (apiKeyId, permissionId) diff --git a/src/cpl-core/cpl/core/ctx/user_context.py b/src/cpl-core/cpl/core/ctx/user_context.py index a60d69f9..7aaa3584 100644 --- a/src/cpl-core/cpl/core/ctx/user_context.py +++ b/src/cpl-core/cpl/core/ctx/user_context.py @@ -1,13 +1,13 @@ from contextvars import ContextVar from typing import Optional -from cpl.auth.schema._administration.auth_user import AuthUser +from cpl.auth.schema._administration.user import User from cpl.dependency import get_provider -_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None) +_user_context: ContextVar[Optional[User]] = ContextVar("user", default=None) -def set_user(user: Optional[AuthUser]): +def set_user(user: Optional[User]): from cpl.core.log.logger_abc import LoggerABC logger = get_provider().get_service(LoggerABC) @@ -15,5 +15,5 @@ def set_user(user: Optional[AuthUser]): _user_context.set(user) -def get_user() -> Optional[AuthUser]: +def get_user() -> Optional[User]: return _user_context.get() diff --git a/src/cpl-core/cpl/core/log/structured_logger.py b/src/cpl-core/cpl/core/log/structured_logger.py index 2d1b9eca..e8e45849 100644 --- a/src/cpl-core/cpl/core/log/structured_logger.py +++ b/src/cpl-core/cpl/core/log/structured_logger.py @@ -68,7 +68,7 @@ class StructuredLogger(Logger): message["request"] = { "url": str(request.url), - "method": request.method, + "method": request.method if request.scope == "http" else "websocket", "scope": self._scope_to_json(request), } if isinstance(request, Request) and request.scope == "http": diff --git a/src/cpl-core/cpl/core/utils/credential_manager.py b/src/cpl-core/cpl/core/utils/credential_manager.py index d030dc94..46df3b43 100644 --- a/src/cpl-core/cpl/core/utils/credential_manager.py +++ b/src/cpl-core/cpl/core/utils/credential_manager.py @@ -2,10 +2,6 @@ import os from cryptography.fernet import Fernet -from cpl.core.log.logger import Logger - -_logger = Logger(__name__) - class CredentialManager: r"""Handles credential encryption and decryption""" @@ -14,6 +10,8 @@ class CredentialManager: @classmethod def with_secret(cls, file: str = None): + from cpl.core.log import Logger + if file is None: file = ".secret" @@ -25,12 +23,12 @@ class CredentialManager: with open(file, "w") as secret_file: secret_file.write(Fernet.generate_key().decode()) secret_file.close() - _logger.warning("Secret file not found, regenerating") + Logger(__name__).warning("Secret file not found, regenerating") with open(file, "r") as secret_file: secret = secret_file.read().strip() if secret == "" or secret is None: - _logger.fatal("No secret found in .secret file.") + Logger(__name__).fatal("No secret found in .secret file.") cls._secret = str(secret) diff --git a/src/cpl-database/cpl/database/abc/data_access_object_abc.py b/src/cpl-database/cpl/database/abc/data_access_object_abc.py index 95a12e05..7f1e235b 100644 --- a/src/cpl-database/cpl/database/abc/data_access_object_abc.py +++ b/src/cpl-database/cpl/database/abc/data_access_object_abc.py @@ -46,6 +46,10 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]): def table_name(self) -> str: return self._table_name + @property + def type(self) -> Type[T_DBM]: + return self._model_type + def has_attribute(self, attr_name: Attribute) -> bool: """ Check if the attribute exists in the DAO @@ -81,7 +85,7 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]): self.__ignored_attributes.add(attr_name) if not db_name: - db_name = attr_name.lower().replace("_", "") + db_name = String.to_camel_case(attr_name) self.__db_names[attr_name] = db_name self.__db_names[db_name] = db_name @@ -490,16 +494,16 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]): table, join_condition = self.__foreign_tables[attr] builder.with_left_join(table, join_condition) - if filters: + if filters is not None: await self._build_conditions(builder, filters, external_table_deps) - if sorts: + if sorts is not None: self._build_sorts(builder, sorts, external_table_deps) - if take: + if take is not None: builder.with_limit(take) - if skip: + if skip is not None: builder.with_offset(skip) for external_table in external_table_deps: diff --git a/src/cpl-database/cpl/database/abc/db_join_model_abc.py b/src/cpl-database/cpl/database/abc/db_join_model_abc.py index c81bd50d..42388418 100644 --- a/src/cpl-database/cpl/database/abc/db_join_model_abc.py +++ b/src/cpl-database/cpl/database/abc/db_join_model_abc.py @@ -12,9 +12,9 @@ class DbJoinModelABC[T](DbModelABC[T]): source_id: Id, foreign_id: Id, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None = None, + updated: datetime | None = None, ): DbModelABC.__init__(self, id, deleted, editor_id, created, updated) diff --git a/src/cpl-database/cpl/database/abc/db_model_abc.py b/src/cpl-database/cpl/database/abc/db_model_abc.py index edbd1f3b..3272bf67 100644 --- a/src/cpl-database/cpl/database/abc/db_model_abc.py +++ b/src/cpl-database/cpl/database/abc/db_model_abc.py @@ -2,7 +2,10 @@ from abc import ABC from datetime import datetime, timezone from typing import Optional, Generic +from async_property import async_property + from cpl.core.typing import Id, SerialId, T +from cpl.dependency import get_provider class DbModelABC(ABC, Generic[T]): @@ -10,9 +13,9 @@ class DbModelABC(ABC, Generic[T]): self, id: Id, deleted: bool = False, - editor_id: Optional[SerialId] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, + editor_id: SerialId | None = None, + created: datetime | None = None, + updated: datetime | None = None, ): self._id = id self._deleted = deleted @@ -41,14 +44,16 @@ class DbModelABC(ABC, Generic[T]): def editor_id(self, value: SerialId): self._editor_id = value - # @async_property - # async def editor(self): - # if self._editor_id is None: - # return None - # - # from data.schemas.administration.user_dao import userDao - # - # return await userDao.get_by_id(self._editor_id) + @async_property + async def editor(self): + if self._editor_id is None: + return None + + from cpl.auth.schema import UserDao + + user_dao = get_provider().get_service(UserDao) + + return await user_dao.get_by_id(self._editor_id) @property def created(self) -> datetime: diff --git a/src/cpl-database/cpl/database/abc/db_model_dao_abc.py b/src/cpl-database/cpl/database/abc/db_model_dao_abc.py index 9d9bfef6..873ba4fd 100644 --- a/src/cpl-database/cpl/database/abc/db_model_dao_abc.py +++ b/src/cpl-database/cpl/database/abc/db_model_dao_abc.py @@ -18,7 +18,7 @@ class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]): self.attribute(DbModelABC.editor_id, int, db_name="editorId", ignore=True) # handled by db trigger self.reference( - "editor", "id", DbModelABC.editor_id, TableManager.get("auth_users") + "editor", "id", DbModelABC.editor_id, TableManager.get("users") ) # not relevant for updates due to editor_id self.attribute(DbModelABC.created, datetime, ignore=True) # handled by db trigger diff --git a/src/cpl-database/cpl/database/model/database_settings.py b/src/cpl-database/cpl/database/model/database_settings.py index ccf1ad44..fa6154af 100644 --- a/src/cpl-database/cpl/database/model/database_settings.py +++ b/src/cpl-database/cpl/database/model/database_settings.py @@ -1,6 +1,6 @@ from typing import Optional -from cpl.core.configuration import Configuration +from cpl.core.configuration.configuration import Configuration from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC diff --git a/src/cpl-database/cpl/database/mysql/mysql_pool.py b/src/cpl-database/cpl/database/mysql/mysql_pool.py index a5422761..b482229c 100644 --- a/src/cpl-database/cpl/database/mysql/mysql_pool.py +++ b/src/cpl-database/cpl/database/mysql/mysql_pool.py @@ -1,6 +1,8 @@ from typing import Optional, Any - import sqlparse +import asyncio + +from mysql.connector import errors, PoolError from mysql.connector.aio import MySQLConnectionPool from cpl.core.environment import Environment @@ -10,7 +12,6 @@ from cpl.dependency.context import get_provider class MySQLPool: - def __init__(self, database_settings: DatabaseSettings): self._dbconfig = { "host": database_settings.host, @@ -25,59 +26,87 @@ class MySQLPool: "ssl_disabled": database_settings.ssl_disabled, } self._pool: Optional[MySQLConnectionPool] = None + self._pool_lock = asyncio.Lock() - async def _get_pool(self): + async def _get_pool(self) -> MySQLConnectionPool: if self._pool is None: - try: - self._pool = MySQLConnectionPool( - pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig - ) - await self._pool.initialize_pool() + async with self._pool_lock: + if self._pool is None: + try: + self._pool = MySQLConnectionPool( + pool_name="cplpool", + pool_size=Environment.get("DB_POOL_SIZE", int, 20), + **self._dbconfig, + ) + await self._pool.initialize_pool() - con = await self._pool.get_connection() - async with await con.cursor() as cursor: - await cursor.execute("SELECT 1") - await cursor.fetchall() - - await con.close() - except Exception as e: - logger = get_provider().get_service(DBLogger) - logger.fatal(f"Error connecting to the database", e) + # Testverbindung (Ping) + con = await self._pool.get_connection() + try: + async with await con.cursor() as cursor: + await cursor.execute("SELECT 1") + await cursor.fetchall() + finally: + await con.close() + except Exception as e: + logger = get_provider().get_service(DBLogger) + logger.fatal("Error connecting to the database", e) + raise return self._pool + async def _get_connection(self, retries: int = 3, delay: float = 0.5): + """Stabiler Connection-Getter mit Retry und Ping""" + pool = await self._get_pool() + + for attempt in range(retries): + try: + con = await pool.get_connection() + + # Verbindungs-Check (Ping) + try: + async with await con.cursor() as cursor: + await cursor.execute("SELECT 1") + await cursor.fetchall() + except errors.OperationalError: + await con.close() + raise + + return con + + except PoolError: + if attempt == retries - 1: + raise + await asyncio.sleep(delay) + @staticmethod async def _exec_sql(cursor: Any, query: str, args=None, multi=True): result = [] if multi: queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()] for q in queries: - if q.strip() == "": - continue - await cursor.execute(q, args) - if cursor.description is not None: - result = await cursor.fetchall() + if q: + await cursor.execute(q, args) + if cursor.description is not None: + result = await cursor.fetchall() else: await cursor.execute(query, args) if cursor.description is not None: result = await cursor.fetchall() - return result - async def execute(self, query: str, args=None, multi=True) -> list[list]: - pool = await self._get_pool() - con = await pool.get_connection() + async def execute(self, query: str, args=None, multi=True) -> list[str]: + con = await self._get_connection() try: async with await con.cursor() as cursor: - result = await self._exec_sql(cursor, query, args, multi) + res = await self._exec_sql(cursor, query, args, multi) await con.commit() - return result + return list(res) finally: await con.close() async def select(self, query: str, args=None, multi=True) -> list[str]: - pool = await self._get_pool() - con = await pool.get_connection() + con = await self._get_connection() try: async with await con.cursor() as cursor: res = await self._exec_sql(cursor, query, args, multi) @@ -86,11 +115,17 @@ class MySQLPool: await con.close() async def select_map(self, query: str, args=None, multi=True) -> list[dict]: - pool = await self._get_pool() - con = await pool.get_connection() + con = await self._get_connection() try: async with await con.cursor(dictionary=True) as cursor: res = await self._exec_sql(cursor, query, args, multi) - return list(res) + decoded_res = [] + for row in res: + decoded_row = { + k: (v.decode("utf-8") if isinstance(v, (bytes, bytearray)) else v) for k, v in row.items() + } + decoded_res.append(decoded_row) + + return decoded_res finally: await con.close() diff --git a/src/cpl-database/cpl/database/postgres/postgres_pool.py b/src/cpl-database/cpl/database/postgres/postgres_pool.py index 891fb7f1..434c2655 100644 --- a/src/cpl-database/cpl/database/postgres/postgres_pool.py +++ b/src/cpl-database/cpl/database/postgres/postgres_pool.py @@ -27,7 +27,7 @@ class PostgresPool: self._pool: Optional[AsyncConnectionPool] = None async def _get_pool(self): - if self._pool is None: + if self._pool is None or self._pool.closed: pool = AsyncConnectionPool( conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1) ) diff --git a/src/cpl-database/cpl/database/schema/executed_migration.py b/src/cpl-database/cpl/database/schema/executed_migration.py index 3b9ed1c5..b6ec58ac 100644 --- a/src/cpl-database/cpl/database/schema/executed_migration.py +++ b/src/cpl-database/cpl/database/schema/executed_migration.py @@ -1,15 +1,15 @@ from datetime import datetime -from typing import Optional +from typing import Optional, Self from cpl.database.abc import DbModelABC -class ExecutedMigration(DbModelABC): +class ExecutedMigration(DbModelABC[Self]): def __init__( self, migration_id: str, - created: Optional[datetime] = None, - modified: Optional[datetime] = None, + created: datetime | None = None, + modified: datetime | None = None, ): DbModelABC.__init__(self, migration_id, False, created, modified) diff --git a/src/cpl-database/cpl/database/table_manager.py b/src/cpl-database/cpl/database/table_manager.py index 9bd1f6b2..7ca8d4e9 100644 --- a/src/cpl-database/cpl/database/table_manager.py +++ b/src/cpl-database/cpl/database/table_manager.py @@ -7,9 +7,9 @@ class TableManager: ServerTypes.POSTGRES: "system._executed_migrations", ServerTypes.MYSQL: "system__executed_migrations", }, - "auth_users": { - ServerTypes.POSTGRES: "administration.auth_users", - ServerTypes.MYSQL: "administration_auth_users", + "users": { + ServerTypes.POSTGRES: "administration.users", + ServerTypes.MYSQL: "administration_users", }, "api_keys": { ServerTypes.POSTGRES: "administration.api_keys", @@ -33,7 +33,7 @@ class TableManager: }, "role_users": { ServerTypes.POSTGRES: "permission.role_users", - ServerTypes.MYSQL: "permission_role_auth_users", + ServerTypes.MYSQL: "permission_role_users", }, } diff --git a/src/cpl-dependency/cpl/dependency/event_bus.py b/src/cpl-dependency/cpl/dependency/event_bus.py new file mode 100644 index 00000000..efd372aa --- /dev/null +++ b/src/cpl-dependency/cpl/dependency/event_bus.py @@ -0,0 +1,10 @@ +from abc import abstractmethod, ABC +from typing import Any, AsyncGenerator + + +class EventBusABC(ABC): + @abstractmethod + async def publish(self, channel: str, event: Any) -> None: ... + + @abstractmethod + async def subscribe(self, channel: str) -> AsyncGenerator[Any, None]: ... diff --git a/src/cpl-dependency/cpl/dependency/module/module_abc.py b/src/cpl-dependency/cpl/dependency/module/module_abc.py index 971a721c..9cf0c9f8 100644 --- a/src/cpl-dependency/cpl/dependency/module/module_abc.py +++ b/src/cpl-dependency/cpl/dependency/module/module_abc.py @@ -8,7 +8,7 @@ class ModuleABC(ABC): __OPTIONAL_VARS = ["dependencies", "configuration", "singleton", "scoped", "transient", "hosted"] def __init_subclass__(cls): - super().__init_subclass__() + ABC.__init_subclass__() if f"{cls.__module__}.{cls.__name__}" == "cpl.dependency.module.module.Module": return diff --git a/src/cpl-dependency/cpl/dependency/service_provider.py b/src/cpl-dependency/cpl/dependency/service_provider.py index 23a4216d..38e0ae46 100644 --- a/src/cpl-dependency/cpl/dependency/service_provider.py +++ b/src/cpl-dependency/cpl/dependency/service_provider.py @@ -25,7 +25,7 @@ class ServiceProvider: for descriptor in self._service_descriptors: if typing.get_origin(service_type) is None and ( - descriptor.service_type == service_type + descriptor.service_type.__name__ == service_type.__name__ or typing.get_origin(descriptor.base_type) is None and issubclass(descriptor.base_type, service_type) ): diff --git a/src/cpl-graphql/cpl/graphql/__init__.py b/src/cpl-graphql/cpl/graphql/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/_endpoints/__init__.py b/src/cpl-graphql/cpl/graphql/_endpoints/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py b/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py new file mode 100644 index 00000000..a369fd64 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/_endpoints/graphiql.py @@ -0,0 +1,69 @@ +from starlette.responses import HTMLResponse + + +async def graphiql_endpoint(request): + return HTMLResponse( + """ + + + + + GraphiQL + + + +
+ + + + + + + + + + + + + + + """ + ) diff --git a/src/cpl-graphql/cpl/graphql/_endpoints/graphql.py b/src/cpl-graphql/cpl/graphql/_endpoints/graphql.py new file mode 100644 index 00000000..01cb133b --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/_endpoints/graphql.py @@ -0,0 +1,13 @@ +from starlette.requests import Request +from starlette.responses import Response, JSONResponse + +from cpl.graphql.service.graphql import GraphQLService + + +async def graphql_endpoint(request: Request, service: GraphQLService) -> Response: + body = await request.json() + query = body.get("query") + variables = body.get("variables") + + response_data = await service.execute(query, variables, request) + return JSONResponse(response_data) diff --git a/src/cpl-graphql/cpl/graphql/_endpoints/lazy_graphql_app.py b/src/cpl-graphql/cpl/graphql/_endpoints/lazy_graphql_app.py new file mode 100644 index 00000000..e70970c9 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/_endpoints/lazy_graphql_app.py @@ -0,0 +1,27 @@ +from starlette.requests import Request +from starlette.responses import Response +from strawberry.asgi import GraphQL +from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL + +from cpl.dependency import ServiceProvider +from cpl.graphql.service.schema import Schema + + +class LazyGraphQLApp: + + def __init__(self, services: ServiceProvider): + self._services = services + self._graphql_app = None + + async def __call__(self, scope, receive, send): + if self._graphql_app is None: + schema = self._services.get_service(Schema) + if not schema or not schema.schema: + raise RuntimeError("GraphQL Schema not available yet") + + self._graphql_app = GraphQL( + schema.schema, + subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL], + ) + + await self._graphql_app(scope, receive, send) diff --git a/src/cpl-graphql/cpl/graphql/_endpoints/playground.py b/src/cpl-graphql/cpl/graphql/_endpoints/playground.py new file mode 100644 index 00000000..969cd506 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/_endpoints/playground.py @@ -0,0 +1,29 @@ +from starlette.requests import Request +from starlette.responses import Response, HTMLResponse + + +async def playground_endpoint(request: Request) -> Response: + return HTMLResponse( + """ + + + + + GraphQL Playground + + + + + +
+ + + + """ + ) diff --git a/src/cpl-graphql/cpl/graphql/abc/__init__.py b/src/cpl-graphql/cpl/graphql/abc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/abc/query_abc.py b/src/cpl-graphql/cpl/graphql/abc/query_abc.py new file mode 100644 index 00000000..1c7cb648 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/abc/query_abc.py @@ -0,0 +1,227 @@ +import functools +import inspect +import types +from abc import ABC +from asyncio import iscoroutinefunction +from typing import Callable, Type, Any, Optional + +import strawberry +from async_property.base import AsyncPropertyDescriptor +from strawberry.exceptions import StrawberryException + +from cpl.api import Unauthorized, Forbidden +from cpl.core.ctx.user_context import get_user +from cpl.dependency import get_provider +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol +from cpl.graphql.error import graphql_error +from cpl.graphql.query_context import QueryContext +from cpl.graphql.schema.field import Field +from cpl.graphql.typing import Resolver, AttributeName +from cpl.graphql.utils.type_collector import TypeCollector + + +class QueryABC(StrawberryProtocol, ABC): + + def __init__(self): + ABC.__init__(self) + self._fields: dict[str, Field] = {} + + @property + def fields(self) -> dict[str, Field]: + return self._fields + + @property + def fields_count(self) -> int: + return len(self._fields) + + def get_fields(self) -> dict[str, Field]: + return self._fields + + def field( + self, + name: AttributeName, + t: type, + resolver: Resolver = None, + ) -> Field: + from cpl.graphql.schema.field import Field + + if isinstance(name, property): + name = name.fget.__name__ + + self._fields[name] = Field(name, t, resolver) + return self._fields[name] + + def string_field(self, name: AttributeName, resolver: Resolver = None) -> Field: + return self.field(name, str, resolver) + + def int_field(self, name: AttributeName, resolver: Resolver = None) -> Field: + return self.field(name, int, resolver) + + def float_field(self, name: AttributeName, resolver: Resolver = None) -> Field: + return self.field(name, float, resolver) + + def bool_field(self, name: AttributeName, resolver: Resolver = None) -> Field: + return self.field(name, bool, resolver) + + def list_field(self, name: AttributeName, t: type, resolver: Resolver = None) -> Field: + return self.field(name, list[t], resolver) + + def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field: + if not isinstance(t, type) and callable(t): + return self.field(name, t, resolver) + + return self.field(name, t().to_strawberry(), resolver) + + @staticmethod + def _build_resolver(f: "Field"): + params: list[inspect.Parameter] = [] + for arg in f.arguments.values(): + _type = arg.type + if isinstance(_type, type) and issubclass(_type, StrawberryProtocol): + _type = _type().to_strawberry() + + ann = Optional[_type] if arg.optional else _type + + if arg.default is None: + param = inspect.Parameter( + arg.name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=ann, + ) + else: + param = inspect.Parameter( + arg.name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=ann, + default=arg.default, + ) + + params.append(param) + + sig = inspect.Signature(parameters=params, return_annotation=f.type) + + async def _resolver(*args, **kwargs): + if f.resolver is None: + return None + + if iscoroutinefunction(f.resolver): + return await f.resolver(*args, **kwargs) + return f.resolver(*args, **kwargs) + + _resolver.__signature__ = sig + return _resolver + + def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable: + sig = getattr(resolver, "__signature__", None) + + @functools.wraps(resolver) + async def _auth_resolver(*args, **kwargs): + if f.public: + return await self._run_resolver(resolver, *args, **kwargs) + + user = get_user() + + if user is None: + raise graphql_error(Unauthorized(f"{f.name}: Authentication required")) + + if f.require_any_permission: + if not any([await user.has_permission(p) for p in f.require_any_permission]): + raise graphql_error(Forbidden(f"{f.name}: Permission denied")) + + if f.require_any: + perms, resolvers = f.require_any + if not any([await user.has_permission(p) for p in perms]): + ctx = QueryContext([x.name for x in await user.permissions]) + resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers] + + if not any(resolved): + raise graphql_error(Forbidden(f"{f.name}: Permission denied")) + + return await self._run_resolver(resolver, *args, **kwargs) + + if sig: + _auth_resolver.__signature__ = sig + + return _auth_resolver + + @staticmethod + async def _run_resolver(r: Callable, *args, **kwargs): + result = r(*args, **kwargs) + if inspect.isawaitable(result): + return await result + return result + + def _field_to_strawberry(self, f: Field) -> Any: + resolver = None + try: + if f.arguments: + resolver = self._build_resolver(f) + elif not f.resolver: + resolver = lambda root: None + else: + ann = getattr(f.resolver, "__annotations__", {}) + if "return" not in ann or ann["return"] is None: + ann = dict(ann) + ann["return"] = f.type + f.resolver.__annotations__ = ann + resolver = f.resolver + + return strawberry.field(resolver=self._wrap_with_auth(f, resolver)) + except StrawberryException as e: + raise Exception(f"Error converting field '{f.name}' to strawberry field: {e}") from e + + @staticmethod + def _type_to_strawberry(t: Type) -> Type: + _t = get_provider().get_service(t) + + if isinstance(_t, StrawberryProtocol): + return _t.to_strawberry() + + return t + + def to_strawberry(self) -> Type: + cls = self.__class__ + if TypeCollector.has(cls): + return TypeCollector.get(cls) + + gql_cls = type(f"{cls.__name__.replace('GraphType', '')}", (), {}) + # register early to handle recursive types + TypeCollector.set(cls, gql_cls) + + annotations: dict[str, Any] = {} + namespace: dict[str, Any] = {} + + for name, f in self._fields.items(): + t = f.type + if isinstance(name, property): + name = name.fget.__name__ + if isinstance(name, AsyncPropertyDescriptor): + name = name.field_name + + if isinstance(t, types.GenericAlias): + t = t.__args__[0] + + if callable(t) and not isinstance(t, type): + t = self._type_to_strawberry(t()) + elif issubclass(t, StrawberryProtocol): + t = self._type_to_strawberry(t) + + annotations[name] = t if not f.optional else Optional[t] + namespace[name] = self._field_to_strawberry(f) + + namespace["__annotations__"] = annotations + for k, v in namespace.items(): + if isinstance(k, property): + k = k.fget.__name__ + if isinstance(k, AsyncPropertyDescriptor): + k = k.field_name + + setattr(gql_cls, k, v) + + try: + gql_cls.__annotations__ = annotations + gql_type = strawberry.type(gql_cls) + except Exception as e: + raise Exception(f"Error creating strawberry type for '{cls.__name__}': {e}") from e + TypeCollector.set(cls, gql_type) + return gql_type diff --git a/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py b/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py new file mode 100644 index 00000000..ad8f18b8 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/abc/strawberry_protocol.py @@ -0,0 +1,11 @@ +from typing import Protocol, Type, runtime_checkable + +from cpl.graphql.schema.field import Field +from cpl.graphql.schema.subscription_field import SubscriptionField + + +@runtime_checkable +class StrawberryProtocol(Protocol): + def to_strawberry(self) -> Type: ... + + def get_fields(self) -> dict[str, Field | SubscriptionField]: ... diff --git a/src/cpl-graphql/cpl/graphql/application/__init__.py b/src/cpl-graphql/cpl/graphql/application/__init__.py new file mode 100644 index 00000000..96b2346c --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/application/__init__.py @@ -0,0 +1 @@ +from .graphql_app import WebApp diff --git a/src/cpl-graphql/cpl/graphql/application/graphql_app.py b/src/cpl-graphql/cpl/graphql/application/graphql_app.py new file mode 100644 index 00000000..4730006f --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/application/graphql_app.py @@ -0,0 +1,126 @@ +import socket +from enum import Enum +from typing import Self + +from cpl.api.application import WebApp +from cpl.api.model.validation_match import ValidationMatch +from cpl.application.abc.application_abc import __not_implemented__ +from cpl.core.environment import Environment +from cpl.dependency.service_provider import ServiceProvider +from cpl.dependency.typing import Modules +from cpl.graphql._endpoints.graphiql import graphiql_endpoint +from cpl.graphql._endpoints.graphql import graphql_endpoint +from cpl.graphql._endpoints.lazy_graphql_app import LazyGraphQLApp +from cpl.graphql._endpoints.playground import playground_endpoint +from cpl.graphql.graphql_module import GraphQLModule +from cpl.graphql.service.schema import Schema + + +class GraphQLApp(WebApp): + def __init__(self, services: ServiceProvider, modules: Modules): + WebApp.__init__(self, services, modules, [GraphQLModule]) + + self._with_graphiql = False + self._with_playground = False + + def with_graphql( + self, + authentication: bool = False, + roles: list[str | Enum] = None, + permissions: list[str | Enum] = None, + policies: list[str] = None, + match: ValidationMatch = None, + ) -> Schema: + self.with_route( + path="/api/graphql", + fn=graphql_endpoint, + method="POST", + authentication=authentication, + roles=roles, + permissions=permissions, + policies=policies, + match=match, + ) + schema = self._services.get_service(Schema) + if schema is None: + self._logger.fatal("Could not resolve RootQuery. Make sure GraphQLModule is registered.") + # + # graphql_ws_app = GraphQL( + # schema, + # subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL], + # ) + self.with_websocket("/api/graphql/ws", LazyGraphQLApp(self._services)) + return schema + + def with_graphiql( + self, + authentication: bool = False, + roles: list[str | Enum] = None, + permissions: list[str | Enum] = None, + policies: list[str] = None, + match: ValidationMatch = None, + ) -> Self: + self.with_route( + path="/api/graphiql", + fn=graphiql_endpoint, + method="GET", + authentication=authentication, + roles=roles, + permissions=permissions, + policies=policies, + match=match, + ) + self._with_graphiql = True + return self + + def with_playground( + self, + authentication: bool = False, + roles: list[str | Enum] = None, + permissions: list[str | Enum] = None, + policies: list[str] = None, + match: ValidationMatch = None, + ) -> Self: + self.with_route( + path="/api/playground", + fn=playground_endpoint, + method="GET", + authentication=authentication, + roles=roles, + permissions=permissions, + policies=policies, + match=match, + ) + self._with_playground = True + return self + + def with_auth_root_queries(self, public: bool = False): + try: + from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule + + GraphQLAuthModule.with_auth_root_queries(self._services, public=public) + except ImportError: + __not_implemented__("cpl-auth & cpl-graphql", self.with_auth_root_mutations) + + def with_auth_root_mutations(self, public: bool = False): + try: + from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule + + GraphQLAuthModule.with_auth_root_mutations(self._services, public=public) + except ImportError: + __not_implemented__("cpl-auth & cpl-graphql", self.with_auth_root_mutations) + + async def _log_before_startup(self): + host = self._api_settings.host + if host == "0.0.0.0" and Environment.get_environment() == "development": + host = "localhost" + elif host == "0.0.0.0": + host = socket.gethostbyname(socket.gethostname()) + + self._logger.info(f"Start API on {host}:{self._api_settings.port}") + if self._with_graphiql: + self._logger.warning(f"GraphiQL available at http://{host}:{self._api_settings.port}/api/graphiql") + if self._with_playground: + self._logger.warning( + f"GraphQL Playground available at http://{host}:{self._api_settings.port}/api/playground" + ) diff --git a/src/cpl-graphql/cpl/graphql/auth/__init__.py b/src/cpl-graphql/cpl/graphql/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/auth/api_key/__init__.py b/src/cpl-graphql/cpl/graphql/auth/api_key/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_filter.py b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_filter.py new file mode 100644 index 00000000..9c5752d2 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_filter.py @@ -0,0 +1,10 @@ +from cpl.auth.schema import ApiKey +from cpl.graphql.schema.filter.db_model_filter import DbModelFilter +from cpl.graphql.schema.filter.string_filter import StringFilter + + +class ApiKeyFilter(DbModelFilter[ApiKey]): + def __init__(self, public: bool = False): + DbModelFilter.__init__(self, public) + + self.field("identifier", StringFilter).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_graph_type.py b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_graph_type.py new file mode 100644 index 00000000..0bb52bbb --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_graph_type.py @@ -0,0 +1,14 @@ +from cpl.auth.schema import ApiKey, RolePermissionDao +from cpl.graphql.schema.db_model_graph_type import DbModelGraphType + + +class ApiKeyGraphType(DbModelGraphType[ApiKey]): + + def __init__(self, role_permission_dao: RolePermissionDao): + DbModelGraphType.__init__(self) + + self.string_field(ApiKey.identifier, lambda root: root.identifier) + self.string_field(ApiKey.key, lambda root: root.key) + self.string_field(ApiKey.permissions, lambda root: root.permissions) + + self.set_history_reference_dao(role_permission_dao, "apikeyid") diff --git a/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_input.py b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_input.py new file mode 100644 index 00000000..a669fce1 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_input.py @@ -0,0 +1,25 @@ +from cpl.auth.schema import ApiKey +from cpl.core.typing import SerialId +from cpl.graphql.schema.input import Input + + +class ApiKeyCreateInput(Input[ApiKey]): + identifier: str + permissions: list[SerialId] + + def __init__(self): + Input.__init__(self) + self.string_field("identifier").with_required() + self.list_field("permissions", SerialId) + + +class ApiKeyUpdateInput(Input[ApiKey]): + id: SerialId + identifier: str | None + permissions: list[SerialId] | None + + def __init__(self): + Input.__init__(self) + self.int_field("id").with_required() + self.string_field("identifier").with_required() + self.list_field("permissions", SerialId) diff --git a/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py new file mode 100644 index 00000000..dd3a4665 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_mutation.py @@ -0,0 +1,93 @@ +from cpl.api import APILogger +from cpl.auth.keycloak import KeycloakAdmin +from cpl.auth.permission import Permissions +from cpl.auth.schema import ApiKey, ApiKeyDao, ApiKeyPermissionDao, ApiKeyPermission +from cpl.graphql.auth.api_key.api_key_input import ApiKeyUpdateInput, ApiKeyCreateInput +from cpl.graphql.schema.mutation import Mutation + + +class ApiKeyMutation(Mutation): + def __init__( + self, + logger: APILogger, + api_key_dao: ApiKeyDao, + api_key_permission_dao: ApiKeyPermissionDao, + permission_dao: ApiKeyPermissionDao, + keycloak_admin: KeycloakAdmin, + ): + Mutation.__init__(self) + self._logger = logger + self._api_key_dao = api_key_dao + self._api_key_permission_dao = api_key_permission_dao + self._permission_dao = permission_dao + self._keycloak_admin = keycloak_admin + + self.int_field( + "create", + self.resolve_create, + ).with_require_any_permission(Permissions.api_keys_create).with_argument( + "input", + ApiKeyCreateInput, + ).with_required() + + self.bool_field( + "update", + self.resolve_update, + ).with_require_any_permission(Permissions.api_keys_update).with_argument( + "input", + ApiKeyUpdateInput, + ).with_required() + + self.bool_field( + "delete", + self.resolve_delete, + ).with_require_any_permission(Permissions.api_keys_delete).with_argument( + "id", + int, + ).with_required() + + self.bool_field( + "restore", + self.resolve_restore, + ).with_require_any_permission(Permissions.api_keys_delete).with_argument( + "id", + int, + ).with_required() + + async def resolve_create(self, obj: ApiKeyCreateInput): + self._logger.debug(f"create api key: {obj.__dict__}") + + api_key = ApiKey.new(obj.identifier) + await self._api_key_dao.create(api_key) + api_key = await self._api_key_dao.get_single_by([{ApiKey.identifier: obj.identifier}]) + await self._api_key_permission_dao.create_many([ApiKeyPermission(0, api_key.id, x) for x in obj.permissions]) + return api_key + + async def resolve_update(self, input: ApiKeyUpdateInput): + self._logger.debug(f"update api key: {input}") + api_key = await self._api_key_dao.get_by_id(input.id) + + await self._resolve_assignments( + input.permissions or [], + api_key, + ApiKeyPermission.api_key_id, + ApiKeyPermission.permission_id, + self._api_key_dao, + self._api_key_permission_dao, + ApiKeyPermission, + self._permission_dao, + ) + + return api_key + + async def resolve_delete(self, id: str): + self._logger.debug(f"delete api key: {id}") + api_key = await self._api_key_dao.get_by_id(id) + await self._api_key_dao.delete(api_key) + return True + + async def resolve_restore(self, id: str): + self._logger.debug(f"restore api key: {id}") + api_key = await self._api_key_dao.get_by_id(id) + await self._api_key_dao.restore(api_key) + return True diff --git a/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_sort.py b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_sort.py new file mode 100644 index 00000000..af3d0c18 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/api_key/api_key_sort.py @@ -0,0 +1,9 @@ +from cpl.auth.schema import ApiKey +from cpl.graphql.schema.sort.db_model_sort import DbModelSort +from cpl.graphql.schema.sort.sort_order import SortOrder + + +class ApiKeySort(DbModelSort[ApiKey]): + def __init__(self): + DbModelSort.__init__(self) + self.field("identifier", SortOrder) diff --git a/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py new file mode 100644 index 00000000..7ce2a0b4 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/graphql_auth_module.py @@ -0,0 +1,77 @@ +from cpl.auth.permission import Permissions +from cpl.auth.schema import UserDao, ApiKeyDao, RoleDao +from cpl.core.configuration import Configuration +from cpl.dependency import ServiceProvider +from cpl.dependency.module.module import Module +from cpl.dependency.service_collection import ServiceCollection +from cpl.graphql.auth.api_key.api_key_filter import ApiKeyFilter +from cpl.graphql.auth.api_key.api_key_graph_type import ApiKeyGraphType +from cpl.graphql.auth.api_key.api_key_mutation import ApiKeyMutation +from cpl.graphql.auth.api_key.api_key_sort import ApiKeySort +from cpl.graphql.auth.role.role_filter import RoleFilter +from cpl.graphql.auth.role.role_graph_type import RoleGraphType +from cpl.graphql.auth.role.role_mutation import RoleMutation +from cpl.graphql.auth.role.role_sort import RoleSort +from cpl.graphql.auth.user.user_filter import UserFilter +from cpl.graphql.auth.user.user_graph_type import UserGraphType +from cpl.graphql.auth.user.user_mutation import UserMutation +from cpl.graphql.auth.user.user_sort import UserSort +from cpl.graphql.graphql_module import GraphQLModule +from cpl.graphql.service.schema import Schema + + +class GraphQLAuthModule(Module): + dependencies = [GraphQLModule] + transient = [ + UserGraphType, + UserMutation, + UserFilter, + UserSort, + ApiKeyGraphType, + ApiKeyMutation, + ApiKeyFilter, + ApiKeySort, + RoleGraphType, + RoleMutation, + RoleFilter, + RoleSort, + ] + + @staticmethod + def register(collection: ServiceCollection): + Configuration.set("GraphQLAuthModuleEnabled", True) + + @staticmethod + def configure(provider: ServiceProvider): + schema = provider.get_service(Schema) + schema.with_type(UserGraphType) + schema.with_type(ApiKeyGraphType) + schema.with_type(RoleGraphType) + + @staticmethod + def with_auth_root_queries(provider: ServiceProvider, public: bool = False): + if not Configuration.get("GraphQLAuthModuleEnabled", False): + raise Exception("GraphQLAuthModule is not loaded yet. Make sure to run 'add_module(GraphQLAuthModule)'") + + schema = provider.get_service(Schema) + schema.query.dao_collection_field( + UserGraphType, UserDao, "users", UserFilter, UserSort + ).with_require_any_permission(Permissions.users).with_public(public) + + schema.query.dao_collection_field( + ApiKeyGraphType, ApiKeyDao, "apiKeys", ApiKeyFilter, ApiKeySort + ).with_require_any_permission(Permissions.api_keys).with_public(public) + + schema.query.dao_collection_field( + RoleGraphType, RoleDao, "roles", RoleFilter, RoleSort + ).with_require_any_permission(Permissions.roles).with_public(public) + + @staticmethod + def with_auth_root_mutations(provider: ServiceProvider, public: bool = False): + if not Configuration.get("GraphQLAuthModuleEnabled", False): + raise Exception("GraphQLAuthModule is not loaded yet. Make sure to run 'add_module(GraphQLAuthModule)'") + + schema = provider.get_service(Schema) + schema.mutation.with_mutation("user", UserMutation).with_public(public) + schema.mutation.with_mutation("apiKey", ApiKeyMutation).with_public(public) + schema.mutation.with_mutation("role", RoleMutation).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/auth/role/__init__.py b/src/cpl-graphql/cpl/graphql/auth/role/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/auth/role/role_filter.py b/src/cpl-graphql/cpl/graphql/auth/role/role_filter.py new file mode 100644 index 00000000..f31dbf4f --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/role/role_filter.py @@ -0,0 +1,11 @@ +from cpl.auth.schema import User, Role +from cpl.graphql.schema.filter.db_model_filter import DbModelFilter +from cpl.graphql.schema.filter.string_filter import StringFilter + + +class RoleFilter(DbModelFilter[Role]): + def __init__(self, public: bool = False): + DbModelFilter.__init__(self, public) + + self.field("name", StringFilter).with_public(public) + self.field("description", StringFilter).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/auth/role/role_graph_type.py b/src/cpl-graphql/cpl/graphql/auth/role/role_graph_type.py new file mode 100644 index 00000000..27ce9309 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/role/role_graph_type.py @@ -0,0 +1,14 @@ +from cpl.auth.schema import Role +from cpl.graphql.auth.user.user_graph_type import UserGraphType +from cpl.graphql.schema.db_model_graph_type import DbModelGraphType + + +class RoleGraphType(DbModelGraphType[Role]): + + def __init__(self, public: bool = False): + DbModelGraphType.__init__(self) + + self.string_field("name", lambda root: root.name).with_public(public) + self.string_field("description", lambda root: root.description).with_public(public) + self.list_field("permissions", str, lambda root: root.permissions).with_public(public) + self.list_field("users", UserGraphType, lambda root: root.users).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/auth/role/role_input.py b/src/cpl-graphql/cpl/graphql/auth/role/role_input.py new file mode 100644 index 00000000..7ae1334f --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/role/role_input.py @@ -0,0 +1,29 @@ +from cpl.auth.schema import User, Role +from cpl.core.typing import SerialId +from cpl.graphql.schema.input import Input + + +class RoleCreateInput(Input[Role]): + name: str + description: str | None + permissions: list[SerialId] | None + + def __init__(self): + Input.__init__(self) + self.string_field("name").with_required() + self.string_field("description") + self.list_field("permissions", SerialId) + + +class RoleUpdateInput(Input[Role]): + id: SerialId + name: str | None + description: str | None + permissions: list[SerialId] | None + + def __init__(self): + Input.__init__(self) + self.int_field("id").with_required() + self.string_field("name") + self.string_field("description") + self.list_field("permissions", SerialId) diff --git a/src/cpl-graphql/cpl/graphql/auth/role/role_mutation.py b/src/cpl-graphql/cpl/graphql/auth/role/role_mutation.py new file mode 100644 index 00000000..df7d06d8 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/role/role_mutation.py @@ -0,0 +1,101 @@ +from cpl.api import APILogger +from cpl.auth.keycloak import KeycloakAdmin +from cpl.auth.permission import Permissions +from cpl.auth.schema import RoleDao, Role, RolePermissionDao, RolePermission +from cpl.graphql.auth.role.role_input import RoleCreateInput, RoleUpdateInput +from cpl.graphql.schema.mutation import Mutation + + +class RoleMutation(Mutation): + def __init__( + self, + logger: APILogger, + role_dao: RoleDao, + role_permission_dao: RolePermissionDao, + permission_dao: RolePermissionDao, + keycloak_admin: KeycloakAdmin, + ): + Mutation.__init__(self) + self._logger = logger + self._role_dao = role_dao + self._role_permission_dao = role_permission_dao + self._permission_dao = permission_dao + self._keycloak_admin = keycloak_admin + + self.int_field( + "create", + self.resolve_create, + ).with_require_any_permission(Permissions.roles_create).with_argument( + "input", + RoleCreateInput, + ).with_required() + + self.bool_field( + "update", + self.resolve_update, + ).with_require_any_permission(Permissions.roles_update).with_argument( + "input", + RoleUpdateInput, + ).with_required() + + self.bool_field( + "delete", + self.resolve_delete, + ).with_require_any_permission(Permissions.roles_delete).with_argument( + "id", + int, + ).with_required() + + self.bool_field( + "restore", + self.resolve_restore, + ).with_require_any_permission(Permissions.roles_delete).with_argument( + "id", + int, + ).with_required() + + async def resolve_create(self, input: RoleCreateInput, *_): + self._logger.debug(f"create role: {input.__dict__}") + + role = Role( + 0, + input.name, + input.description, + ) + await self._role_dao.create(role) + role = await self._role_dao.get_by_name(role.name) + await self._role_permission_dao.create_many([RolePermission(0, role.id, x) for x in input.permissions]) + + return role + + async def resolve_update(self, input: RoleUpdateInput, *_): + self._logger.debug(f"update role: {input.__dict__}") + role = await self._role_dao.get_by_id(input.id) + role.name = input.get("name", role.name) + role.description = input.get("description", role.description) + await self._role_dao.update(role) + + await self._resolve_assignments( + input.get("permissions", []), + role, + RolePermission.role_id, + RolePermission.permission_id, + self._role_dao, + self._role_permission_dao, + RolePermission, + self._permission_dao, + ) + + return role + + async def resolve_delete(self, id: int): + self._logger.debug(f"delete role: {id}") + role = await self._role_dao.get_by_id(id) + await self._role_dao.delete(role) + return True + + async def resolve_restore(self, id: int): + self._logger.debug(f"restore role: {id}") + role = await self._role_dao.get_by_id(id) + await self._role_dao.restore(role) + return True diff --git a/src/cpl-graphql/cpl/graphql/auth/role/role_sort.py b/src/cpl-graphql/cpl/graphql/auth/role/role_sort.py new file mode 100644 index 00000000..6c55568e --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/role/role_sort.py @@ -0,0 +1,10 @@ +from cpl.auth.schema import Role +from cpl.graphql.schema.sort.db_model_sort import DbModelSort +from cpl.graphql.schema.sort.sort_order import SortOrder + + +class RoleSort(DbModelSort[Role]): + def __init__(self): + DbModelSort.__init__(self) + self.field("name", SortOrder) + self.field("description", SortOrder) diff --git a/src/cpl-graphql/cpl/graphql/auth/user/__init__.py b/src/cpl-graphql/cpl/graphql/auth/user/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/auth/user/user_filter.py b/src/cpl-graphql/cpl/graphql/auth/user/user_filter.py new file mode 100644 index 00000000..991e6efb --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/user/user_filter.py @@ -0,0 +1,11 @@ +from cpl.auth.schema import User +from cpl.graphql.schema.filter.db_model_filter import DbModelFilter +from cpl.graphql.schema.filter.string_filter import StringFilter + + +class UserFilter(DbModelFilter[User]): + def __init__(self, public: bool = False): + DbModelFilter.__init__(self, public) + + self.field("username", StringFilter).with_public(public) + self.field("email", StringFilter).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/auth/user/user_graph_type.py b/src/cpl-graphql/cpl/graphql/auth/user/user_graph_type.py new file mode 100644 index 00000000..f0ffa1ab --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/user/user_graph_type.py @@ -0,0 +1,12 @@ +from cpl.auth.schema import User +from cpl.graphql.schema.db_model_graph_type import DbModelGraphType + + +class UserGraphType(DbModelGraphType[User]): + + def __init__(self, public: bool = False): + DbModelGraphType.__init__(self) + + self.string_field(User.keycloak_id, lambda root: root.keycloak_id).with_public(public) + self.string_field(User.username, lambda root: root.username).with_public(public) + self.string_field(User.email, lambda root: root.email).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/auth/user/user_input.py b/src/cpl-graphql/cpl/graphql/auth/user/user_input.py new file mode 100644 index 00000000..c5f5ac07 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/user/user_input.py @@ -0,0 +1,23 @@ +from cpl.auth.schema import User +from cpl.core.typing import SerialId +from cpl.graphql.schema.input import Input + + +class UserCreateInput(Input[User]): + keycloak_id: str + roles: list[SerialId] | None + + def __init__(self): + Input.__init__(self) + self.string_field("keycloak_id").with_required() + self.list_field("roles", SerialId) + + +class UserUpdateInput(Input[User]): + id: SerialId + roles: list[SerialId] | None + + def __init__(self): + Input.__init__(self) + self.int_field("id").with_required() + self.list_field("roles", SerialId) diff --git a/src/cpl-graphql/cpl/graphql/auth/user/user_mutation.py b/src/cpl-graphql/cpl/graphql/auth/user/user_mutation.py new file mode 100644 index 00000000..59afb752 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/user/user_mutation.py @@ -0,0 +1,112 @@ +from cpl.api import APILogger +from cpl.auth.keycloak import KeycloakAdmin +from cpl.auth.permission import Permissions +from cpl.auth.schema import UserDao, User, RoleUser, RoleUserDao, RoleDao +from cpl.core.ctx.user_context import get_user +from cpl.graphql.auth.user.user_input import UserCreateInput, UserUpdateInput +from cpl.graphql.schema.mutation import Mutation + + +class UserMutation(Mutation): + def __init__( + self, + logger: APILogger, + user_dao: UserDao, + role_user_dao: RoleUserDao, + role_dao: RoleDao, + keycloak_admin: KeycloakAdmin, + ): + Mutation.__init__(self) + self._logger = logger + self._user_dao = user_dao + self._role_user_dao = role_user_dao + self._role_dao = role_dao + self._keycloak_admin = keycloak_admin + + self.int_field( + "create", + self.resolve_create, + ).with_require_any_permission(Permissions.users_create).with_argument( + "input", + UserCreateInput, + ).with_required() + + self.bool_field( + "update", + self.resolve_update, + ).with_require_any_permission(Permissions.users_update).with_argument( + "input", + UserUpdateInput, + ).with_required() + + self.bool_field( + "delete", + self.resolve_delete, + ).with_require_any_permission(Permissions.users_delete).with_argument( + "id", + int, + ).with_required() + + self.bool_field( + "restore", + self.resolve_restore, + ).with_require_any_permission(Permissions.users_delete).with_argument( + "id", + int, + ).with_required() + + async def resolve_create(self, input: UserCreateInput): + self._logger.debug(f"create user: {input.__dict__}") + + # ensure keycloak knows a user with this keycloak_id + # get_user should raise an exception if the user does not exist + kc_user = self._keycloak_admin.get_user(input.keycloak_id) + if kc_user is None: + raise ValueError(f"Keycloak user with id {input.keycloak_id} does not exist") + + user = User(0, input.keycloak_id, input.license) + user_id = await self._user_dao.create(user) + user = await self._user_dao.get_by_id(user_id) + await self._role_user_dao.create_many([RoleUser(0, user.id, x) for x in set(input.roles)]) + + return user + + async def resolve_update(self, input: UserUpdateInput): + self._logger.debug(f"update user: {input.__dict__}") + user = await self._user_dao.get_by_id(input.id) + + if input.license: + user.license = input.license + + await self._user_dao.update(user) + await self._resolve_assignments( + input.roles or [], + user, + RoleUser.user_id, + RoleUser.role_id, + self._user_dao, + self._role_user_dao, + RoleUser, + self._role_dao, + ) + + return user + + async def resolve_delete(self, id: int): + self._logger.debug(f"delete user: {id}") + user = await self._user_dao.get_by_id(id) + await self._user_dao.delete(user) + try: + active_user = get_user() + if active_user is not None and active_user.id == user.id: + # await broadcast.publish("userLogout", user.id) + self._keycloak_admin.user_logout(user_id=user.keycloak_id) + except Exception as e: + self._logger.error(f"Failed to logout user from Keycloak", e) + return True + + async def resolve_restore(self, id: int): + self._logger.debug(f"restore user: {id}") + user = await self._user_dao.get_by_id(id) + await self._user_dao.restore(user) + return True diff --git a/src/cpl-graphql/cpl/graphql/auth/user/user_sort.py b/src/cpl-graphql/cpl/graphql/auth/user/user_sort.py new file mode 100644 index 00000000..fe0cb8b1 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/auth/user/user_sort.py @@ -0,0 +1,10 @@ +from cpl.auth.schema import User +from cpl.graphql.schema.sort.db_model_sort import DbModelSort +from cpl.graphql.schema.sort.sort_order import SortOrder + + +class UserSort(DbModelSort[User]): + def __init__(self): + DbModelSort.__init__(self) + self.field("username", SortOrder) + self.field("email", SortOrder) diff --git a/src/cpl-graphql/cpl/graphql/error.py b/src/cpl-graphql/cpl/graphql/error.py new file mode 100644 index 00000000..ecab2c06 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/error.py @@ -0,0 +1,14 @@ +from graphql import GraphQLError + +from cpl.api import APIError + + +def graphql_error(api_error: APIError) -> GraphQLError: + """Convert an APIError (from cpl-api) into a GraphQL-friendly error.""" + return GraphQLError( + message=api_error.error_message, + extensions={ + "code": api_error.status_code, + }, + original_error=api_error, + ) diff --git a/src/cpl-graphql/cpl/graphql/event_bus/__init__.py b/src/cpl-graphql/cpl/graphql/event_bus/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/event_bus/memory.py b/src/cpl-graphql/cpl/graphql/event_bus/memory.py new file mode 100644 index 00000000..4d74c1af --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/event_bus/memory.py @@ -0,0 +1,27 @@ +import asyncio +from typing import Any, AsyncGenerator + +from cpl.dependency.event_bus import EventBusABC + + +class InMemoryEventBus(EventBusABC): + def __init__(self): + self._subscribers: dict[str, list[asyncio.Queue]] = {} + + async def publish(self, channel: str, event: Any) -> None: + queues = self._subscribers.get(channel, []) + for q in queues.copy(): + await q.put(event) + + async def subscribe(self, channel: str) -> AsyncGenerator[Any, None]: + q = asyncio.Queue() + if channel not in self._subscribers: + self._subscribers[channel] = [] + self._subscribers[channel].append(q) + + try: + while True: + item = await q.get() + yield item + finally: + self._subscribers[channel].remove(q) diff --git a/src/cpl-graphql/cpl/graphql/graphql_module.py b/src/cpl-graphql/cpl/graphql/graphql_module.py new file mode 100644 index 00000000..3672e119 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/graphql_module.py @@ -0,0 +1,25 @@ +from cpl.api.api_module import ApiModule +from cpl.dependency.module.module import Module +from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.schema.filter.bool_filter import BoolFilter +from cpl.graphql.schema.filter.date_filter import DateFilter +from cpl.graphql.schema.filter.filter import Filter +from cpl.graphql.schema.filter.int_filter import IntFilter +from cpl.graphql.schema.filter.string_filter import StringFilter +from cpl.graphql.schema.root_mutation import RootMutation +from cpl.graphql.schema.root_query import RootQuery +from cpl.graphql.schema.root_subscription import RootSubscription +from cpl.graphql.service.graphql import GraphQLService +from cpl.graphql.service.schema import Schema + + +class GraphQLModule(Module): + dependencies = [ApiModule] + singleton = [Schema, RootQuery, RootMutation, RootSubscription] + scoped = [GraphQLService] + transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter] + + @staticmethod + def configure(services: ServiceProvider) -> None: + schema = services.get_service(Schema) + schema.build() diff --git a/src/cpl-graphql/cpl/graphql/query_context.py b/src/cpl-graphql/cpl/graphql/query_context.py new file mode 100644 index 00000000..831273c4 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/query_context.py @@ -0,0 +1,48 @@ +from enum import Enum +from typing import Optional + +from graphql import GraphQLResolveInfo + +from cpl.auth.schema import User, Permission +from cpl.core.ctx import get_user + + +class QueryContext: + + def __init__(self, user_permissions: Optional[list[Enum | Permission]], is_mutation: bool = False, *args, **kwargs): + self._user = get_user() + self._user_permissions = user_permissions or [] + + self._resolve_info = None + for arg in args: + if isinstance(arg, GraphQLResolveInfo): + self._resolve_info = arg + continue + + self._args = args + self._kwargs = kwargs + + self._is_mutation = is_mutation + + @property + def user(self) -> User: + return self._user + + @property + def resolve_info(self) -> Optional[GraphQLResolveInfo]: + return self._resolve_info + + @property + def args(self) -> tuple: + return self._args + + @property + def kwargs(self) -> dict: + return self._kwargs + + @property + def is_mutation(self) -> bool: + return self._is_mutation + + def has_permission(self, permission: Enum | str) -> bool: + return permission.value if isinstance(permission, Enum) else permission in self._user_permissions diff --git a/src/cpl-graphql/cpl/graphql/schema/__init__.py b/src/cpl-graphql/cpl/graphql/schema/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/schema/argument.py b/src/cpl-graphql/cpl/graphql/schema/argument.py new file mode 100644 index 00000000..3332ddd0 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/argument.py @@ -0,0 +1,54 @@ +from typing import Any, Self + + +class Argument: + + def __init__( + self, + name: str, + t: type, + description: str = None, + default: Any = None, + optional: bool = None, + ): + self._name = name + self._type = t + self._description = description + self._default = default + self._optional = optional + + @property + def name(self) -> str: + return self._name + + @property + def type(self) -> type: + return self._type + + @property + def description(self) -> str | None: + return self._description + + @property + def default(self) -> Any | None: + return self._default + + @property + def optional(self) -> bool | None: + return self._optional + + def with_description(self, description: str) -> Self: + self._description = description + return self + + def with_default(self, default: Any) -> Self: + self._default = default + return self + + def with_optional(self, optional: bool) -> Self: + self._optional = optional + return self + + def with_required(self, required: bool = True) -> Self: + self._optional = not required + return self diff --git a/src/cpl-graphql/cpl/graphql/schema/collection.py b/src/cpl-graphql/cpl/graphql/schema/collection.py new file mode 100644 index 00000000..650fc71e --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/collection.py @@ -0,0 +1,61 @@ +from typing import Type, Dict, List + +import strawberry + +from cpl.core.typing import T +from cpl.dependency import get_provider +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol + + +from cpl.graphql.utils.type_collector import TypeCollector + + +class CollectionGraphTypeFactory: + @classmethod + def get(cls, node_type: Type[StrawberryProtocol]) -> Type: + type_name = f"{node_type.__name__.replace('GraphType', '')}Collection" + + if TypeCollector.has(type_name): + return TypeCollector.get(type_name) + + node_t = get_provider().get_service(node_type) + if not node_t: + raise ValueError(f"Node type '{node_type.__name__}' not registered in service provider") + + gql_node = node_t.to_strawberry() if hasattr(node_type, "to_strawberry") else node_type + + gql_cls = type(type_name, (), {}) + + TypeCollector.set(type_name, gql_cls) + + gql_cls.__annotations__ = { + "nodes": List[gql_node], + "total_count": int, + "count": int, + } + for k in gql_cls.__annotations__.keys(): + setattr(gql_cls, k, strawberry.field()) + + gql_type = strawberry.type(gql_cls) + + TypeCollector.set(type_name, gql_type) + return gql_type + + +class Collection: + def __init__(self, nodes: list[T], total_count: int, count: int): + self._nodes = nodes + self._total_count = total_count + self._count = count + + @property + def nodes(self) -> list[T]: + return self._nodes + + @property + def total_count(self) -> int: + return self._total_count + + @property + def count(self) -> int: + return self._count diff --git a/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py b/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py new file mode 100644 index 00000000..ed4153a2 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/db_model_graph_type.py @@ -0,0 +1,62 @@ +from typing import Type, Optional, Generic, Annotated + +import strawberry + +from cpl.core.configuration import Configuration +from cpl.core.typing import T +from cpl.database.abc.data_access_object_abc import DataAccessObjectABC +from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.schema.query import Query + + +class DbModelGraphType(GraphType[T], Generic[T]): + + def __init__(self, t_dao: Type[DataAccessObjectABC] = None, with_history: bool = False, public: bool = False): + Query.__init__(self) + + self._dao: Optional[DataAccessObjectABC] = None + + if t_dao is not None: + dao = self._provider.get_service(t_dao) + if dao is not None: + self._dao = dao + + self.int_field("id", lambda root: root.id).with_public(public) + self.bool_field("deleted", lambda root: root.deleted).with_public(public) + + if Configuration.get("GraphQLAuthModuleEnabled", False): + from cpl.graphql.auth.user.user_graph_type import UserGraphType + + self.object_field("editor", lambda: UserGraphType, lambda root: root.editor).with_public(public) + + self.string_field("created", lambda root: root.created).with_public(public) + self.string_field("updated", lambda root: root.updated).with_public(public) + + # if with_history: + # if self._dao is None: + # raise ValueError("DAO must be provided to enable history") + # self.set_field("history", self._resolve_history).with_public(public) + + self._history_reference_daos: dict[DataAccessObjectABC, str] = {} + + async def _resolve_history(self, root): + if self._dao is None: + raise Exception("DAO not set for history query") + + history = sorted( + [await self._dao.get_by_id(root.id), *await self._dao.get_history(root.id)], + key=lambda h: h.updated, + reverse=True, + ) + return history + + def set_history_reference_dao(self, dao: DataAccessObjectABC, key: str = None): + """ + Set the reference DAO for history resolution. + :param dao: + :param key: The key to use for resolving history. + :return: + """ + if key is None: + key = "id" + self._history_reference_daos[dao] = key diff --git a/src/cpl-graphql/cpl/graphql/schema/field.py b/src/cpl-graphql/cpl/graphql/schema/field.py new file mode 100644 index 00000000..7866fafa --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/field.py @@ -0,0 +1,141 @@ +from enum import Enum +from typing import Self + +from cpl.graphql.schema.argument import Argument +from cpl.graphql.typing import TQuery, Resolver, TRequireAnyPermissions, TRequireAnyResolvers + + +class Field: + + def __init__( + self, + name: str, + t: type = None, + resolver: Resolver = None, + optional=None, + default=None, + subquery: TQuery = None, + parent_type=None, + ): + self._name = name + self._type = t + self._resolver = resolver + self._optional = optional or True + self._default = default + + self._subquery = subquery + self._parent_type = parent_type + + self._args: dict[str, Argument] = {} + self._require_any_permission = None + self._require_any = None + self._public = False + + @property + def name(self) -> str: + return self._name + + @property + def type(self) -> type: + return self._type + + @property + def resolver(self) -> callable: + return self._resolver + + @property + def optional(self) -> bool | None: + return self._optional + + @property + def default(self): + return self._default + + @property + def args(self) -> dict: + return self._args + + @property + def subquery(self) -> TQuery | None: + return self._subquery + + @property + def parent_type(self): + return self._parent_type + + @property + def arguments(self) -> dict[str, Argument]: + return self._args + + @property + def require_any_permission(self) -> TRequireAnyPermissions | None: + return self._require_any_permission + + @property + def require_any(self) -> TRequireAnyResolvers | None: + return self._require_any + + @property + def public(self) -> bool: + return self._public + + def with_type(self, t: type) -> Self: + self._type = t + return self + + def with_resolver(self, resolver: Resolver) -> Self: + self._resolver = resolver + return self + + def with_optional(self, optional: bool = True) -> Self: + self._optional = optional + return self + + def with_required(self, required: bool = True) -> Self: + self._optional = not required + return self + + def with_default(self, default) -> Self: + self._default = default + return self + + def with_argument( + self, name: str, arg_type: type, description: str = None, default_value=None, optional=True + ) -> Argument: + if name in self._args: + raise ValueError(f"Argument with name '{name}' already exists in field '{self._name}'") + self._args[name] = Argument(name, arg_type, description, default_value, optional) + return self._args[name] + + def with_arguments(self, args: list[Argument]) -> Self: + for arg in args: + if not isinstance(arg, Argument): + raise ValueError(f"Expected Argument instance, got {type(arg)}") + + self.with_argument(arg.type, arg.name, arg.description, arg.default, arg.optional) + return self + + def with_require_any_permission(self, *permissions: TRequireAnyPermissions) -> Self: + if not isinstance(permissions, list): + permissions = list(permissions) + + assert permissions is not None, "require_any_permission cannot be None" + assert all(isinstance(x, (str, Enum)) for x in permissions), "All permissions must be of Permission type" + self._require_any_permission = permissions + return self + + def with_require_any(self, permissions: TRequireAnyPermissions, resolvers: TRequireAnyResolvers) -> Self: + assert permissions is not None, "permissions cannot be None" + assert all(isinstance(p, (str, Enum)) for p in permissions), "All permissions must be of Permission type" + assert resolvers is not None, "resolvers cannot be None" + assert all(callable(r) for r in resolvers), "All resolvers must be callable" + self._require_any = (permissions, resolvers) + return self + + def with_public(self, public: bool = True) -> Self: + if public: + self._require_any = None + self._require_any_permission = None + + self._public = public + return self diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/__init__.py b/src/cpl-graphql/cpl/graphql/schema/filter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/bool_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/bool_filter.py new file mode 100644 index 00000000..4be0db85 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/filter/bool_filter.py @@ -0,0 +1,10 @@ +from cpl.graphql.schema.input import Input + + +class BoolFilter(Input[bool]): + def __init__(self): + super().__init__() + self.field("equal", bool, optional=True) + self.field("notEqual", bool, optional=True) + self.field("isNull", bool, optional=True) + self.field("isNotNull", bool, optional=True) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py new file mode 100644 index 00000000..0149a3b9 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/filter/date_filter.py @@ -0,0 +1,18 @@ +from datetime import datetime + +from cpl.graphql.schema.input import Input + + +class DateFilter(Input[datetime]): + def __init__(self): + super().__init__() + self.field("equal", datetime, optional=True) + self.field("notEqual", datetime, optional=True) + self.field("greater", datetime, optional=True) + self.field("greaterOrEqual", datetime, optional=True) + self.field("less", datetime, optional=True) + self.field("lessOrEqual", datetime, optional=True) + self.field("isNull", datetime, optional=True) + self.field("isNotNull", datetime, optional=True) + self.field("in", list[datetime], optional=True) + self.field("notIn", list[datetime], optional=True) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py new file mode 100644 index 00000000..4a91544c --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/filter/db_model_filter.py @@ -0,0 +1,23 @@ +from typing import Generic + +from cpl.core.configuration.configuration import Configuration +from cpl.core.typing import T +from cpl.graphql.schema.filter.bool_filter import BoolFilter +from cpl.graphql.schema.filter.date_filter import DateFilter +from cpl.graphql.schema.filter.filter import Filter +from cpl.graphql.schema.filter.int_filter import IntFilter + + +class DbModelFilter(Filter[T], Generic[T]): + def __init__(self, public: bool = False): + Filter.__init__(self) + + self.field("id", IntFilter).with_public(public) + self.field("deleted", BoolFilter).with_public(public) + if Configuration.get("GraphQLAuthModuleEnabled", False): + from cpl.graphql.auth.user.user_filter import UserFilter + + self.field("editor", lambda: UserFilter).with_public(public) + + self.field("created", DateFilter).with_public(public) + self.field("updated", DateFilter).with_public(public) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py new file mode 100644 index 00000000..75bd3c3c --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/filter/filter.py @@ -0,0 +1,28 @@ +from typing import Type + +from cpl.core.typing import T +from cpl.graphql.schema.filter.bool_filter import BoolFilter +from cpl.graphql.schema.filter.date_filter import DateFilter +from cpl.graphql.schema.filter.int_filter import IntFilter +from cpl.graphql.schema.filter.string_filter import StringFilter +from cpl.graphql.schema.input import Input + + +class Filter(Input[T]): + def __init__(self): + Input.__init__(self) + + def filter_field(self, name: str, filter_type: Type["Filter"]): + self.field(name, filter_type) + + def string_field(self, name: str): + self.field(name, StringFilter) + + def int_field(self, name: str): + self.field(name, IntFilter) + + def bool_field(self, name: str): + self.field(name, BoolFilter) + + def date_field(self, name: str): + self.field(name, DateFilter) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py new file mode 100644 index 00000000..801ad562 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/filter/int_filter.py @@ -0,0 +1,16 @@ +from cpl.graphql.schema.input import Input + + +class IntFilter(Input[int]): + def __init__(self): + super().__init__() + self.field("equal", int, optional=True) + self.field("notEqual", int, optional=True) + self.field("greater", int, optional=True) + self.field("greaterOrEqual", int, optional=True) + self.field("less", int, optional=True) + self.field("lessOrEqual", int, optional=True) + self.field("isNull", int, optional=True) + self.field("isNotNull", int, optional=True) + self.field("in", list[int], optional=True) + self.field("notIn", list[int], optional=True) diff --git a/src/cpl-graphql/cpl/graphql/schema/filter/string_filter.py b/src/cpl-graphql/cpl/graphql/schema/filter/string_filter.py new file mode 100644 index 00000000..7c060abc --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/filter/string_filter.py @@ -0,0 +1,16 @@ +from cpl.graphql.schema.input import Input + + +class StringFilter(Input[str]): + def __init__(self): + super().__init__() + self.field("equal", str, optional=True) + self.field("notEqual", str, optional=True) + self.field("contains", str, optional=True) + self.field("notContains", str, optional=True) + self.field("startsWith", str, optional=True) + self.field("endsWith", str, optional=True) + self.field("isNull", str, optional=True) + self.field("isNotNull", str, optional=True) + self.field("in", list[str], optional=True) + self.field("notIn", list[str], optional=True) diff --git a/src/cpl-graphql/cpl/graphql/schema/graph_type.py b/src/cpl-graphql/cpl/graphql/schema/graph_type.py new file mode 100644 index 00000000..b4d5b422 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/graph_type.py @@ -0,0 +1,10 @@ +from typing import Generic + +from cpl.core.typing import T +from cpl.graphql.schema.query import Query + + +class GraphType(Query, Generic[T]): + + def __init__(self): + Query.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/schema/input.py b/src/cpl-graphql/cpl/graphql/schema/input.py new file mode 100644 index 00000000..ce7817ab --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/input.py @@ -0,0 +1,115 @@ +import types +from typing import Generic, Dict, Type, Optional, Union, Any + +import strawberry + +from cpl.core.typing import T +from cpl.dependency import get_provider +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol +from cpl.graphql.schema.field import Field +from cpl.graphql.typing import AttributeName +from cpl.graphql.utils.type_collector import TypeCollector + +_PYTHON_KEYWORDS = {"in", "not", "is", "and", "or"} + + +class Input(StrawberryProtocol, Generic[T]): + def __init__(self): + self._fields: Dict[str, Field] = {} + self._values: Dict[str, Any] = {} + + @property + def fields(self) -> Dict[str, Field]: + return self._fields + + def __getattr__(self, item): + if item in self._values: + return self._values[item] + raise AttributeError(f"{self.__class__.__name__} has no attribute {item}") + + def __setattr__(self, key, value): + if key in {"_fields", "_values"}: + super().__setattr__(key, value) + elif key in self._fields: + self._values[key] = value + else: + super().__setattr__(key, value) + + def get(self, key: str, default=None): + return self._values.get(key, default) + + def get_fields(self) -> dict[str, Field]: + return self._fields + + def field(self, name: AttributeName, typ: type, optional: bool = True) -> Field: + if isinstance(name, property): + name = name.fget.__name__ + + self._fields[name] = Field(name, typ, optional=optional) + return self._fields[name] + + def string_field(self, name: AttributeName, optional: bool = True) -> Field: + return self.field(name, str) + + def int_field(self, name: AttributeName, optional: bool = True) -> Field: + return self.field(name, int, optional) + + def float_field(self, name: AttributeName, optional: bool = True) -> Field: + return self.field(name, float, optional) + + def bool_field(self, name: AttributeName, optional: bool = True) -> Field: + return self.field(name, bool, optional) + + def list_field(self, name: AttributeName, t: type, optional: bool = True) -> Field: + return self.field(name, list[t], optional) + + def object_field(self, name: AttributeName, t: Type[StrawberryProtocol], optional: bool = True) -> Field: + if not isinstance(t, type) and callable(t): + return self.field(name, t, optional) + + return self.field(name, t().to_strawberry(), optional) + + def to_strawberry(self) -> Type: + cls = self.__class__ + if TypeCollector.has(cls): + return TypeCollector.get(cls) + + gql_cls = type(f"{cls.__name__.replace('GraphType', '')}", (), {}) + # register early to handle recursive types + TypeCollector.set(cls, gql_cls) + + annotations: dict[str, Any] = {} + namespace: dict[str, Any] = {} + + for name, f in self._fields.items(): + t = f.type + + if isinstance(t, types.FunctionType): + _t = get_provider().get_service(t()) + if _t is None: + raise ValueError(f"'{t()}' could not be resolved from the provider") + t = _t.to_strawberry() + elif isinstance(t, type) and issubclass(t, Input): + t = t().to_strawberry() + elif isinstance(t, Input): + t = t.to_strawberry() + + py_name = name + "_" if name in _PYTHON_KEYWORDS else name + annotations[py_name] = t if not f.optional else Optional[t] + + field_args = {} + if py_name != name: + field_args["name"] = name + + default = None if f.optional else f.default + namespace[py_name] = strawberry.field(default=default, **field_args) + + namespace["__annotations__"] = annotations + + for k, v in namespace.items(): + setattr(gql_cls, k, v) + + gql_cls.__annotations__ = annotations + gql_type = strawberry.input(gql_cls) + TypeCollector.set(cls, gql_type) + return gql_type diff --git a/src/cpl-graphql/cpl/graphql/schema/mutation.py b/src/cpl-graphql/cpl/graphql/schema/mutation.py new file mode 100644 index 00000000..d336f3b1 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/mutation.py @@ -0,0 +1,93 @@ +from typing import Type, Union + +from cpl.core.typing import T +from cpl.database.abc import DataAccessObjectABC, DbJoinModelABC +from cpl.dependency.inject import inject +from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.abc.query_abc import QueryABC +from cpl.graphql.schema.field import Field + + +class Mutation(QueryABC): + + @inject + def __init__(self, provider: ServiceProvider): + QueryABC.__init__(self) + self._provider = provider + + from cpl.graphql.service.schema import Schema + + self._schema = provider.get_service(Schema) + + def with_mutation(self, name: str, cls: Type["Mutation"]) -> Field: + sub = self._provider.get_service(cls) + if not sub: + raise ValueError(f"Mutation '{cls.__name__}' not registered in service provider") + + return self.field(name, sub.to_strawberry(), lambda: sub) + + @staticmethod + async def _resolve_assignments( + foreign_objs: list[int], + resolved_obj: T, + reference_key_own: Union[str, property], + reference_key_foreign: Union[str, property], + source_dao: DataAccessObjectABC[T], + join_dao: DataAccessObjectABC[T], + join_type: Type[DbJoinModelABC], + foreign_dao: DataAccessObjectABC[T], + ): + if foreign_objs is None: + return + + reference_key_foreign_attr = reference_key_foreign + if isinstance(reference_key_foreign, property): + reference_key_foreign_attr = reference_key_foreign.fget.__name__ + + foreign_list = await join_dao.find_by([{reference_key_own: resolved_obj.id}, {"deleted": False}]) + + to_delete = ( + foreign_list + if len(foreign_objs) == 0 + else await join_dao.find_by( + [ + {reference_key_own: resolved_obj.id}, + {reference_key_foreign: {"notIn": foreign_objs}}, + ] + ) + ) + foreign_ids = [getattr(x, reference_key_foreign_attr) for x in foreign_list] + deleted_foreign_ids = [ + getattr(x, reference_key_foreign_attr) + for x in await join_dao.find_by([{reference_key_own: resolved_obj.id}, {"deleted": True}]) + ] + + to_create = [ + join_type(0, resolved_obj.id, x) + for x in foreign_objs + if x not in foreign_ids and x not in deleted_foreign_ids + ] + to_restore = [ + await join_dao.get_single_by( + [ + {reference_key_own: resolved_obj.id}, + {reference_key_foreign: x}, + ] + ) + for x in foreign_objs + if x not in foreign_ids and x in deleted_foreign_ids + ] + + if len(to_delete) > 0: + await join_dao.delete_many(to_delete) + + if len(to_create) > 0: + await join_dao.create_many(to_create) + + if len(to_restore) > 0: + await join_dao.restore_many(to_restore) + + foreign_changes = [*to_delete, *to_create, *to_restore] + if len(foreign_changes) > 0: + await source_dao.touch(resolved_obj) + await foreign_dao.touch_many_by_id([getattr(x, reference_key_foreign_attr) for x in foreign_changes]) diff --git a/src/cpl-graphql/cpl/graphql/schema/query.py b/src/cpl-graphql/cpl/graphql/schema/query.py new file mode 100644 index 00000000..cbd05781 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/query.py @@ -0,0 +1,131 @@ +from typing import Callable, Type + +from cpl.database.abc.data_access_object_abc import DataAccessObjectABC +from cpl.dependency.inject import inject +from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.abc.query_abc import QueryABC +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol +from cpl.graphql.schema.collection import Collection, CollectionGraphTypeFactory +from cpl.graphql.schema.field import Field +from cpl.graphql.schema.sort.sort_order import SortOrder + + +class Query(QueryABC): + + @inject + def __init__(self, provider: ServiceProvider): + QueryABC.__init__(self) + self._provider = provider + + from cpl.graphql.service.schema import Schema + + self._schema = provider.get_service(Schema) + + def with_query(self, name: str, subquery_cls: Type["Query"]) -> Field: + sub = self._provider.get_service(subquery_cls) + if not sub: + raise ValueError(f"Subquery '{subquery_cls.__name__}' not registered in service provider") + + return self.field(name, sub.to_strawberry(), lambda: sub) + + def collection_field( + self, + t: type, + name: str, + filter_type: Type[StrawberryProtocol], + sort_type: Type[StrawberryProtocol], + resolver: Callable, + ) -> Field: + def _resolve_collection(filter=None, sort=None, skip=0, take=10): + items = resolver() + if filter: + for field, value in filter.__dict__.items(): + if value is None: + continue + items = [i for i in items if getattr(i, field) == value] + + if sort: + for field, direction in sort.__dict__.items(): + reverse = direction == SortOrder.DESC + items = sorted(items, key=lambda i: getattr(i, field), reverse=reverse) + total_count = len(items) + paged = items[skip : skip + take] + return Collection(nodes=paged, total_count=total_count, count=len(paged)) + + filter = self._provider.get_service(filter_type) + if not filter: + raise ValueError(f"Filter '{filter_type.__name__}' not registered in service provider") + + sort = self._provider.get_service(sort_type) + if not sort: + raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider") + + f = self.field(name, CollectionGraphTypeFactory.get(t), _resolve_collection) + f.with_argument("filter", filter.to_strawberry()) + f.with_argument("sort", sort.to_strawberry()) + f.with_argument("skip", int, default_value=0) + f.with_argument("take", int, default_value=10) + return f + + def dao_collection_field( + self, + t: Type[StrawberryProtocol], + dao_type: Type[DataAccessObjectABC], + name: str, + filter_type: Type[StrawberryProtocol], + sort_type: Type[StrawberryProtocol], + ) -> Field: + assert issubclass(dao_type, DataAccessObjectABC), "dao_type must be a subclass of DataAccessObjectABC" + dao = self._provider.get_service(dao_type) + if not dao: + raise ValueError(f"DAO '{dao_type.__name__}' not registered in service provider") + + filter = self._provider.get_service(filter_type) + if not filter: + raise ValueError(f"Filter '{filter_type.__name__}' not registered in service provider") + + sort = self._provider.get_service(sort_type) + if not sort: + raise ValueError(f"Sort '{sort_type.__name__}' not registered in service provider") + + def input_to_dict(obj) -> dict | None: + if obj is None: + return None + + result = {} + for k, v in obj.__dict__.items(): + if v is None: + continue + + if hasattr(v, "__dict__"): + result[k] = input_to_dict(v) + else: + result[k] = v + return result + + async def _resolver(filter=None, sort=None, take=10, skip=0): + filter_dict = input_to_dict(filter) if filter is not None else None + sort_dict = None + + if sort is not None: + sort_dict = {} + for k, v in sort.__dict__.items(): + if v is None: + continue + + if isinstance(v, SortOrder): + sort_dict[k] = str(v.value).lower() + continue + + sort_dict[k] = str(v).lower() + + total_count = await dao.count(filter_dict) + data = await dao.find_by(filter_dict, sort_dict, take, skip) + return Collection(nodes=data, total_count=total_count, count=len(data)) + + f = self.field(name, CollectionGraphTypeFactory.get(t), _resolver) + f.with_argument("filter", filter.to_strawberry()) + f.with_argument("sort", sort.to_strawberry()) + f.with_argument("skip", int, default_value=0) + f.with_argument("take", int, default_value=10) + return f diff --git a/src/cpl-graphql/cpl/graphql/schema/root_mutation.py b/src/cpl-graphql/cpl/graphql/schema/root_mutation.py new file mode 100644 index 00000000..8855d8e7 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/root_mutation.py @@ -0,0 +1,6 @@ +from cpl.graphql.schema.mutation import Mutation + + +class RootMutation(Mutation): + def __init__(self): + Mutation.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/schema/root_query.py b/src/cpl-graphql/cpl/graphql/schema/root_query.py new file mode 100644 index 00000000..85ee1d38 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/root_query.py @@ -0,0 +1,6 @@ +from cpl.graphql.schema.query import Query + + +class RootQuery(Query): + def __init__(self): + Query.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/schema/root_subscription.py b/src/cpl-graphql/cpl/graphql/schema/root_subscription.py new file mode 100644 index 00000000..fab2bc8f --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/root_subscription.py @@ -0,0 +1,6 @@ +from cpl.graphql.schema.subscription import Subscription + + +class RootSubscription(Subscription): + def __init__(self): + Subscription.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/schema/sort/__init__.py b/src/cpl-graphql/cpl/graphql/schema/sort/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/schema/sort/db_model_sort.py b/src/cpl-graphql/cpl/graphql/schema/sort/db_model_sort.py new file mode 100644 index 00000000..02726ec8 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/sort/db_model_sort.py @@ -0,0 +1,19 @@ +from typing import Generic + +from cpl.core.configuration import Configuration +from cpl.core.typing import T +from cpl.graphql.schema.sort.sort import Sort +from cpl.graphql.schema.sort.sort_order import SortOrder + + +class DbModelSort(Sort[T], Generic[T]): + def __init__( + self, + ): + Sort.__init__(self) + self.field("id", SortOrder) + self.field("deleted", SortOrder) + if Configuration.get("GraphQLAuthModuleEnabled", False): + self.field("editor", SortOrder) + self.field("created", SortOrder) + self.field("updated", SortOrder) diff --git a/src/cpl-graphql/cpl/graphql/schema/sort/sort.py b/src/cpl-graphql/cpl/graphql/schema/sort/sort.py new file mode 100644 index 00000000..ccbb6980 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/sort/sort.py @@ -0,0 +1,9 @@ +from cpl.core.typing import T +from cpl.graphql.schema.input import Input + + +class Sort(Input[T]): + def __init__( + self, + ): + Input.__init__(self) diff --git a/src/cpl-graphql/cpl/graphql/schema/sort/sort_order.py b/src/cpl-graphql/cpl/graphql/schema/sort/sort_order.py new file mode 100644 index 00000000..db75e06e --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/sort/sort_order.py @@ -0,0 +1,6 @@ +from enum import Enum, auto + + +class SortOrder(Enum): + ASC = "ASC" + DESC = "DESC" diff --git a/src/cpl-graphql/cpl/graphql/schema/subscription.py b/src/cpl-graphql/cpl/graphql/schema/subscription.py new file mode 100644 index 00000000..1be59d84 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/subscription.py @@ -0,0 +1,88 @@ +import inspect +from typing import Any, Type, Optional, Self + +import strawberry +from strawberry.exceptions import StrawberryException + +from cpl.api import Unauthorized, Forbidden +from cpl.core.ctx.user_context import get_user +from cpl.dependency import get_provider, inject +from cpl.dependency.event_bus import EventBusABC +from cpl.graphql.abc.query_abc import QueryABC +from cpl.graphql.error import graphql_error +from cpl.graphql.query_context import QueryContext +from cpl.graphql.schema.subscription_field import SubscriptionField +from cpl.graphql.typing import Selector + + +class Subscription(QueryABC): + + @inject + def __init__(self, bus: EventBusABC): + QueryABC.__init__(self) + self._bus = bus + + def subscription_field( + self, + name: str, + t: Type, + selector: Optional[Selector] = None, + channel: Optional[str] = None, + ) -> SubscriptionField: + field = SubscriptionField(name, t, selector, channel) + self._fields[name] = field + return field + + def with_subscription(self, sub_cls: Type[Self]) -> Self: + sub = get_provider().get_service(sub_cls) + if not sub: + raise ValueError(f"Subscription '{sub_cls.__name__}' not registered in provider") + + for sub_name, sub_field in sub.get_fields().items(): + self._fields[sub_name] = sub_field + + return self + + def _field_to_strawberry(self, f: SubscriptionField) -> Any: + try: + if isinstance(f, SubscriptionField): + + def make_resolver(field: SubscriptionField): + async def resolver(root=None, info=None): + if not field.public: + user = get_user() + if not user: + raise graphql_error(Unauthorized(f"{field.name}: Authentication required")) + + if field.require_any_permission: + ok = any([await user.has_permission(p) for p in field.require_any_permission]) + if not ok: + raise graphql_error(Forbidden(f"{field.name}: Permission denied")) + + if field.require_any: + perms, resolvers = field.require_any + ok = any([await user.has_permission(p) for p in perms]) + if not ok: + ctx = QueryContext([x.name for x in await user.permissions]) + results = [ + r(ctx) if not inspect.iscoroutinefunction(r) else await r(ctx) + for r in resolvers + ] + if not any(results): + raise graphql_error(Forbidden(f"{field.name}: Permission denied")) + + async for event in self._bus.subscribe(field.channel): + if field.selector is None or field.selector(event, info): + yield event + + return resolver + + return strawberry.subscription(resolver=make_resolver(f)) + + async def wrapper_resolver(root=None, info=None): + yield None + + return strawberry.subscription(resolver=wrapper_resolver) + + except StrawberryException as e: + raise Exception(f"Error converting subscription field '{f.name}': {e}") from e diff --git a/src/cpl-graphql/cpl/graphql/schema/subscription_field.py b/src/cpl-graphql/cpl/graphql/schema/subscription_field.py new file mode 100644 index 00000000..bab90a70 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/schema/subscription_field.py @@ -0,0 +1,25 @@ +from typing import Type, Callable, Optional + +from cpl.graphql.schema.field import Field +from cpl.graphql.typing import Selector + + +class SubscriptionField(Field): + def __init__( + self, + name: str, + t: Type, + selector: Optional[Selector] = None, + channel: Optional[str] = None, + ): + super().__init__(name, t) + self.selector = selector + self.channel = channel or name + + def with_selector(self, selector: Selector) -> "SubscriptionField": + self.selector = selector + return self + + def with_channel(self, channel: str) -> "SubscriptionField": + self.channel = channel + return self diff --git a/src/cpl-graphql/cpl/graphql/service/__init__.py b/src/cpl-graphql/cpl/graphql/service/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/service/graphql.py b/src/cpl-graphql/cpl/graphql/service/graphql.py new file mode 100644 index 00000000..7262906d --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/service/graphql.py @@ -0,0 +1,52 @@ +from typing import Any, Dict, Optional + +from graphql import GraphQLError + +from cpl.api import APILogger, APIError +from cpl.api.typing import TRequest +from cpl.graphql.service.schema import Schema + + +class GraphQLService: + def __init__(self, logger: APILogger, schema: Schema): + self._logger = logger + + if schema.schema is None: + raise ValueError("Schema has not been built. Call schema.build() before using the service.") + self._schema = schema.schema + + async def execute( + self, + query: str, + variables: Optional[Dict[str, Any]], + request: TRequest, + ) -> Dict[str, Any]: + result = await self._schema.execute( + query, + variable_values=variables, + context_value={"request": request}, + ) + + response_data: Dict[str, Any] = {} + if result.errors: + errors = [] + for error in result.errors: + if isinstance(error, APIError): + self._logger.error(f"GraphQL APIError", error) + errors.append({"message": error.error_message, "extensions": {"code": error.status_code}}) + continue + + if isinstance(error, GraphQLError): + + self._logger.error(f"GraphQLError", error) + errors.append({"message": error.message, "extensions": error.extensions}) + continue + + self._logger.error(f"GraphQL unexpected error", error) + errors.append({"message": str(error), "extensions": {"code": 500}}) + + response_data["errors"] = errors + if result.data: + response_data["data"] = result.data + + return response_data diff --git a/src/cpl-graphql/cpl/graphql/service/schema.py b/src/cpl-graphql/cpl/graphql/service/schema.py new file mode 100644 index 00000000..d56428b0 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/service/schema.py @@ -0,0 +1,76 @@ +import logging +from typing import Type, Self + +import strawberry + +from cpl.api.logger import APILogger +from cpl.dependency.service_provider import ServiceProvider +from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol +from cpl.graphql.schema.root_mutation import RootMutation +from cpl.graphql.schema.root_query import RootQuery +from cpl.graphql.schema.root_subscription import RootSubscription + + +class Schema: + + def __init__(self, logger: APILogger, provider: ServiceProvider): + self._logger = logger + self._provider = provider + + self._types: dict[str, Type[StrawberryProtocol]] = {} + + self._schema = None + + @property + def schema(self) -> strawberry.Schema | None: + return self._schema + + @property + def query(self) -> RootQuery: + query = self._provider.get_service(RootQuery) + if not query: + raise ValueError("RootQuery not registered in service provider") + return query + + @property + def mutation(self) -> RootMutation: + mutation = self._provider.get_service(RootMutation) + if not mutation: + raise ValueError("RootMutation not registered in service provider") + return mutation + + @property + def subscription(self) -> RootSubscription: + subscription = self._provider.get_service(RootSubscription) + if not subscription: + raise ValueError("RootSubscription not registered in service provider") + return subscription + + def with_type(self, t: Type[StrawberryProtocol]) -> Self: + self._types[t.__name__] = t + return self + + def _get_types(self): + types: list[Type] = [] + for t in self._types.values(): + t_obj = self._provider.get_service(t) + if not t_obj: + raise ValueError(f"Type '{t.__name__}' not registered in service provider") + types.append(t_obj.to_strawberry()) + + return types + + def build(self) -> strawberry.Schema: + logging.getLogger("strawberry.execution").setLevel(logging.CRITICAL) + + query = self.query + mutation = self.mutation + subscription = self.subscription + + self._schema = strawberry.Schema( + query=query.to_strawberry() if query.fields_count > 0 else None, + mutation=mutation.to_strawberry() if mutation.fields_count > 0 else None, + subscription=subscription.to_strawberry() if subscription.fields_count > 0 else None, + types=self._get_types(), + ) + return self._schema diff --git a/src/cpl-graphql/cpl/graphql/typing.py b/src/cpl-graphql/cpl/graphql/typing.py new file mode 100644 index 00000000..bb8cda8e --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/typing.py @@ -0,0 +1,16 @@ +from enum import Enum +from typing import Type, Callable, List, Tuple, Awaitable, Any + +import strawberry + +from cpl.auth.permission import Permissions +from cpl.graphql.query_context import QueryContext + +TQuery = Type["Query"] +Resolver = Callable +Selector = Callable[[Any, strawberry.types.Info], bool] +ScalarType = str | int | float | bool | object +AttributeName = str | property +TRequireAnyPermissions = List[Enum | Permissions] | None +TRequireAnyResolvers = List[Callable[[QueryContext], bool | Awaitable[bool]],] +TRequireAny = Tuple[TRequireAnyPermissions, TRequireAnyResolvers] diff --git a/src/cpl-graphql/cpl/graphql/utils/__init__.py b/src/cpl-graphql/cpl/graphql/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl-graphql/cpl/graphql/utils/name_pipe.py b/src/cpl-graphql/cpl/graphql/utils/name_pipe.py new file mode 100644 index 00000000..7e9b72b1 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/utils/name_pipe.py @@ -0,0 +1,28 @@ +from cpl.core.pipes import PipeABC +from cpl.core.typing import T +from cpl.graphql.schema.collection import CollectionGraphType +from cpl.graphql.schema.graph_type import GraphType +from cpl.graphql.schema.object_graph_type import ObjectGraphType + + +class NamePipe(PipeABC): + + @staticmethod + def to_str(value: type, *args) -> str: + if isinstance(value, str): + return value + + if not isinstance(value, type): + raise ValueError(f"Expected a type, got {type(value)}") + + if issubclass(value, CollectionGraphType): + return f"{value.__name__.replace(GraphType.__name__, "")}" + + if issubclass(value, (ObjectGraphType, GraphType)): + return value.__name__.replace(GraphType.__name__, "") + + return value.__name__ + + @staticmethod + def from_str(value: str, *args) -> T: + pass diff --git a/src/cpl-graphql/cpl/graphql/utils/type_collector.py b/src/cpl-graphql/cpl/graphql/utils/type_collector.py new file mode 100644 index 00000000..439d3ec2 --- /dev/null +++ b/src/cpl-graphql/cpl/graphql/utils/type_collector.py @@ -0,0 +1,17 @@ +from typing import Type, Any + + +class TypeCollector: + _registry: dict[type | str, Type] = {} + + @classmethod + def has(cls, base: type | str) -> bool: + return base in cls._registry + + @classmethod + def get(cls, base: type | str) -> Type: + return cls._registry[base] + + @classmethod + def set(cls, base: type | str, gql_type: Type): + cls._registry[base] = gql_type diff --git a/src/cpl-graphql/pyproject.toml b/src/cpl-graphql/pyproject.toml new file mode 100644 index 00000000..cecb85d2 --- /dev/null +++ b/src/cpl-graphql/pyproject.toml @@ -0,0 +1,30 @@ +[build-system] +requires = ["setuptools>=70.1.0", "wheel>=0.43.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "cpl-database" +version = "2024.7.0" +description = "CPL database" +readme ="CPL database package" +requires-python = ">=3.12" +license = { text = "MIT" } +authors = [ + { name = "Sven Heidemann", email = "sven.heidemann@sh-edraft.de" } +] +keywords = ["cpl", "database", "backend", "shared", "library"] + +dynamic = ["dependencies", "optional-dependencies"] + +[project.urls] +Homepage = "https://www.sh-edraft.de" + +[tool.setuptools.packages.find] +where = ["."] +include = ["cpl*"] + +[tool.setuptools.dynamic] +dependencies = { file = ["requirements.txt"] } +optional-dependencies.dev = { file = ["requirements.dev.txt"] } + + diff --git a/src/cpl-graphql/requirements.dev.txt b/src/cpl-graphql/requirements.dev.txt new file mode 100644 index 00000000..e7664b42 --- /dev/null +++ b/src/cpl-graphql/requirements.dev.txt @@ -0,0 +1 @@ +black==25.1.0 \ No newline at end of file diff --git a/src/cpl-graphql/requirements.txt b/src/cpl-graphql/requirements.txt new file mode 100644 index 00000000..d74de843 --- /dev/null +++ b/src/cpl-graphql/requirements.txt @@ -0,0 +1,2 @@ +cpl-api +strawberry-graphql==0.282.0 \ No newline at end of file diff --git a/src/cpl-query/cpl/query/ordered_enumerable.py b/src/cpl-query/cpl/query/ordered_enumerable.py index 89edc3d7..03405057 100644 --- a/src/cpl-query/cpl/query/ordered_enumerable.py +++ b/src/cpl-query/cpl/query/ordered_enumerable.py @@ -6,7 +6,7 @@ from cpl.query.typing import K class OrderedEnumerable(Enumerable[T]): def __init__(self, source, key_selectors: List[tuple[Callable[[T], K], bool]]): - super().__init__(source) + Enumerable.__init__(self, source) self._key_selectors = key_selectors def __iter__(self) -> Iterator[T]: