Merge pull request 'Added gql base #181' (#196) from #181_gql into dev
All checks were successful
Test before pr merge / test-lint (pull_request) Successful in 7s
Build on push / prepare (push) Successful in 12s
Build on push / query (push) Successful in 23s
Build on push / core (push) Successful in 23s
Build on push / dependency (push) Successful in 29s
Build on push / translation (push) Successful in 17s
Build on push / application (push) Successful in 22s
Build on push / database (push) Successful in 22s
Build on push / mail (push) Successful in 23s
Build on push / auth (push) Successful in 15s
Build on push / api (push) Successful in 15s

Reviewed-on: #196
Closes #181
This commit is contained in:
2025-10-08 21:25:41 +02:00
137 changed files with 3481 additions and 270 deletions

View File

@@ -1,40 +1,80 @@
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from cpl.dependency.event_bus import EventBusABC
from cpl.graphql.event_bus.memory import InMemoryEventBus
from queries.cities import CityGraphType, CityFilter, CitySort
from queries.hello import UserGraphType # , UserFilter, UserSort, UserGraphType
from queries.user import UserFilter, UserSort
from cpl.api.api_module import ApiModule from cpl.api.api_module import ApiModule
from cpl.api.application.web_app import WebApp
from cpl.application.application_builder import ApplicationBuilder from cpl.application.application_builder import ApplicationBuilder
from cpl.auth import AuthModule from cpl.auth.schema import User, Role
from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema import AuthUser, Role
from cpl.core.configuration import Configuration from cpl.core.configuration import Configuration
from cpl.core.console import Console from cpl.core.console import Console
from cpl.core.environment import Environment from cpl.core.environment import Environment
from cpl.core.utils.cache import Cache from cpl.core.utils.cache import Cache
from cpl.database.mysql.mysql_module import MySQLModule from cpl.database.mysql.mysql_module import MySQLModule
from cpl.graphql.application.graphql_app import GraphQLApp
from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule
from cpl.graphql.graphql_module import GraphQLModule
from model.author_dao import AuthorDao
from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort
from model.post_dao import PostDao
from model.post_query import PostFilter, PostSort, PostGraphType, PostMutation, PostSubscription
from permissions import PostPermissions
from queries.hello import HelloQuery
from scoped_service import ScopedService from scoped_service import ScopedService
from service import PingService from service import PingService
from test_data_seeder import TestDataSeeder
def main(): def main():
builder = ApplicationBuilder[WebApp](WebApp) builder = ApplicationBuilder[GraphQLApp](GraphQLApp)
Configuration.add_json_file(f"appsettings.json") Configuration.add_json_file(f"appsettings.json")
Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json") Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json")
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True) Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
# builder.services.add_logging() # builder.services.add_logging()
builder.services.add_structured_logging() (
builder.services.add_transient(PingService) builder.services.add_structured_logging()
builder.services.add_module(MySQLModule) .add_transient(PingService)
builder.services.add_module(ApiModule) .add_module(MySQLModule)
.add_module(ApiModule)
builder.services.add_scoped(ScopedService) .add_module(GraphQLModule)
.add_module(GraphQLAuthModule)
builder.services.add_cache(AuthUser) .add_scoped(ScopedService)
builder.services.add_cache(Role) .add_singleton(EventBusABC, InMemoryEventBus)
.add_cache(User)
.add_cache(Role)
.add_transient(CityGraphType)
.add_transient(CityFilter)
.add_transient(CitySort)
.add_transient(UserGraphType)
.add_transient(UserFilter)
.add_transient(UserSort)
# .add_transient(UserGraphType)
# .add_transient(UserFilter)
# .add_transient(UserSort)
.add_transient(HelloQuery)
# test data
.add_singleton(TestDataSeeder)
# authors
.add_transient(AuthorDao)
.add_transient(AuthorGraphType)
.add_transient(AuthorFilter)
.add_transient(AuthorSort)
# posts
.add_transient(PostDao)
.add_transient(PostGraphType)
.add_transient(PostFilter)
.add_transient(PostSort)
.add_transient(PostMutation)
.add_transient(PostSubscription)
)
app = builder.build() app = builder.build()
app.with_logging() app.with_logging()
app.with_migrations("./scripts")
app.with_authentication() app.with_authentication()
app.with_authorization() app.with_authorization()
@@ -43,13 +83,35 @@ def main():
path="/route1", path="/route1",
fn=lambda r: JSONResponse("route1"), fn=lambda r: JSONResponse("route1"),
method="GET", method="GET",
authentication=True, # authentication=True,
permissions=[Permissions.administrator], # permissions=[Permissions.administrator],
) )
app.with_routes_directory("routes") app.with_routes_directory("routes")
schema = app.with_graphql()
schema.query.string_field("ping", resolver=lambda: "pong")
schema.query.with_query("hello", HelloQuery)
schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort)
(
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort)
# .with_require_any_permission(PostPermissions.read)
.with_public()
)
schema.mutation.with_mutation("post", PostMutation).with_public()
schema.subscription.with_subscription(PostSubscription)
app.with_auth_root_queries(True)
app.with_auth_root_mutations(True)
app.with_playground()
app.with_graphiql()
app.with_permissions(PostPermissions)
provider = builder.service_provider provider = builder.service_provider
user_cache = provider.get_service(Cache[AuthUser]) user_cache = provider.get_service(Cache[User])
role_cache = provider.get_service(Cache[Role]) role_cache = provider.get_service(Cache[Role])
if role_cache == user_cache: if role_cache == user_cache:

View File

View File

@@ -0,0 +1,30 @@
from datetime import datetime
from typing import Self
from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC
class Author(DbModelABC[Self]):
def __init__(
self,
id: int,
first_name: str,
last_name: str,
deleted: bool = False,
editor_id: SerialId | None = None,
created: datetime | None = None,
updated: datetime | None = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._first_name = first_name
self._last_name = last_name
@property
def first_name(self) -> str:
return self._first_name
@property
def last_name(self) -> str:
return self._last_name

View File

@@ -0,0 +1,11 @@
from cpl.database.abc import DbModelDaoABC
from model.author import Author
class AuthorDao(DbModelDaoABC):
def __init__(self):
DbModelDaoABC.__init__(self, Author, "authors")
self.attribute(Author.first_name, str, db_name="firstname")
self.attribute(Author.last_name, str, db_name="lastname")

View File

@@ -0,0 +1,37 @@
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
from model.author import Author
class AuthorFilter(DbModelFilter[Author]):
def __init__(self):
DbModelFilter.__init__(self, public=True)
self.int_field("id")
self.string_field("firstName")
self.string_field("lastName")
class AuthorSort(Sort[Author]):
def __init__(self):
Sort.__init__(self)
self.field("id", SortOrder)
self.field("firstName", SortOrder)
self.field("lastName", SortOrder)
class AuthorGraphType(DbModelGraphType[Author]):
def __init__(self):
DbModelGraphType.__init__(self, public=True)
self.int_field(
"id",
resolver=lambda root: root.id,
).with_public(True)
self.string_field(
"firstName",
resolver=lambda root: root.first_name,
).with_public(True)
self.string_field(
"lastName",
resolver=lambda root: root.last_name,
).with_public(True)

View File

@@ -0,0 +1,44 @@
from datetime import datetime
from typing import Self
from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC
class Post(DbModelABC[Self]):
def __init__(
self,
id: int,
author_id: SerialId,
title: str,
content: str,
deleted: bool = False,
editor_id: SerialId | None = None,
created: datetime | None = None,
updated: datetime | None = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._author_id = author_id
self._title = title
self._content = content
@property
def author_id(self) -> SerialId:
return self._author_id
@property
def title(self) -> str:
return self._title
@title.setter
def title(self, value: str):
self._title = value
@property
def content(self) -> str:
return self._content
@content.setter
def content(self, value: str):
self._content = value

View File

@@ -0,0 +1,15 @@
from cpl.database.abc import DbModelDaoABC
from model.author_dao import AuthorDao
from model.post import Post
class PostDao(DbModelDaoABC[Post]):
def __init__(self, authors: AuthorDao):
DbModelDaoABC.__init__(self, Post, "posts")
self.attribute(Post.author_id, int, db_name="authorId")
self.reference("author", "id", Post.author_id, "authors", authors)
self.attribute(Post.title, str)
self.attribute(Post.content, str)

View File

@@ -0,0 +1,148 @@
from cpl.dependency.event_bus import EventBusABC
from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
from cpl.graphql.schema.input import Input
from cpl.graphql.schema.mutation import Mutation
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
from cpl.graphql.schema.subscription import Subscription
from model.author_dao import AuthorDao
from model.author_query import AuthorGraphType, AuthorFilter
from model.post import Post
from model.post_dao import PostDao
class PostFilter(DbModelFilter[Post]):
def __init__(self):
DbModelFilter.__init__(self, public=True)
self.int_field("id")
self.filter_field("author", AuthorFilter)
self.string_field("title")
self.string_field("content")
class PostSort(Sort[Post]):
def __init__(self):
Sort.__init__(self)
self.field("id", SortOrder)
self.field("title", SortOrder)
self.field("content", SortOrder)
class PostGraphType(DbModelGraphType[Post]):
def __init__(self, authors: AuthorDao):
DbModelGraphType.__init__(self, public=True)
self.int_field(
"id",
resolver=lambda root: root.id,
).with_optional().with_public(True)
async def _a(root: Post):
return await authors.get_by_id(root.author_id)
def r_name(ctx: QueryContext):
return ctx.user.username == "admin"
self.object_field("author", AuthorGraphType, resolver=_a).with_public(True) # .with_require_any([], [r_name]))
self.string_field(
"title",
resolver=lambda root: root.title,
).with_public(True)
self.string_field(
"content",
resolver=lambda root: root.content,
).with_public(True)
class PostCreateInput(Input[Post]):
title: str
content: str
author_id: int
def __init__(self):
Input.__init__(self)
self.string_field("title").with_required()
self.string_field("content").with_required()
self.int_field("author_id").with_required()
class PostUpdateInput(Input[Post]):
title: str
content: str
author_id: int
def __init__(self):
Input.__init__(self)
self.int_field("id").with_required()
self.string_field("title").with_required(False)
self.string_field("content").with_required(False)
class PostSubscription(Subscription):
def __init__(self, bus: EventBusABC):
Subscription.__init__(self)
self._bus = bus
def selector(event: Post, info) -> bool:
return event.id == 101
self.subscription_field("postChange", PostGraphType, selector).with_public()
class PostMutation(Mutation):
def __init__(self, posts: PostDao, authors: AuthorDao, bus: EventBusABC):
Mutation.__init__(self)
self._posts = posts
self._authors = authors
self._bus = bus
self.field("create", int, resolver=self.create_post).with_public().with_required().with_argument(
"input",
PostCreateInput,
).with_required()
self.field("update", bool, resolver=self.update_post).with_public().with_required().with_argument(
"input",
PostUpdateInput,
).with_required()
self.field("delete", bool, resolver=self.delete_post).with_public().with_required().with_argument(
"id",
int,
).with_required()
self.field("restore", bool, resolver=self.restore_post).with_public().with_required().with_argument(
"id",
int,
).with_required()
async def create_post(self, input: PostCreateInput) -> int:
return await self._posts.create(Post(0, input.author_id, input.title, input.content))
async def update_post(self, input: PostUpdateInput) -> bool:
post = await self._posts.get_by_id(input.id)
if post is None:
return False
post.title = input.title if input.title is not None else post.title
post.content = input.content if input.content is not None else post.content
await self._posts.update(post)
await self._bus.publish("postChange", post)
return True
async def delete_post(self, id: int) -> bool:
post = await self._posts.get_by_id(id)
if post is None:
return False
await self._posts.delete(post)
return True
async def restore_post(self, id: int) -> bool:
post = await self._posts.get_by_id(id)
if post is None:
return False
await self._posts.restore(post)
return True

View File

@@ -0,0 +1,8 @@
from enum import Enum
class PostPermissions(Enum):
read = "post.read"
write = "post.write"
delete = "post.delete"

View File

View File

@@ -0,0 +1,39 @@
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
class City:
def __init__(self, id: int, name: str):
self.id = id
self.name = name
class CityFilter(Filter[City]):
def __init__(self):
Filter.__init__(self)
self.field("id", int)
self.field("name", str)
class CitySort(Sort[City]):
def __init__(self):
Sort.__init__(self)
self.field("id", SortOrder)
self.field("name", SortOrder)
class CityGraphType(GraphType[City]):
def __init__(self):
GraphType.__init__(self)
self.int_field(
"id",
resolver=lambda root: root.id,
)
self.string_field(
"name",
resolver=lambda root: root.name,
)

View File

@@ -0,0 +1,70 @@
from queries.cities import CityFilter, CitySort, CityGraphType, City
from queries.user import User, UserFilter, UserSort, UserGraphType
from cpl.api.middleware.request import get_request
from cpl.auth.schema import UserDao, User
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.query import Query
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
users = [User(i, f"User {i}") for i in range(1, 101)]
cities = [City(i, f"City {i}") for i in range(1, 101)]
# class UserFilter(Filter[User]):
# def __init__(self):
# Filter.__init__(self)
# self.field("id", int)
# self.field("username", str)
#
#
# class UserSort(Sort[User]):
# def __init__(self):
# Sort.__init__(self)
# self.field("id", SortOrder)
# self.field("username", SortOrder)
#
# class UserGraphType(GraphType[User]):
#
# def __init__(self):
# GraphType.__init__(self)
#
# self.int_field(
# "id",
# resolver=lambda root: root.id,
# )
# self.string_field(
# "username",
# resolver=lambda root: root.username,
# )
class HelloQuery(Query):
def __init__(self):
Query.__init__(self)
self.string_field(
"message",
resolver=lambda name: f"Hello {name} {get_request().state.request_id}",
).with_argument("name", str, "Name to greet", "world")
self.collection_field(
UserGraphType,
"users",
UserFilter,
UserSort,
resolver=lambda: users,
)
self.collection_field(
CityGraphType,
"cities",
CityFilter,
CitySort,
resolver=lambda: cities,
)
# self.dao_collection_field(
# UserGraphType,
# UserDao,
# "Users",
# UserFilter,
# UserSort,
# )

View File

@@ -0,0 +1,39 @@
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.graph_type import GraphType
from cpl.graphql.schema.sort.sort import Sort
from cpl.graphql.schema.sort.sort_order import SortOrder
class User:
def __init__(self, id: int, name: str):
self.id = id
self.name = name
class UserFilter(Filter[User]):
def __init__(self):
Filter.__init__(self)
self.field("id", int)
self.field("name", str)
class UserSort(Sort[User]):
def __init__(self):
Sort.__init__(self)
self.field("id", SortOrder)
self.field("name", SortOrder)
class UserGraphType(GraphType[User]):
def __init__(self):
GraphType.__init__(self)
self.int_field(
"id",
resolver=lambda root: root.id,
)
self.string_field(
"name",
resolver=lambda root: root.name,
)

View File

@@ -0,0 +1,22 @@
CREATE TABLE IF NOT EXISTS `authors` (
`id` INT(30) NOT NULL AUTO_INCREMENT,
`firstname` VARCHAR(64) NOT NULL,
`lastname` VARCHAR(64) NOT NULL,
deleted BOOLEAN NOT NULL DEFAULT FALSE,
editorId INT NULL,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY(`id`)
);
CREATE TABLE IF NOT EXISTS `posts` (
`id` INT(30) NOT NULL AUTO_INCREMENT,
`authorId` INT(30) NOT NULL REFERENCES `authors`(`id`) ON DELETE CASCADE,
`title` TEXT NOT NULL,
`content` TEXT NOT NULL,
deleted BOOLEAN NOT NULL DEFAULT FALSE,
editorId INT NULL,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY(`id`)
);

View File

@@ -0,0 +1,48 @@
from faker import Faker
from cpl.database.abc import DataSeederABC
from cpl.query import Enumerable
from model.author import Author
from model.author_dao import AuthorDao
from model.post import Post
from model.post_dao import PostDao
fake = Faker()
class TestDataSeeder(DataSeederABC):
def __init__(self, authors: AuthorDao, posts: PostDao):
DataSeederABC.__init__(self)
self._authors = authors
self._posts = posts
async def seed(self):
if await self._authors.count() == 0:
await self._seed_authors()
if await self._posts.count() == 0:
await self._seed_posts()
async def _seed_authors(self):
authors = Enumerable.range(0, 35).select(
lambda x: Author(
0,
fake.first_name(),
fake.last_name(),
)
).to_list()
await self._authors.create_many(authors, skip_editor=True)
async def _seed_posts(self):
posts = Enumerable.range(0, 100).select(
lambda x: Post(
id=0,
author_id=fake.random_int(min=1, max=35),
title=fake.sentence(nb_words=6),
content=fake.paragraph(nb_sentences=6),
)
).to_list()
await self._posts.create_many(posts, skip_editor=True)

View File

@@ -5,16 +5,16 @@ from cpl.core.typing import SerialId
from cpl.database.abc.db_model_abc import DbModelABC from cpl.database.abc.db_model_abc import DbModelABC
class City(DbModelABC): class City(DbModelABC[Self]):
def __init__( def __init__(
self, self,
id: int, id: int,
name: str, name: str,
zip: str, zip: str,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None= None,
updated: Optional[datetime] = None, updated: datetime | None= None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name self._name = name

View File

@@ -5,7 +5,7 @@ from cpl.core.typing import SerialId
from cpl.database.abc.db_model_abc import DbModelABC from cpl.database.abc.db_model_abc import DbModelABC
class User(DbModelABC): class User(DbModelABC[Self]):
def __init__( def __init__(
self, self,
@@ -13,9 +13,9 @@ class User(DbModelABC):
name: str, name: str,
city_id: int = 0, city_id: int = 0,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None= None,
updated: Optional[datetime] = None, updated: datetime | None= None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name self._name = name

View File

@@ -0,0 +1,45 @@
from abc import ABC
from enum import Enum
from typing import Self
from starlette.applications import Starlette
from cpl.api.model.api_route import ApiRoute
from cpl.api.model.validation_match import ValidationMatch
from cpl.api.typing import HTTPMethods, PartialMiddleware, TEndpoint, PolicyInput
from cpl.application.abc.application_abc import ApplicationABC
from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.typing import Modules
class WebAppABC(ApplicationABC, ABC):
def __init__(self, services: ServiceProvider, modules: Modules, required_modules: list[str | object] = None):
ApplicationABC.__init__(self, services, modules, required_modules)
def with_routes_directory(self, directory: str) -> Self: ...
def with_app(self, app: Starlette) -> Self: ...
def with_routes(
self,
routes: list[ApiRoute],
method: HTTPMethods,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Self: ...
def with_route(
self,
path: str,
fn: TEndpoint,
method: HTTPMethods,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Self: ...
def with_middleware(self, middleware: PartialMiddleware) -> Self: ...
def with_authentication(self) -> Self: ...
def with_authorization(self, *policies: list[PolicyInput] | PolicyInput) -> Self: ...

View File

@@ -1,6 +1,6 @@
import os import os
from enum import Enum from enum import Enum
from typing import Mapping, Any, Callable, Self, Union from typing import Mapping, Any, Self
import uvicorn import uvicorn
from starlette.applications import Starlette from starlette.applications import Starlette
@@ -10,6 +10,7 @@ from starlette.requests import Request
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from starlette.types import ExceptionHandler from starlette.types import ExceptionHandler
from cpl.api.abc.web_app_abc import WebAppABC
from cpl.api.api_module import ApiModule from cpl.api.api_module import ApiModule
from cpl.api.error import APIError from cpl.api.error import APIError
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
@@ -24,8 +25,7 @@ from cpl.api.registry.policy import PolicyRegistry
from cpl.api.registry.route import RouteRegistry from cpl.api.registry.route import RouteRegistry
from cpl.api.router import Router from cpl.api.router import Router
from cpl.api.settings import ApiSettings from cpl.api.settings import ApiSettings
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver from cpl.api.typing import HTTPMethods, PartialMiddleware, TEndpoint, PolicyInput
from cpl.application.abc.application_abc import ApplicationABC
from cpl.auth.auth_module import AuthModule from cpl.auth.auth_module import AuthModule
from cpl.auth.permission.permission_module import PermissionsModule from cpl.auth.permission.permission_module import PermissionsModule
from cpl.core.configuration.configuration import Configuration from cpl.core.configuration.configuration import Configuration
@@ -33,12 +33,12 @@ from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.typing import Modules from cpl.dependency.typing import Modules
PolicyInput = Union[dict[str, PolicyResolver], Policy]
class WebApp(WebAppABC):
class WebApp(ApplicationABC): def __init__(self, services: ServiceProvider, modules: Modules, required_modules: list[str | object] = None):
def __init__(self, services: ServiceProvider, modules: Modules): WebAppABC.__init__(
super().__init__(services, modules, [AuthModule, PermissionsModule, ApiModule]) self, services, modules, [AuthModule, PermissionsModule, ApiModule] + (required_modules or [])
)
self._app: Starlette | None = None self._app: Starlette | None = None
self._logger = services.get_service(APILogger) self._logger = services.get_service(APILogger)
@@ -78,16 +78,17 @@ class WebApp(ApplicationABC):
self._logger.debug(f"Allowed origins: {origins}") self._logger.debug(f"Allowed origins: {origins}")
return origins.split(",") return origins.split(",")
def with_app(self, app: Starlette) -> Self:
assert app is not None, "app must not be None"
assert isinstance(app, Starlette), "app must be an instance of Starlette"
self._app = app
return self
def _check_for_app(self): def _check_for_app(self):
if self._app is not None: if self._app is not None:
raise ValueError("App is already set, cannot add routes or middleware") raise ValueError("App is already set, cannot add routes or middleware")
def _validate_policies(self):
for rule in Router.get_authorization_rules():
for policy_name in rule["policies"]:
policy = self._policies.get(policy_name)
if not policy:
self._logger.fatal(f"Authorization policy '{policy_name}' not found")
def with_routes_directory(self, directory: str) -> Self: def with_routes_directory(self, directory: str) -> Self:
self._check_for_app() self._check_for_app()
assert directory is not None, "directory must not be None" assert directory is not None, "directory must not be None"
@@ -102,6 +103,12 @@ class WebApp(ApplicationABC):
return self return self
def with_app(self, app: Starlette) -> Self:
assert app is not None, "app must not be None"
assert isinstance(app, Starlette), "app must be an instance of Starlette"
self._app = app
return self
def with_routes( def with_routes(
self, self,
routes: list[ApiRoute], routes: list[ApiRoute],
@@ -131,7 +138,7 @@ class WebApp(ApplicationABC):
def with_route( def with_route(
self, self,
path: str, path: str,
fn: Callable[[Request], Any], fn: TEndpoint,
method: HTTPMethods, method: HTTPMethods,
authentication: bool = False, authentication: bool = False,
roles: list[str | Enum] = None, roles: list[str | Enum] = None,
@@ -162,6 +169,30 @@ class WebApp(ApplicationABC):
return self return self
def with_websocket(
self,
path: str,
fn: TEndpoint,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Self:
self._check_for_app()
assert path is not None, "path must not be None"
assert fn is not None, "fn must not be None"
Router.websocket(path, registry=self._routes)(fn)
if authentication:
Router.authenticate()(fn)
if roles or permissions or policies:
Router.authorize(roles, permissions, policies, match)(fn)
return self
def with_middleware(self, middleware: PartialMiddleware) -> Self: def with_middleware(self, middleware: PartialMiddleware) -> Self:
self._check_for_app() self._check_for_app()
@@ -179,6 +210,7 @@ class WebApp(ApplicationABC):
return self return self
def with_authorization(self, *policies: list[PolicyInput] | PolicyInput) -> Self: def with_authorization(self, *policies: list[PolicyInput] | PolicyInput) -> Self:
self._check_for_app()
if policies: if policies:
_policies = [] _policies = []
@@ -206,12 +238,8 @@ class WebApp(ApplicationABC):
self.with_middleware(AuthorizationMiddleware) self.with_middleware(AuthorizationMiddleware)
return self return self
def _validate_policies(self): async def _log_before_startup(self):
for rule in Router.get_authorization_rules(): self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
for policy_name in rule["policies"]:
policy = self._policies.get(policy_name)
if not policy:
self._logger.fatal(f"Authorization policy '{policy_name}' not found")
async def main(self): async def main(self):
self._logger.debug(f"Preparing API") self._logger.debug(f"Preparing API")
@@ -236,7 +264,7 @@ class WebApp(ApplicationABC):
else: else:
app = self._app app = self._app
self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}") await self._log_before_startup()
config = uvicorn.Config( config = uvicorn.Config(
app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio" app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio"

View File

@@ -8,7 +8,7 @@ class APIError(HTTPException):
status_code = 500 status_code = 500
def __init__(self, message: str = ""): def __init__(self, message: str = ""):
super().__init__(self.status_code, message) HTTPException.__init__(self, self.status_code, message)
self._message = message self._message = message
@property @property

View File

@@ -7,13 +7,13 @@ from cpl.api.logger import APILogger
from cpl.api.middleware.request import get_request from cpl.api.middleware.request import get_request
from cpl.api.router import Router from cpl.api.router import Router
from cpl.auth.keycloak import KeycloakClient from cpl.auth.keycloak import KeycloakClient
from cpl.auth.schema import AuthUserDao, AuthUser from cpl.auth.schema import UserDao, User
from cpl.core.ctx import set_user from cpl.core.ctx import set_user
class AuthenticationMiddleware(ASGIMiddleware): class AuthenticationMiddleware(ASGIMiddleware):
def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: AuthUserDao): def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)
self._logger = logger self._logger = logger
@@ -25,6 +25,21 @@ class AuthenticationMiddleware(ASGIMiddleware):
request = get_request() request = get_request()
url = request.url.path url = request.url.path
if url not in Router.get_auth_required_routes():
self._logger.trace(f"No authentication required for {url}")
return await self._app(scope, receive, send)
user = getattr(request.state, "user", None)
if not user or user.deleted:
self._logger.debug(f"Unauthorized access to {url}, user missing or deleted")
return await Unauthorized("Unauthorized").asgi_response(scope, receive, send)
return await self._call_next(scope, receive, send)
async def _old_call__(self, scope: Scope, receive: Receive, send: Send):
request = get_request()
url = request.url.path
if url not in Router.get_auth_required_routes(): if url not in Router.get_auth_required_routes():
self._logger.trace(f"No authentication required for {url}") self._logger.trace(f"No authentication required for {url}")
return await self._app(scope, receive, send) return await self._app(scope, receive, send)
@@ -57,12 +72,12 @@ class AuthenticationMiddleware(ASGIMiddleware):
return await self._call_next(scope, receive, send) return await self._call_next(scope, receive, send)
async def _get_or_crate_user(self, keycloak_id: str) -> AuthUser: async def _get_or_crate_user(self, keycloak_id: str) -> User:
existing = await self._user_dao.find_by_keycloak_id(keycloak_id) existing = await self._user_dao.find_by_keycloak_id(keycloak_id)
if existing is not None: if existing is not None:
return existing return existing
user = AuthUser(0, keycloak_id) user = User(0, keycloak_id)
uid = await self._user_dao.create(user) uid = await self._user_dao.create(user)
return await self._user_dao.get_by_id(uid) return await self._user_dao.get_by_id(uid)

View File

@@ -7,13 +7,13 @@ from cpl.api.middleware.request import get_request
from cpl.api.model.validation_match import ValidationMatch from cpl.api.model.validation_match import ValidationMatch
from cpl.api.registry.policy import PolicyRegistry from cpl.api.registry.policy import PolicyRegistry
from cpl.api.router import Router from cpl.api.router import Router
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.user_dao import UserDao
from cpl.core.ctx.user_context import get_user from cpl.core.ctx.user_context import get_user
class AuthorizationMiddleware(ASGIMiddleware): class AuthorizationMiddleware(ASGIMiddleware):
def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: AuthUserDao): def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: UserDao):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)
self._logger = logger self._logger = logger

View File

@@ -5,10 +5,15 @@ from uuid import uuid4
from starlette.requests import Request from starlette.requests import Request
from starlette.types import Scope, Receive, Send from starlette.types import Scope, Receive, Send
from starlette.websockets import WebSocket
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.logger import APILogger from cpl.api.logger import APILogger
from cpl.api.typing import TRequest from cpl.api.typing import TRequest
from cpl.auth.keycloak.keycloak_client import KeycloakClient
from cpl.auth.schema import User
from cpl.auth.schema._administration.user_dao import UserDao
from cpl.core.ctx import set_user
from cpl.dependency.inject import inject from cpl.dependency.inject import inject
from cpl.dependency.service_provider import ServiceProvider from cpl.dependency.service_provider import ServiceProvider
@@ -17,19 +22,23 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
class RequestMiddleware(ASGIMiddleware): class RequestMiddleware(ASGIMiddleware):
def __init__(self, app, provider: ServiceProvider, logger: APILogger): def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao):
ASGIMiddleware.__init__(self, app) ASGIMiddleware.__init__(self, app)
self._provider = provider self._provider = provider
self._logger = logger self._logger = logger
self._keycloak = keycloak
self._user_dao = user_dao
self._ctx_token = None self._ctx_token = None
async def __call__(self, scope: Scope, receive: Receive, send: Send): async def __call__(self, scope: Scope, receive: Receive, send: Send):
request = Request(scope, receive, send) request = Request(scope, receive, send) if scope["type"] != "websocket" else WebSocket(scope, receive, send)
await self.set_request_data(request) await self.set_request_data(request)
try: try:
await self._try_set_user(request)
with self._provider.create_scope(): with self._provider.create_scope():
inject(await self._app(scope, receive, send)) inject(await self._app(scope, receive, send))
finally: finally:
@@ -53,6 +62,37 @@ class RequestMiddleware(ASGIMiddleware):
self._logger.trace(f"Clearing current request: {request.state.request_id}") self._logger.trace(f"Clearing current request: {request.state.request_id}")
_request_context.reset(self._ctx_token) _request_context.reset(self._ctx_token)
async def _try_set_user(self, request: Request):
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
return
token = auth_header.split("Bearer ")[1]
try:
token_info = self._keycloak.introspect(token)
if not token_info.get("active", False):
return
keycloak_id = self._keycloak.get_user_id(token)
if not keycloak_id:
return
user = await self._user_dao.find_by_keycloak_id(keycloak_id)
if not user:
user = User(0, keycloak_id)
uid = await self._user_dao.create(user)
user = await self._user_dao.get_by_id(uid)
if user.deleted:
return
request.state.user = user
set_user(user)
self._logger.trace(f"User {user.id} bound to request {request.state.request_id}")
except Exception as e:
self._logger.debug(f"Silent user binding failed: {e}")
def get_request() -> Optional[TRequest]: def get_request() -> Optional[TRequest]:
return _request_context.get() return _request_context.get()

View File

@@ -0,0 +1,31 @@
from typing import Callable
import starlette.routing
class WebSocketRoute:
def __init__(self, path: str, fn: Callable, **kwargs):
self._path = path
self._fn = fn
self._kwargs = kwargs
@property
def name(self) -> str:
return self._fn.__name__
@property
def fn(self) -> Callable:
return self._fn
@property
def path(self) -> str:
return self._path
@property
def kwargs(self) -> dict:
return self._kwargs
def to_starlette(self, *args) -> starlette.routing.WebSocketRoute:
return starlette.routing.WebSocketRoute(self._path, self._fn)

View File

@@ -1,32 +1,35 @@
from typing import Optional from typing import Optional, Union
from cpl.api.model.api_route import ApiRoute from cpl.api.model.api_route import ApiRoute
from cpl.api.model.websocket_route import WebSocketRoute
from cpl.core.abc.registry_abc import RegistryABC from cpl.core.abc.registry_abc import RegistryABC
TRoute = Union[ApiRoute, WebSocketRoute]
class RouteRegistry(RegistryABC): class RouteRegistry(RegistryABC):
def __init__(self): def __init__(self):
RegistryABC.__init__(self) RegistryABC.__init__(self)
def extend(self, items: list[ApiRoute]): def extend(self, items: list[TRoute]):
for policy in items: for policy in items:
self.add(policy) self.add(policy)
def add(self, item: ApiRoute): def add(self, item: TRoute):
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute" assert isinstance(item, (ApiRoute, WebSocketRoute)), "route must be an instance of ApiRoute"
if item.path in self._items: if item.path in self._items:
raise ValueError(f"ApiRoute {item.path} is already registered") raise ValueError(f"ApiRoute {item.path} is already registered")
self._items[item.path] = item self._items[item.path] = item
def set(self, item: ApiRoute): def set(self, item: TRoute):
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute" assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
self._items[item.path] = item self._items[item.path] = item
def get(self, key: str) -> Optional[ApiRoute]: def get(self, key: str) -> Optional[TRoute]:
return self._items.get(key) return self._items.get(key)
def all(self) -> list[ApiRoute]: def all(self) -> list[TRoute]:
return list(self._items.values()) return list(self._items.values())

View File

@@ -91,6 +91,22 @@ class Router:
return inner return inner
@classmethod
def websocket(cls, path: str, registry: RouteRegistry = None, **kwargs):
from cpl.api.model.websocket_route import WebSocketRoute
if not registry:
routes = get_provider().get_service(RouteRegistry)
else:
routes = registry
def inner(fn):
routes.add(WebSocketRoute(path, fn, **kwargs))
setattr(fn, "_route_path", path)
return fn
return inner
@classmethod @classmethod
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs): def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs):
from cpl.api.model.api_route import ApiRoute from cpl.api.model.api_route import ApiRoute

View File

@@ -6,7 +6,7 @@ from cpl.core.configuration import ConfigurationModelABC
class ApiSettings(ConfigurationModelABC): class ApiSettings(ConfigurationModelABC):
def __init__(self, src: Optional[dict] = None): def __init__(self, src: Optional[dict] = None):
super().__init__(src) ConfigurationModelABC.__init__(self, src)
self.option("host", str, "0.0.0.0") self.option("host", str, "0.0.0.0")
self.option("port", int, 5000) self.option("port", int, 5000)

View File

@@ -2,13 +2,15 @@ from typing import Union, Literal, Callable, Type, Awaitable
from urllib.request import Request from urllib.request import Request
from starlette.middleware import Middleware from starlette.middleware import Middleware
from starlette.responses import Response
from starlette.types import ASGIApp from starlette.types import ASGIApp
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.auth.schema import AuthUser from cpl.auth.schema import User
TRequest = Union[Request, WebSocket] TRequest = Union[Request, WebSocket]
TEndpoint = Callable[[TRequest, ...], Awaitable[Response]] | Callable[[TRequest, ...], Response]
HTTPMethods = Literal["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"] HTTPMethods = Literal["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
PartialMiddleware = Union[ PartialMiddleware = Union[
ASGIMiddleware, ASGIMiddleware,
@@ -16,4 +18,5 @@ PartialMiddleware = Union[
Middleware, Middleware,
Callable[[ASGIApp], ASGIApp], Callable[[ASGIApp], ASGIApp],
] ]
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]] PolicyResolver = Callable[[User], bool | Awaitable[bool]]
PolicyInput = Union[dict[str, PolicyResolver], "Policy"]

View File

@@ -56,7 +56,7 @@ class ApplicationABC(ABC):
module_dependency_error( module_dependency_error(
type(self).__name__, type(self).__name__,
module.__name__, module.__name__ if not isinstance(module, str) else module,
ImportError( ImportError(
f"Required module '{module}' for application '{self.__class__.__name__}' is not loaded. Load using 'add_module({module})' method." f"Required module '{module}' for application '{self.__class__.__name__}' is not loaded. Load using 'add_module({module})' method."
), ),

View File

@@ -12,7 +12,7 @@ from cpl.dependency.service_provider import ServiceProvider
from .keycloak.keycloak_admin import KeycloakAdmin from .keycloak.keycloak_admin import KeycloakAdmin
from .keycloak.keycloak_client import KeycloakClient from .keycloak.keycloak_client import KeycloakClient
from .schema._administration.api_key_dao import ApiKeyDao from .schema._administration.api_key_dao import ApiKeyDao
from .schema._administration.auth_user_dao import AuthUserDao from .schema._administration.user_dao import UserDao
from .schema._permission.api_key_permission_dao import ApiKeyPermissionDao from .schema._permission.api_key_permission_dao import ApiKeyPermissionDao
from .schema._permission.permission_dao import PermissionDao from .schema._permission.permission_dao import PermissionDao
from .schema._permission.role_dao import RoleDao from .schema._permission.role_dao import RoleDao
@@ -26,7 +26,7 @@ class AuthModule(Module):
singleton = [ singleton = [
KeycloakClient, KeycloakClient,
KeycloakAdmin, KeycloakAdmin,
AuthUserDao, UserDao,
ApiKeyDao, ApiKeyDao,
ApiKeyPermissionDao, ApiKeyPermissionDao,
PermissionDao, PermissionDao,

View File

@@ -2,6 +2,7 @@ from cpl.auth.auth_module import AuthModule
from cpl.auth.permission.permission_seeder import PermissionSeeder from cpl.auth.permission.permission_seeder import PermissionSeeder
from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions import Permissions
from cpl.auth.permission.permissions_registry import PermissionsRegistry from cpl.auth.permission.permissions_registry import PermissionsRegistry
from cpl.auth.permission.role_seeder import RoleSeeder
from cpl.database.abc.data_seeder_abc import DataSeederABC from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.database_module import DatabaseModule from cpl.database.database_module import DatabaseModule
from cpl.dependency.module.module import Module from cpl.dependency.module.module import Module
@@ -10,7 +11,7 @@ from cpl.dependency.service_collection import ServiceCollection
class PermissionsModule(Module): class PermissionsModule(Module):
dependencies = [DatabaseModule, AuthModule] dependencies = [DatabaseModule, AuthModule]
singleton = [(DataSeederABC, PermissionSeeder)] transient = [(DataSeederABC, PermissionSeeder), (DataSeederABC, RoleSeeder)]
@staticmethod @staticmethod
def register(collection: ServiceCollection): def register(collection: ServiceCollection):

View File

@@ -1,4 +1,3 @@
from cpl.auth.permission.permissions import Permissions
from cpl.auth.permission.permissions_registry import PermissionsRegistry from cpl.auth.permission.permissions_registry import PermissionsRegistry
from cpl.auth.schema import ( from cpl.auth.schema import (
Permission, Permission,

View File

@@ -0,0 +1,60 @@
from cpl.auth.schema import (
Role,
RolePermission,
PermissionDao,
RoleDao,
RolePermissionDao,
ApiKeyDao,
ApiKeyPermissionDao,
UserDao,
RoleUserDao,
RoleUser,
)
from cpl.database.abc.data_seeder_abc import DataSeederABC
from cpl.database.logger import DBLogger
class RoleSeeder(DataSeederABC):
def __init__(
self,
logger: DBLogger,
permission_dao: PermissionDao,
role_dao: RoleDao,
role_permission_dao: RolePermissionDao,
api_key_dao: ApiKeyDao,
api_key_permission_dao: ApiKeyPermissionDao,
user_dao: UserDao,
role_user_dao: RoleUserDao,
):
DataSeederABC.__init__(self)
self._logger = logger
self._permission_dao = permission_dao
self._role_dao = role_dao
self._role_permission_dao = role_permission_dao
self._api_key_dao = api_key_dao
self._api_key_permission_dao = api_key_permission_dao
self._user_dao = user_dao
self._role_user_dao = role_user_dao
async def seed(self):
self._logger.info("Creating admin role")
roles = await self._role_dao.get_all()
if len(roles) == 0:
rid = await self._role_dao.create(Role(0, "admin", "Default admin role"))
permissions = await self._permission_dao.get_all()
await self._role_permission_dao.create_many(
[RolePermission(0, rid, permission.id) for permission in permissions]
)
role = await self._role_dao.get_by_name("admin")
if len(await role.users) > 0:
return
users = await self._user_dao.get_all()
if len(users) == 0:
return
user = users[0]
self._logger.warning(f"Assigning admin role to first user {user.id}")
await self._role_user_dao.create(RoleUser(0, role.id, user.id))

View File

@@ -1,7 +1,7 @@
from ._administration.api_key import ApiKey from ._administration.api_key import ApiKey
from ._administration.api_key_dao import ApiKeyDao from ._administration.api_key_dao import ApiKeyDao
from ._administration.auth_user import AuthUser from ._administration.user import User
from ._administration.auth_user_dao import AuthUserDao from ._administration.user_dao import UserDao
from ._permission.api_key_permission import ApiKeyPermission from ._permission.api_key_permission import ApiKeyPermission
from ._permission.api_key_permission_dao import ApiKeyPermissionDao from ._permission.api_key_permission_dao import ApiKeyPermissionDao

View File

@@ -1,6 +1,6 @@
import secrets import secrets
from datetime import datetime from datetime import datetime
from typing import Optional, Union from typing import Optional, Union, Self
from async_property import async_property from async_property import async_property
@@ -16,7 +16,7 @@ from cpl.dependency.service_provider import ServiceProvider
_logger = Logger(__name__) _logger = Logger(__name__)
class ApiKey(DbModelABC): class ApiKey(DbModelABC[Self]):
def __init__( def __init__(
self, self,
@@ -25,8 +25,8 @@ class ApiKey(DbModelABC):
key: Union[str, bytes], key: Union[str, bytes],
deleted: bool = False, deleted: bool = False,
editor_id: Optional[Id] = None, editor_id: Optional[Id] = None,
created: Optional[datetime] = None, created: datetime | None = None,
updated: Optional[datetime] = None, updated: datetime | None = None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._identifier = identifier self._identifier = identifier

View File

@@ -1,6 +1,6 @@
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional, Self
from async_property import async_property from async_property import async_property
from keycloak import KeycloakGetError from keycloak import KeycloakGetError
@@ -10,18 +10,18 @@ from cpl.auth.permission.permissions import Permissions
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
from cpl.database.logger import DBLogger from cpl.database.logger import DBLogger
from cpl.dependency import ServiceProvider from cpl.dependency import get_provider
class AuthUser(DbModelABC): class User(DbModelABC[Self]):
def __init__( def __init__(
self, self,
id: SerialId, id: SerialId,
keycloak_id: str, keycloak_id: str,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None = None,
updated: Optional[datetime] = None, updated: datetime | None = None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._keycloak_id = keycloak_id self._keycloak_id = keycloak_id
@@ -69,21 +69,21 @@ class AuthUser(DbModelABC):
@async_property @async_property
async def permissions(self): async def permissions(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.user_dao import UserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) user_dao: UserDao = get_provider().get_service(UserDao)
return await auth_user_dao.get_permissions(self.id) return await user_dao.get_permissions(self.id)
async def has_permission(self, permission: Permissions) -> bool: async def has_permission(self, permission: Permissions) -> bool:
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.user_dao import UserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) user_dao: UserDao = get_provider().get_service(UserDao)
return await auth_user_dao.has_permission(self.id, permission) return await user_dao.has_permission(self.id, permission)
async def anonymize(self): async def anonymize(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.user_dao import UserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) user_dao: UserDao = get_provider().get_service(UserDao)
self._keycloak_id = str(uuid.UUID(int=0)) self._keycloak_id = str(uuid.UUID(int=0))
await auth_user_dao.update(self) await user_dao.update(self)

View File

@@ -1,19 +1,23 @@
from typing import Optional, Union from typing import Optional, Union
from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions import Permissions
from cpl.auth.schema._administration.auth_user import AuthUser from cpl.auth.schema._permission.permission_dao import PermissionDao
from cpl.auth.schema._permission.permission import Permission
from cpl.auth.schema._administration.user import User
from cpl.database import TableManager from cpl.database import TableManager
from cpl.database.abc import DbModelDaoABC from cpl.database.abc import DbModelDaoABC
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
from cpl.dependency import ServiceProvider from cpl.dependency.context import get_provider
class AuthUserDao(DbModelDaoABC[AuthUser]): class UserDao(DbModelDaoABC[User]):
def __init__(self): def __init__(self, permission_dao: PermissionDao):
DbModelDaoABC.__init__(self, AuthUser, TableManager.get("auth_users")) DbModelDaoABC.__init__(self, User, TableManager.get("users"))
self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId") self._permissions = permission_dao
self.attribute(User.keycloak_id, str)
async def get_users(): async def get_users():
return [(x.id, x.username, x.email) for x in await self.get_all()] return [(x.id, x.username, x.email) for x in await self.get_all()]
@@ -27,11 +31,11 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
.with_value_getter(get_users) .with_value_getter(get_users)
) )
async def get_by_keycloak_id(self, keycloak_id: str) -> AuthUser: async def get_by_keycloak_id(self, keycloak_id: str) -> User:
return await self.get_single_by({AuthUser.keycloak_id: keycloak_id}) return await self.get_single_by({User.keycloak_id: keycloak_id})
async def find_by_keycloak_id(self, keycloak_id: str) -> Optional[AuthUser]: async def find_by_keycloak_id(self, keycloak_id: str) -> Optional[User]:
return await self.find_single_by({AuthUser.keycloak_id: keycloak_id}) return await self.find_single_by({User.keycloak_id: keycloak_id})
async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool: async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool:
from cpl.auth.schema._permission.permission_dao import PermissionDao from cpl.auth.schema._permission.permission_dao import PermissionDao
@@ -54,7 +58,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
return result[0]["count"] > 0 return result[0]["count"] > 0
async def get_permissions(self, user_id: int) -> list[Permissions]: async def get_permissions(self, user_id: int) -> list[Permission]:
result = await self._db.select_map( result = await self._db.select_map(
f""" f"""
SELECT p.* SELECT p.*
@@ -66,4 +70,4 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
AND ru.deleted = FALSE; AND ru.deleted = FALSE;
""" """
) )
return [Permissions(p["name"]) for p in result] return [self._permissions.to_object(x) for x in result]

View File

@@ -15,9 +15,9 @@ class ApiKeyPermission(DbJoinModelABC):
api_key_id: SerialId, api_key_id: SerialId,
permission_id: SerialId, permission_id: SerialId,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None = None,
updated: Optional[datetime] = None, updated: datetime | None = None,
): ):
DbJoinModelABC.__init__(self, api_key_id, permission_id, id, deleted, editor_id, created, updated) DbJoinModelABC.__init__(self, api_key_id, permission_id, id, deleted, editor_id, created, updated)
self._api_key_id = api_key_id self._api_key_id = api_key_id

View File

@@ -1,20 +1,20 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional, Self
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
class Permission(DbModelABC): class Permission(DbModelABC[Self]):
def __init__( def __init__(
self, self,
id: SerialId, id: SerialId,
name: str, name: str,
description: str, description: str,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None = None,
updated: Optional[datetime] = None, updated: datetime | None = None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name self._name = name

View File

@@ -1,24 +1,24 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional, Self
from async_property import async_property from async_property import async_property
from cpl.auth.permission.permissions import Permissions from cpl.auth.permission.permissions import Permissions
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProvider, get_provider
class Role(DbModelABC): class Role(DbModelABC[Self]):
def __init__( def __init__(
self, self,
id: SerialId, id: SerialId,
name: str, name: str,
description: str, description: str,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None = None,
updated: Optional[datetime] = None, updated: datetime | None = None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name self._name = name

View File

@@ -1,46 +1,44 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Self
from async_property import async_property from async_property import async_property
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbModelABC from cpl.database.abc import DbJoinModelABC
from cpl.dependency import ServiceProvider from cpl.dependency import get_provider
class RolePermission(DbModelABC): class RolePermission(DbJoinModelABC[Self]):
def __init__( def __init__(
self, self,
id: SerialId, id: SerialId,
role_id: SerialId, role_id: SerialId,
permission_id: SerialId, permission_id: SerialId,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None = None,
updated: Optional[datetime] = None, updated: datetime | None = None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbJoinModelABC.__init__(self, id, role_id, permission_id, deleted, editor_id, created, updated)
self._role_id = role_id
self._permission_id = permission_id
@property @property
def role_id(self) -> int: def role_id(self) -> int:
return self._role_id return self._source_id
@async_property @async_property
async def role(self): async def role(self):
from cpl.auth.schema._permission.role_dao import RoleDao from cpl.auth.schema._permission.role_dao import RoleDao
role_dao: RoleDao = get_provider().get_service(RoleDao) role_dao: RoleDao = get_provider().get_service(RoleDao)
return await role_dao.get_by_id(self._role_id) return await role_dao.get_by_id(self._source_id)
@property @property
def permission_id(self) -> int: def permission_id(self) -> int:
return self._permission_id return self._foreign_id
@async_property @async_property
async def permission(self): async def permission(self):
from cpl.auth.schema._permission.permission_dao import PermissionDao from cpl.auth.schema._permission.permission_dao import PermissionDao
permission_dao: PermissionDao = get_provider().get_service(PermissionDao) permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
return await permission_dao.get_by_id(self._permission_id) return await permission_dao.get_by_id(self._foreign_id)

View File

@@ -5,7 +5,7 @@ from async_property import async_property
from cpl.core.typing import SerialId from cpl.core.typing import SerialId
from cpl.database.abc import DbJoinModelABC from cpl.database.abc import DbJoinModelABC
from cpl.dependency import ServiceProvider from cpl.dependency import ServiceProvider, get_provider
class RoleUser(DbJoinModelABC): class RoleUser(DbJoinModelABC):
@@ -15,9 +15,9 @@ class RoleUser(DbJoinModelABC):
user_id: SerialId, user_id: SerialId,
role_id: SerialId, role_id: SerialId,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None = None,
updated: Optional[datetime] = None, updated: datetime | None = None,
): ):
DbJoinModelABC.__init__(self, id, user_id, role_id, deleted, editor_id, created, updated) DbJoinModelABC.__init__(self, id, user_id, role_id, deleted, editor_id, created, updated)
self._user_id = user_id self._user_id = user_id
@@ -29,10 +29,10 @@ class RoleUser(DbJoinModelABC):
@async_property @async_property
async def user(self): async def user(self):
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao from cpl.auth.schema._administration.user_dao import UserDao
auth_user_dao: AuthUserDao = get_provider().get_service(AuthUserDao) user_dao: UserDao = get_provider().get_service(UserDao)
return await auth_user_dao.get_by_id(self._user_id) return await user_dao.get_by_id(self._user_id)
@property @property
def role_id(self) -> int: def role_id(self) -> int:

View File

@@ -1,4 +1,4 @@
CREATE TABLE IF NOT EXISTS administration_auth_users CREATE TABLE IF NOT EXISTS administration_users
( (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
keycloakId CHAR(36) NOT NULL, keycloakId CHAR(36) NOT NULL,
@@ -9,10 +9,10 @@ CREATE TABLE IF NOT EXISTS administration_auth_users
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
CONSTRAINT UC_KeycloakId UNIQUE (keycloakId), CONSTRAINT UC_KeycloakId UNIQUE (keycloakId),
CONSTRAINT FK_EditorId FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) CONSTRAINT FK_EditorId FOREIGN KEY (editorId) REFERENCES administration_users (id)
); );
CREATE TABLE IF NOT EXISTS administration_auth_users_history CREATE TABLE IF NOT EXISTS administration_users_history
( (
id INT NOT NULL, id INT NOT NULL,
keycloakId CHAR(36) NOT NULL, keycloakId CHAR(36) NOT NULL,
@@ -23,22 +23,22 @@ CREATE TABLE IF NOT EXISTS administration_auth_users_history
updated TIMESTAMP NOT NULL updated TIMESTAMP NOT NULL
); );
CREATE TRIGGER TR_administration_auth_usersUpdate CREATE TRIGGER TR_administration_usersUpdate
AFTER UPDATE AFTER UPDATE
ON administration_auth_users ON administration_users
FOR EACH ROW FOR EACH ROW
BEGIN BEGIN
INSERT INTO administration_auth_users_history INSERT INTO administration_users_history
(id, keycloakId, deleted, editorId, created, updated) (id, keycloakId, deleted, editorId, created, updated)
VALUES (OLD.id, OLD.keycloakId, OLD.deleted, OLD.editorId, OLD.created, NOW()); VALUES (OLD.id, OLD.keycloakId, OLD.deleted, OLD.editorId, OLD.created, NOW());
END; END;
CREATE TRIGGER TR_administration_auth_usersDelete CREATE TRIGGER TR_administration_usersDelete
AFTER DELETE AFTER DELETE
ON administration_auth_users ON administration_users
FOR EACH ROW FOR EACH ROW
BEGIN BEGIN
INSERT INTO administration_auth_users_history INSERT INTO administration_users_history
(id, keycloakId, deleted, editorId, created, updated) (id, keycloakId, deleted, editorId, created, updated)
VALUES (OLD.id, OLD.keycloakId, 1, OLD.editorId, OLD.created, NOW()); VALUES (OLD.id, OLD.keycloakId, 1, OLD.editorId, OLD.created, NOW());
END; END;

View File

@@ -10,7 +10,7 @@ CREATE TABLE IF NOT EXISTS administration_api_keys
CONSTRAINT UC_Identifier_Key UNIQUE (identifier, keyString), CONSTRAINT UC_Identifier_Key UNIQUE (identifier, keyString),
CONSTRAINT UC_Key UNIQUE (keyString), CONSTRAINT UC_Key UNIQUE (keyString),
CONSTRAINT FK_ApiKeys_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) CONSTRAINT FK_ApiKeys_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id)
); );
CREATE TABLE IF NOT EXISTS administration_api_keys_history CREATE TABLE IF NOT EXISTS administration_api_keys_history

View File

@@ -8,7 +8,7 @@ CREATE TABLE IF NOT EXISTS permission_permissions
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
CONSTRAINT UQ_PermissionName UNIQUE (name), CONSTRAINT UQ_PermissionName UNIQUE (name),
CONSTRAINT FK_Permissions_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) CONSTRAINT FK_Permissions_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id)
); );
CREATE TABLE IF NOT EXISTS permission_permissions_history CREATE TABLE IF NOT EXISTS permission_permissions_history
@@ -52,7 +52,7 @@ CREATE TABLE IF NOT EXISTS permission_roles
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
CONSTRAINT UQ_RoleName UNIQUE (name), CONSTRAINT UQ_RoleName UNIQUE (name),
CONSTRAINT FK_Roles_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) CONSTRAINT FK_Roles_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id)
); );
CREATE TABLE IF NOT EXISTS permission_roles_history CREATE TABLE IF NOT EXISTS permission_roles_history
@@ -89,22 +89,22 @@ END;
CREATE TABLE IF NOT EXISTS permission_role_permissions CREATE TABLE IF NOT EXISTS permission_role_permissions
( (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
RoleId INT NOT NULL, roleId INT NOT NULL,
permissionId INT NOT NULL, permissionId INT NOT NULL,
deleted BOOL NOT NULL DEFAULT FALSE, deleted BOOL NOT NULL DEFAULT FALSE,
editorId INT NULL, editorId INT NULL,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
CONSTRAINT UQ_RolePermission UNIQUE (RoleId, permissionId), CONSTRAINT UQ_RolePermission UNIQUE (roleId, permissionId),
CONSTRAINT FK_RolePermissions_Role FOREIGN KEY (RoleId) REFERENCES permission_roles (id) ON DELETE CASCADE, CONSTRAINT FK_RolePermissions_Role FOREIGN KEY (roleId) REFERENCES permission_roles (id) ON DELETE CASCADE,
CONSTRAINT FK_RolePermissions_Permission FOREIGN KEY (permissionId) REFERENCES permission_permissions (id) ON DELETE CASCADE, CONSTRAINT FK_RolePermissions_Permission FOREIGN KEY (permissionId) REFERENCES permission_permissions (id) ON DELETE CASCADE,
CONSTRAINT FK_RolePermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) CONSTRAINT FK_RolePermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id)
); );
CREATE TABLE IF NOT EXISTS permission_role_permissions_history CREATE TABLE IF NOT EXISTS permission_role_permissions_history
( (
id INT NOT NULL, id INT NOT NULL,
RoleId INT NOT NULL, roleId INT NOT NULL,
permissionId INT NOT NULL, permissionId INT NOT NULL,
deleted BOOL NOT NULL, deleted BOOL NOT NULL,
editorId INT NULL, editorId INT NULL,
@@ -118,8 +118,8 @@ CREATE TRIGGER TR_RolePermissionsUpdate
FOR EACH ROW FOR EACH ROW
BEGIN BEGIN
INSERT INTO permission_role_permissions_history INSERT INTO permission_role_permissions_history
(id, RoleId, permissionId, deleted, editorId, created, updated) (id, roleId, permissionId, deleted, editorId, created, updated)
VALUES (OLD.id, OLD.RoleId, OLD.permissionId, OLD.deleted, OLD.editorId, OLD.created, NOW()); VALUES (OLD.id, OLD.roleId, OLD.permissionId, OLD.deleted, OLD.editorId, OLD.created, NOW());
END; END;
CREATE TRIGGER TR_RolePermissionsDelete CREATE TRIGGER TR_RolePermissionsDelete
@@ -128,52 +128,52 @@ CREATE TRIGGER TR_RolePermissionsDelete
FOR EACH ROW FOR EACH ROW
BEGIN BEGIN
INSERT INTO permission_role_permissions_history INSERT INTO permission_role_permissions_history
(id, RoleId, permissionId, deleted, editorId, created, updated) (id, roleId, permissionId, deleted, editorId, created, updated)
VALUES (OLD.id, OLD.RoleId, OLD.permissionId, 1, OLD.editorId, OLD.created, NOW()); VALUES (OLD.id, OLD.roleId, OLD.permissionId, 1, OLD.editorId, OLD.created, NOW());
END; END;
CREATE TABLE IF NOT EXISTS permission_role_auth_users CREATE TABLE IF NOT EXISTS permission_role_users
( (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
RoleId INT NOT NULL, roleId INT NOT NULL,
UserId INT NOT NULL, userId INT NOT NULL,
deleted BOOL NOT NULL DEFAULT FALSE, deleted BOOL NOT NULL DEFAULT FALSE,
editorId INT NULL, editorId INT NULL,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
CONSTRAINT UQ_RoleUser UNIQUE (RoleId, UserId), CONSTRAINT UQ_RoleUser UNIQUE (roleId, userId),
CONSTRAINT FK_Roleauth_users_Role FOREIGN KEY (RoleId) REFERENCES permission_roles (id) ON DELETE CASCADE, CONSTRAINT FK_Roleusers_Role FOREIGN KEY (roleId) REFERENCES permission_roles (id) ON DELETE CASCADE,
CONSTRAINT FK_Roleauth_users_User FOREIGN KEY (UserId) REFERENCES administration_auth_users (id) ON DELETE CASCADE, CONSTRAINT FK_Roleusers_User FOREIGN KEY (userId) REFERENCES administration_users (id) ON DELETE CASCADE,
CONSTRAINT FK_Roleauth_users_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) CONSTRAINT FK_Roleusers_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id)
); );
CREATE TABLE IF NOT EXISTS permission_role_auth_users_history CREATE TABLE IF NOT EXISTS permission_role_users_history
( (
id INT NOT NULL, id INT NOT NULL,
RoleId INT NOT NULL, roleId INT NOT NULL,
UserId INT NOT NULL, userId INT NOT NULL,
deleted BOOL NOT NULL, deleted BOOL NOT NULL,
editorId INT NULL, editorId INT NULL,
created TIMESTAMP NOT NULL, created TIMESTAMP NOT NULL,
updated TIMESTAMP NOT NULL updated TIMESTAMP NOT NULL
); );
CREATE TRIGGER TR_Roleauth_usersUpdate CREATE TRIGGER TR_RoleusersUpdate
AFTER UPDATE AFTER UPDATE
ON permission_role_auth_users ON permission_role_users
FOR EACH ROW FOR EACH ROW
BEGIN BEGIN
INSERT INTO permission_role_auth_users_history INSERT INTO permission_role_users_history
(id, RoleId, UserId, deleted, editorId, created, updated) (id, roleId, userId, deleted, editorId, created, updated)
VALUES (OLD.id, OLD.RoleId, OLD.UserId, OLD.deleted, OLD.editorId, OLD.created, NOW()); VALUES (OLD.id, OLD.roleId, OLD.userId, OLD.deleted, OLD.editorId, OLD.created, NOW());
END; END;
CREATE TRIGGER TR_Roleauth_usersDelete CREATE TRIGGER TR_RoleusersDelete
AFTER DELETE AFTER DELETE
ON permission_role_auth_users ON permission_role_users
FOR EACH ROW FOR EACH ROW
BEGIN BEGIN
INSERT INTO permission_role_auth_users_history INSERT INTO permission_role_users_history
(id, RoleId, UserId, deleted, editorId, created, updated) (id, roleId, userId, deleted, editorId, created, updated)
VALUES (OLD.id, OLD.RoleId, OLD.UserId, 1, OLD.editorId, OLD.created, NOW()); VALUES (OLD.id, OLD.roleId, OLD.userId, 1, OLD.editorId, OLD.created, NOW());
END; END;

View File

@@ -10,7 +10,7 @@ CREATE TABLE IF NOT EXISTS permission_api_key_permissions
CONSTRAINT UQ_ApiKeyPermission UNIQUE (apiKeyId, permissionId), CONSTRAINT UQ_ApiKeyPermission UNIQUE (apiKeyId, permissionId),
CONSTRAINT FK_ApiKeyPermissions_ApiKey FOREIGN KEY (apiKeyId) REFERENCES administration_api_keys (id) ON DELETE CASCADE, CONSTRAINT FK_ApiKeyPermissions_ApiKey FOREIGN KEY (apiKeyId) REFERENCES administration_api_keys (id) ON DELETE CASCADE,
CONSTRAINT FK_ApiKeyPermissions_Permission FOREIGN KEY (permissionId) REFERENCES permission_permissions (id) ON DELETE CASCADE, CONSTRAINT FK_ApiKeyPermissions_Permission FOREIGN KEY (permissionId) REFERENCES permission_permissions (id) ON DELETE CASCADE,
CONSTRAINT FK_ApiKeyPermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id) CONSTRAINT FK_ApiKeyPermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id)
); );
CREATE TABLE IF NOT EXISTS permission_api_key_permissions_history CREATE TABLE IF NOT EXISTS permission_api_key_permissions_history

View File

@@ -1,26 +1,26 @@
CREATE SCHEMA IF NOT EXISTS administration; CREATE SCHEMA IF NOT EXISTS administration;
CREATE TABLE IF NOT EXISTS administration.auth_users CREATE TABLE IF NOT EXISTS administration.users
( (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
keycloakId UUID NOT NULL, keycloakId UUID NOT NULL,
-- for history -- for history
deleted BOOLEAN NOT NULL DEFAULT FALSE, deleted BOOLEAN NOT NULL DEFAULT FALSE,
editorId INT NULL REFERENCES administration.auth_users (id), editorId INT NULL REFERENCES administration.users (id),
created timestamptz NOT NULL DEFAULT NOW(), created timestamptz NOT NULL DEFAULT NOW(),
updated timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(),
CONSTRAINT UC_KeycloakId UNIQUE (keycloakId) CONSTRAINT UC_KeycloakId UNIQUE (keycloakId)
); );
CREATE TABLE IF NOT EXISTS administration.auth_users_history CREATE TABLE IF NOT EXISTS administration.users_history
( (
LIKE administration.auth_users LIKE administration.users
); );
CREATE TRIGGER users_history_trigger CREATE TRIGGER users_history_trigger
BEFORE INSERT OR UPDATE OR DELETE BEFORE INSERT OR UPDATE OR DELETE
ON administration.auth_users ON administration.users
FOR EACH ROW FOR EACH ROW
EXECUTE FUNCTION public.history_trigger_function(); EXECUTE FUNCTION public.history_trigger_function();

View File

@@ -7,7 +7,7 @@ CREATE TABLE IF NOT EXISTS administration.api_keys
keyString VARCHAR(255) NOT NULL, keyString VARCHAR(255) NOT NULL,
-- for history -- for history
deleted BOOLEAN NOT NULL DEFAULT FALSE, deleted BOOLEAN NOT NULL DEFAULT FALSE,
editorId INT NULL REFERENCES administration.auth_users (id), editorId INT NULL REFERENCES administration.users (id),
created timestamptz NOT NULL DEFAULT NOW(), created timestamptz NOT NULL DEFAULT NOW(),
updated timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(),

View File

@@ -9,7 +9,7 @@ CREATE TABLE permission.permissions
-- for history -- for history
deleted BOOLEAN NOT NULL DEFAULT FALSE, deleted BOOLEAN NOT NULL DEFAULT FALSE,
editorId INT NULL REFERENCES administration.auth_users (id), editorId INT NULL REFERENCES administration.users (id),
created timestamptz NOT NULL DEFAULT NOW(), created timestamptz NOT NULL DEFAULT NOW(),
updated timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(),
CONSTRAINT UQ_PermissionName UNIQUE (name) CONSTRAINT UQ_PermissionName UNIQUE (name)
@@ -35,7 +35,7 @@ CREATE TABLE permission.roles
-- for history -- for history
deleted BOOLEAN NOT NULL DEFAULT FALSE, deleted BOOLEAN NOT NULL DEFAULT FALSE,
editorId INT NULL REFERENCES administration.auth_users (id), editorId INT NULL REFERENCES administration.users (id),
created timestamptz NOT NULL DEFAULT NOW(), created timestamptz NOT NULL DEFAULT NOW(),
updated timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(),
CONSTRAINT UQ_RoleName UNIQUE (name) CONSTRAINT UQ_RoleName UNIQUE (name)
@@ -61,7 +61,7 @@ CREATE TABLE permission.role_permissions
-- for history -- for history
deleted BOOLEAN NOT NULL DEFAULT FALSE, deleted BOOLEAN NOT NULL DEFAULT FALSE,
editorId INT NULL REFERENCES administration.auth_users (id), editorId INT NULL REFERENCES administration.users (id),
created timestamptz NOT NULL DEFAULT NOW(), created timestamptz NOT NULL DEFAULT NOW(),
updated timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(),
CONSTRAINT UQ_RolePermission UNIQUE (RoleId, permissionId) CONSTRAINT UQ_RolePermission UNIQUE (RoleId, permissionId)
@@ -83,11 +83,11 @@ CREATE TABLE permission.role_users
( (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
RoleId INT NOT NULL REFERENCES permission.roles (id) ON DELETE CASCADE, RoleId INT NOT NULL REFERENCES permission.roles (id) ON DELETE CASCADE,
UserId INT NOT NULL REFERENCES administration.auth_users (id) ON DELETE CASCADE, UserId INT NOT NULL REFERENCES administration.users (id) ON DELETE CASCADE,
-- for history -- for history
deleted BOOLEAN NOT NULL DEFAULT FALSE, deleted BOOLEAN NOT NULL DEFAULT FALSE,
editorId INT NULL REFERENCES administration.auth_users (id), editorId INT NULL REFERENCES administration.users (id),
created timestamptz NOT NULL DEFAULT NOW(), created timestamptz NOT NULL DEFAULT NOW(),
updated timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(),
CONSTRAINT UQ_RoleUser UNIQUE (RoleId, UserId) CONSTRAINT UQ_RoleUser UNIQUE (RoleId, UserId)

View File

@@ -6,7 +6,7 @@ CREATE TABLE permission.api_key_permissions
-- for history -- for history
deleted BOOLEAN NOT NULL DEFAULT FALSE, deleted BOOLEAN NOT NULL DEFAULT FALSE,
editorId INT NULL REFERENCES administration.auth_users (id), editorId INT NULL REFERENCES administration.users (id),
created timestamptz NOT NULL DEFAULT NOW(), created timestamptz NOT NULL DEFAULT NOW(),
updated timestamptz NOT NULL DEFAULT NOW(), updated timestamptz NOT NULL DEFAULT NOW(),
CONSTRAINT UQ_ApiKeyPermission UNIQUE (apiKeyId, permissionId) CONSTRAINT UQ_ApiKeyPermission UNIQUE (apiKeyId, permissionId)

View File

@@ -1,13 +1,13 @@
from contextvars import ContextVar from contextvars import ContextVar
from typing import Optional from typing import Optional
from cpl.auth.schema._administration.auth_user import AuthUser from cpl.auth.schema._administration.user import User
from cpl.dependency import get_provider from cpl.dependency import get_provider
_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None) _user_context: ContextVar[Optional[User]] = ContextVar("user", default=None)
def set_user(user: Optional[AuthUser]): def set_user(user: Optional[User]):
from cpl.core.log.logger_abc import LoggerABC from cpl.core.log.logger_abc import LoggerABC
logger = get_provider().get_service(LoggerABC) logger = get_provider().get_service(LoggerABC)
@@ -15,5 +15,5 @@ def set_user(user: Optional[AuthUser]):
_user_context.set(user) _user_context.set(user)
def get_user() -> Optional[AuthUser]: def get_user() -> Optional[User]:
return _user_context.get() return _user_context.get()

View File

@@ -68,7 +68,7 @@ class StructuredLogger(Logger):
message["request"] = { message["request"] = {
"url": str(request.url), "url": str(request.url),
"method": request.method, "method": request.method if request.scope == "http" else "websocket",
"scope": self._scope_to_json(request), "scope": self._scope_to_json(request),
} }
if isinstance(request, Request) and request.scope == "http": if isinstance(request, Request) and request.scope == "http":

View File

@@ -2,10 +2,6 @@ import os
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from cpl.core.log.logger import Logger
_logger = Logger(__name__)
class CredentialManager: class CredentialManager:
r"""Handles credential encryption and decryption""" r"""Handles credential encryption and decryption"""
@@ -14,6 +10,8 @@ class CredentialManager:
@classmethod @classmethod
def with_secret(cls, file: str = None): def with_secret(cls, file: str = None):
from cpl.core.log import Logger
if file is None: if file is None:
file = ".secret" file = ".secret"
@@ -25,12 +23,12 @@ class CredentialManager:
with open(file, "w") as secret_file: with open(file, "w") as secret_file:
secret_file.write(Fernet.generate_key().decode()) secret_file.write(Fernet.generate_key().decode())
secret_file.close() secret_file.close()
_logger.warning("Secret file not found, regenerating") Logger(__name__).warning("Secret file not found, regenerating")
with open(file, "r") as secret_file: with open(file, "r") as secret_file:
secret = secret_file.read().strip() secret = secret_file.read().strip()
if secret == "" or secret is None: if secret == "" or secret is None:
_logger.fatal("No secret found in .secret file.") Logger(__name__).fatal("No secret found in .secret file.")
cls._secret = str(secret) cls._secret = str(secret)

View File

@@ -46,6 +46,10 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
def table_name(self) -> str: def table_name(self) -> str:
return self._table_name return self._table_name
@property
def type(self) -> Type[T_DBM]:
return self._model_type
def has_attribute(self, attr_name: Attribute) -> bool: def has_attribute(self, attr_name: Attribute) -> bool:
""" """
Check if the attribute exists in the DAO Check if the attribute exists in the DAO
@@ -81,7 +85,7 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
self.__ignored_attributes.add(attr_name) self.__ignored_attributes.add(attr_name)
if not db_name: if not db_name:
db_name = attr_name.lower().replace("_", "") db_name = String.to_camel_case(attr_name)
self.__db_names[attr_name] = db_name self.__db_names[attr_name] = db_name
self.__db_names[db_name] = db_name self.__db_names[db_name] = db_name
@@ -490,16 +494,16 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
table, join_condition = self.__foreign_tables[attr] table, join_condition = self.__foreign_tables[attr]
builder.with_left_join(table, join_condition) builder.with_left_join(table, join_condition)
if filters: if filters is not None:
await self._build_conditions(builder, filters, external_table_deps) await self._build_conditions(builder, filters, external_table_deps)
if sorts: if sorts is not None:
self._build_sorts(builder, sorts, external_table_deps) self._build_sorts(builder, sorts, external_table_deps)
if take: if take is not None:
builder.with_limit(take) builder.with_limit(take)
if skip: if skip is not None:
builder.with_offset(skip) builder.with_offset(skip)
for external_table in external_table_deps: for external_table in external_table_deps:

View File

@@ -12,9 +12,9 @@ class DbJoinModelABC[T](DbModelABC[T]):
source_id: Id, source_id: Id,
foreign_id: Id, foreign_id: Id,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None = None,
updated: Optional[datetime] = None, updated: datetime | None = None,
): ):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated) DbModelABC.__init__(self, id, deleted, editor_id, created, updated)

View File

@@ -2,7 +2,10 @@ from abc import ABC
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional, Generic from typing import Optional, Generic
from async_property import async_property
from cpl.core.typing import Id, SerialId, T from cpl.core.typing import Id, SerialId, T
from cpl.dependency import get_provider
class DbModelABC(ABC, Generic[T]): class DbModelABC(ABC, Generic[T]):
@@ -10,9 +13,9 @@ class DbModelABC(ABC, Generic[T]):
self, self,
id: Id, id: Id,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: SerialId | None = None,
created: Optional[datetime] = None, created: datetime | None = None,
updated: Optional[datetime] = None, updated: datetime | None = None,
): ):
self._id = id self._id = id
self._deleted = deleted self._deleted = deleted
@@ -41,14 +44,16 @@ class DbModelABC(ABC, Generic[T]):
def editor_id(self, value: SerialId): def editor_id(self, value: SerialId):
self._editor_id = value self._editor_id = value
# @async_property @async_property
# async def editor(self): async def editor(self):
# if self._editor_id is None: if self._editor_id is None:
# return None return None
#
# from data.schemas.administration.user_dao import userDao from cpl.auth.schema import UserDao
#
# return await userDao.get_by_id(self._editor_id) user_dao = get_provider().get_service(UserDao)
return await user_dao.get_by_id(self._editor_id)
@property @property
def created(self) -> datetime: def created(self) -> datetime:

View File

@@ -18,7 +18,7 @@ class DbModelDaoABC[T_DBM](DataAccessObjectABC[T_DBM]):
self.attribute(DbModelABC.editor_id, int, db_name="editorId", ignore=True) # handled by db trigger self.attribute(DbModelABC.editor_id, int, db_name="editorId", ignore=True) # handled by db trigger
self.reference( self.reference(
"editor", "id", DbModelABC.editor_id, TableManager.get("auth_users") "editor", "id", DbModelABC.editor_id, TableManager.get("users")
) # not relevant for updates due to editor_id ) # not relevant for updates due to editor_id
self.attribute(DbModelABC.created, datetime, ignore=True) # handled by db trigger self.attribute(DbModelABC.created, datetime, ignore=True) # handled by db trigger

View File

@@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from cpl.core.configuration import Configuration from cpl.core.configuration.configuration import Configuration
from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC from cpl.core.configuration.configuration_model_abc import ConfigurationModelABC

View File

@@ -1,6 +1,8 @@
from typing import Optional, Any from typing import Optional, Any
import sqlparse import sqlparse
import asyncio
from mysql.connector import errors, PoolError
from mysql.connector.aio import MySQLConnectionPool from mysql.connector.aio import MySQLConnectionPool
from cpl.core.environment import Environment from cpl.core.environment import Environment
@@ -10,7 +12,6 @@ from cpl.dependency.context import get_provider
class MySQLPool: class MySQLPool:
def __init__(self, database_settings: DatabaseSettings): def __init__(self, database_settings: DatabaseSettings):
self._dbconfig = { self._dbconfig = {
"host": database_settings.host, "host": database_settings.host,
@@ -25,59 +26,87 @@ class MySQLPool:
"ssl_disabled": database_settings.ssl_disabled, "ssl_disabled": database_settings.ssl_disabled,
} }
self._pool: Optional[MySQLConnectionPool] = None self._pool: Optional[MySQLConnectionPool] = None
self._pool_lock = asyncio.Lock()
async def _get_pool(self): async def _get_pool(self) -> MySQLConnectionPool:
if self._pool is None: if self._pool is None:
try: async with self._pool_lock:
self._pool = MySQLConnectionPool( if self._pool is None:
pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig try:
) self._pool = MySQLConnectionPool(
await self._pool.initialize_pool() pool_name="cplpool",
pool_size=Environment.get("DB_POOL_SIZE", int, 20),
**self._dbconfig,
)
await self._pool.initialize_pool()
con = await self._pool.get_connection() # Testverbindung (Ping)
async with await con.cursor() as cursor: con = await self._pool.get_connection()
await cursor.execute("SELECT 1") try:
await cursor.fetchall() async with await con.cursor() as cursor:
await cursor.execute("SELECT 1")
await con.close() await cursor.fetchall()
except Exception as e: finally:
logger = get_provider().get_service(DBLogger) await con.close()
logger.fatal(f"Error connecting to the database", e)
except Exception as e:
logger = get_provider().get_service(DBLogger)
logger.fatal("Error connecting to the database", e)
raise
return self._pool return self._pool
async def _get_connection(self, retries: int = 3, delay: float = 0.5):
"""Stabiler Connection-Getter mit Retry und Ping"""
pool = await self._get_pool()
for attempt in range(retries):
try:
con = await pool.get_connection()
# Verbindungs-Check (Ping)
try:
async with await con.cursor() as cursor:
await cursor.execute("SELECT 1")
await cursor.fetchall()
except errors.OperationalError:
await con.close()
raise
return con
except PoolError:
if attempt == retries - 1:
raise
await asyncio.sleep(delay)
@staticmethod @staticmethod
async def _exec_sql(cursor: Any, query: str, args=None, multi=True): async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
result = [] result = []
if multi: if multi:
queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()] queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()]
for q in queries: for q in queries:
if q.strip() == "": if q:
continue await cursor.execute(q, args)
await cursor.execute(q, args) if cursor.description is not None:
if cursor.description is not None: result = await cursor.fetchall()
result = await cursor.fetchall()
else: else:
await cursor.execute(query, args) await cursor.execute(query, args)
if cursor.description is not None: if cursor.description is not None:
result = await cursor.fetchall() result = await cursor.fetchall()
return result return result
async def execute(self, query: str, args=None, multi=True) -> list[list]: async def execute(self, query: str, args=None, multi=True) -> list[str]:
pool = await self._get_pool() con = await self._get_connection()
con = await pool.get_connection()
try: try:
async with await con.cursor() as cursor: async with await con.cursor() as cursor:
result = await self._exec_sql(cursor, query, args, multi) res = await self._exec_sql(cursor, query, args, multi)
await con.commit() await con.commit()
return result return list(res)
finally: finally:
await con.close() await con.close()
async def select(self, query: str, args=None, multi=True) -> list[str]: async def select(self, query: str, args=None, multi=True) -> list[str]:
pool = await self._get_pool() con = await self._get_connection()
con = await pool.get_connection()
try: try:
async with await con.cursor() as cursor: async with await con.cursor() as cursor:
res = await self._exec_sql(cursor, query, args, multi) res = await self._exec_sql(cursor, query, args, multi)
@@ -86,11 +115,17 @@ class MySQLPool:
await con.close() await con.close()
async def select_map(self, query: str, args=None, multi=True) -> list[dict]: async def select_map(self, query: str, args=None, multi=True) -> list[dict]:
pool = await self._get_pool() con = await self._get_connection()
con = await pool.get_connection()
try: try:
async with await con.cursor(dictionary=True) as cursor: async with await con.cursor(dictionary=True) as cursor:
res = await self._exec_sql(cursor, query, args, multi) res = await self._exec_sql(cursor, query, args, multi)
return list(res) decoded_res = []
for row in res:
decoded_row = {
k: (v.decode("utf-8") if isinstance(v, (bytes, bytearray)) else v) for k, v in row.items()
}
decoded_res.append(decoded_row)
return decoded_res
finally: finally:
await con.close() await con.close()

View File

@@ -27,7 +27,7 @@ class PostgresPool:
self._pool: Optional[AsyncConnectionPool] = None self._pool: Optional[AsyncConnectionPool] = None
async def _get_pool(self): async def _get_pool(self):
if self._pool is None: if self._pool is None or self._pool.closed:
pool = AsyncConnectionPool( pool = AsyncConnectionPool(
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1) conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
) )

View File

@@ -1,15 +1,15 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional, Self
from cpl.database.abc import DbModelABC from cpl.database.abc import DbModelABC
class ExecutedMigration(DbModelABC): class ExecutedMigration(DbModelABC[Self]):
def __init__( def __init__(
self, self,
migration_id: str, migration_id: str,
created: Optional[datetime] = None, created: datetime | None = None,
modified: Optional[datetime] = None, modified: datetime | None = None,
): ):
DbModelABC.__init__(self, migration_id, False, created, modified) DbModelABC.__init__(self, migration_id, False, created, modified)

View File

@@ -7,9 +7,9 @@ class TableManager:
ServerTypes.POSTGRES: "system._executed_migrations", ServerTypes.POSTGRES: "system._executed_migrations",
ServerTypes.MYSQL: "system__executed_migrations", ServerTypes.MYSQL: "system__executed_migrations",
}, },
"auth_users": { "users": {
ServerTypes.POSTGRES: "administration.auth_users", ServerTypes.POSTGRES: "administration.users",
ServerTypes.MYSQL: "administration_auth_users", ServerTypes.MYSQL: "administration_users",
}, },
"api_keys": { "api_keys": {
ServerTypes.POSTGRES: "administration.api_keys", ServerTypes.POSTGRES: "administration.api_keys",
@@ -33,7 +33,7 @@ class TableManager:
}, },
"role_users": { "role_users": {
ServerTypes.POSTGRES: "permission.role_users", ServerTypes.POSTGRES: "permission.role_users",
ServerTypes.MYSQL: "permission_role_auth_users", ServerTypes.MYSQL: "permission_role_users",
}, },
} }

View File

@@ -0,0 +1,10 @@
from abc import abstractmethod, ABC
from typing import Any, AsyncGenerator
class EventBusABC(ABC):
@abstractmethod
async def publish(self, channel: str, event: Any) -> None: ...
@abstractmethod
async def subscribe(self, channel: str) -> AsyncGenerator[Any, None]: ...

View File

@@ -8,7 +8,7 @@ class ModuleABC(ABC):
__OPTIONAL_VARS = ["dependencies", "configuration", "singleton", "scoped", "transient", "hosted"] __OPTIONAL_VARS = ["dependencies", "configuration", "singleton", "scoped", "transient", "hosted"]
def __init_subclass__(cls): def __init_subclass__(cls):
super().__init_subclass__() ABC.__init_subclass__()
if f"{cls.__module__}.{cls.__name__}" == "cpl.dependency.module.module.Module": if f"{cls.__module__}.{cls.__name__}" == "cpl.dependency.module.module.Module":
return return

View File

@@ -25,7 +25,7 @@ class ServiceProvider:
for descriptor in self._service_descriptors: for descriptor in self._service_descriptors:
if typing.get_origin(service_type) is None and ( if typing.get_origin(service_type) is None and (
descriptor.service_type == service_type descriptor.service_type.__name__ == service_type.__name__
or typing.get_origin(descriptor.base_type) is None or typing.get_origin(descriptor.base_type) is None
and issubclass(descriptor.base_type, service_type) and issubclass(descriptor.base_type, service_type)
): ):

View File

View File

@@ -0,0 +1,69 @@
from starlette.responses import HTMLResponse
async def graphiql_endpoint(request):
return HTMLResponse(
"""
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title>GraphiQL</title>
<link
href="https://unpkg.com/graphiql@2.4.0/graphiql.min.css"
rel="stylesheet"
/>
</head>
<body style="margin:0;overflow:hidden;">
<div id="graphiql" style="height:100vh;"></div>
<!-- React + ReactDOM -->
<script src="https://unpkg.com/react@18.2.0/umd/react.production.min.js"></script>
<script src="https://unpkg.com/react-dom@18.2.0/umd/react-dom.production.min.js"></script>
<!-- GraphiQL -->
<script src="https://unpkg.com/graphiql@2.4.0/graphiql.min.js"></script>
<!-- GraphQL over WebSocket client -->
<script src="https://unpkg.com/graphql-ws@5.11.3/umd/graphql-ws.min.js"></script>
<script>
const httpUrl = window.location.origin + '/api/graphql';
const wsUrl = (window.location.protocol === 'https:' ? 'wss://' : 'ws://') +
window.location.host + '/api/graphql/ws';
// HTTP fetcher for queries & mutations
const httpFetcher = async (params) => {
const res = await fetch(httpUrl, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(params),
});
return res.json();
};
// WebSocket fetcher for subscriptions
const wsClient = graphqlWs.createClient({ url: wsUrl });
const wsFetcher = (params) => ({
subscribe: (sink) => ({
unsubscribe: wsClient.subscribe(params, sink),
}),
});
// smart fetcher wrapper (decides HTTP or WS)
const graphQLFetcher = (params) => {
if (params.query.trim().startsWith('subscription')) {
return wsFetcher(params);
}
return httpFetcher(params);
};
ReactDOM.render(
React.createElement(GraphiQL, { fetcher: graphQLFetcher }),
document.getElementById('graphiql'),
);
</script>
</body>
</html>
"""
)

View File

@@ -0,0 +1,13 @@
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from cpl.graphql.service.graphql import GraphQLService
async def graphql_endpoint(request: Request, service: GraphQLService) -> Response:
body = await request.json()
query = body.get("query")
variables = body.get("variables")
response_data = await service.execute(query, variables, request)
return JSONResponse(response_data)

View File

@@ -0,0 +1,27 @@
from starlette.requests import Request
from starlette.responses import Response
from strawberry.asgi import GraphQL
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
from cpl.dependency import ServiceProvider
from cpl.graphql.service.schema import Schema
class LazyGraphQLApp:
def __init__(self, services: ServiceProvider):
self._services = services
self._graphql_app = None
async def __call__(self, scope, receive, send):
if self._graphql_app is None:
schema = self._services.get_service(Schema)
if not schema or not schema.schema:
raise RuntimeError("GraphQL Schema not available yet")
self._graphql_app = GraphQL(
schema.schema,
subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
)
await self._graphql_app(scope, receive, send)

View File

@@ -0,0 +1,29 @@
from starlette.requests import Request
from starlette.responses import Response, HTMLResponse
async def playground_endpoint(request: Request) -> Response:
return HTMLResponse(
"""
<!DOCTYPE html>
<html>
<head>
<meta charset=utf-8/>
<title>GraphQL Playground</title>
<link rel="stylesheet" href="https://unpkg.com/graphql-playground-react/build/static/css/index.css" />
<link rel="shortcut icon" href="https://raw.githubusercontent.com/graphql/graphql-playground/master/packages/graphql-playground-react/public/favicon.png" />
<script src="https://unpkg.com/graphql-playground-react/build/static/js/middleware.js"></script>
</head>
<body>
<div id="root"/>
<script>
window.addEventListener('load', function () {
GraphQLPlayground.init(document.getElementById('root'), {
endpoint: '/api/graphql'
})
})
</script>
</body>
</html>
"""
)

View File

@@ -0,0 +1,227 @@
import functools
import inspect
import types
from abc import ABC
from asyncio import iscoroutinefunction
from typing import Callable, Type, Any, Optional
import strawberry
from async_property.base import AsyncPropertyDescriptor
from strawberry.exceptions import StrawberryException
from cpl.api import Unauthorized, Forbidden
from cpl.core.ctx.user_context import get_user
from cpl.dependency import get_provider
from cpl.graphql.abc.strawberry_protocol import StrawberryProtocol
from cpl.graphql.error import graphql_error
from cpl.graphql.query_context import QueryContext
from cpl.graphql.schema.field import Field
from cpl.graphql.typing import Resolver, AttributeName
from cpl.graphql.utils.type_collector import TypeCollector
class QueryABC(StrawberryProtocol, ABC):
def __init__(self):
ABC.__init__(self)
self._fields: dict[str, Field] = {}
@property
def fields(self) -> dict[str, Field]:
return self._fields
@property
def fields_count(self) -> int:
return len(self._fields)
def get_fields(self) -> dict[str, Field]:
return self._fields
def field(
self,
name: AttributeName,
t: type,
resolver: Resolver = None,
) -> Field:
from cpl.graphql.schema.field import Field
if isinstance(name, property):
name = name.fget.__name__
self._fields[name] = Field(name, t, resolver)
return self._fields[name]
def string_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
return self.field(name, str, resolver)
def int_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
return self.field(name, int, resolver)
def float_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
return self.field(name, float, resolver)
def bool_field(self, name: AttributeName, resolver: Resolver = None) -> Field:
return self.field(name, bool, resolver)
def list_field(self, name: AttributeName, t: type, resolver: Resolver = None) -> Field:
return self.field(name, list[t], resolver)
def object_field(self, name: str, t: Type[StrawberryProtocol], resolver: Resolver = None) -> Field:
if not isinstance(t, type) and callable(t):
return self.field(name, t, resolver)
return self.field(name, t().to_strawberry(), resolver)
@staticmethod
def _build_resolver(f: "Field"):
params: list[inspect.Parameter] = []
for arg in f.arguments.values():
_type = arg.type
if isinstance(_type, type) and issubclass(_type, StrawberryProtocol):
_type = _type().to_strawberry()
ann = Optional[_type] if arg.optional else _type
if arg.default is None:
param = inspect.Parameter(
arg.name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=ann,
)
else:
param = inspect.Parameter(
arg.name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=ann,
default=arg.default,
)
params.append(param)
sig = inspect.Signature(parameters=params, return_annotation=f.type)
async def _resolver(*args, **kwargs):
if f.resolver is None:
return None
if iscoroutinefunction(f.resolver):
return await f.resolver(*args, **kwargs)
return f.resolver(*args, **kwargs)
_resolver.__signature__ = sig
return _resolver
def _wrap_with_auth(self, f: Field, resolver: Callable) -> Callable:
sig = getattr(resolver, "__signature__", None)
@functools.wraps(resolver)
async def _auth_resolver(*args, **kwargs):
if f.public:
return await self._run_resolver(resolver, *args, **kwargs)
user = get_user()
if user is None:
raise graphql_error(Unauthorized(f"{f.name}: Authentication required"))
if f.require_any_permission:
if not any([await user.has_permission(p) for p in f.require_any_permission]):
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
if f.require_any:
perms, resolvers = f.require_any
if not any([await user.has_permission(p) for p in perms]):
ctx = QueryContext([x.name for x in await user.permissions])
resolved = [r(ctx) if not iscoroutinefunction(r) else await r(ctx) for r in resolvers]
if not any(resolved):
raise graphql_error(Forbidden(f"{f.name}: Permission denied"))
return await self._run_resolver(resolver, *args, **kwargs)
if sig:
_auth_resolver.__signature__ = sig
return _auth_resolver
@staticmethod
async def _run_resolver(r: Callable, *args, **kwargs):
result = r(*args, **kwargs)
if inspect.isawaitable(result):
return await result
return result
def _field_to_strawberry(self, f: Field) -> Any:
resolver = None
try:
if f.arguments:
resolver = self._build_resolver(f)
elif not f.resolver:
resolver = lambda root: None
else:
ann = getattr(f.resolver, "__annotations__", {})
if "return" not in ann or ann["return"] is None:
ann = dict(ann)
ann["return"] = f.type
f.resolver.__annotations__ = ann
resolver = f.resolver
return strawberry.field(resolver=self._wrap_with_auth(f, resolver))
except StrawberryException as e:
raise Exception(f"Error converting field '{f.name}' to strawberry field: {e}") from e
@staticmethod
def _type_to_strawberry(t: Type) -> Type:
_t = get_provider().get_service(t)
if isinstance(_t, StrawberryProtocol):
return _t.to_strawberry()
return t
def to_strawberry(self) -> Type:
cls = self.__class__
if TypeCollector.has(cls):
return TypeCollector.get(cls)
gql_cls = type(f"{cls.__name__.replace('GraphType', '')}", (), {})
# register early to handle recursive types
TypeCollector.set(cls, gql_cls)
annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {}
for name, f in self._fields.items():
t = f.type
if isinstance(name, property):
name = name.fget.__name__
if isinstance(name, AsyncPropertyDescriptor):
name = name.field_name
if isinstance(t, types.GenericAlias):
t = t.__args__[0]
if callable(t) and not isinstance(t, type):
t = self._type_to_strawberry(t())
elif issubclass(t, StrawberryProtocol):
t = self._type_to_strawberry(t)
annotations[name] = t if not f.optional else Optional[t]
namespace[name] = self._field_to_strawberry(f)
namespace["__annotations__"] = annotations
for k, v in namespace.items():
if isinstance(k, property):
k = k.fget.__name__
if isinstance(k, AsyncPropertyDescriptor):
k = k.field_name
setattr(gql_cls, k, v)
try:
gql_cls.__annotations__ = annotations
gql_type = strawberry.type(gql_cls)
except Exception as e:
raise Exception(f"Error creating strawberry type for '{cls.__name__}': {e}") from e
TypeCollector.set(cls, gql_type)
return gql_type

View File

@@ -0,0 +1,11 @@
from typing import Protocol, Type, runtime_checkable
from cpl.graphql.schema.field import Field
from cpl.graphql.schema.subscription_field import SubscriptionField
@runtime_checkable
class StrawberryProtocol(Protocol):
def to_strawberry(self) -> Type: ...
def get_fields(self) -> dict[str, Field | SubscriptionField]: ...

View File

@@ -0,0 +1 @@
from .graphql_app import WebApp

View File

@@ -0,0 +1,126 @@
import socket
from enum import Enum
from typing import Self
from cpl.api.application import WebApp
from cpl.api.model.validation_match import ValidationMatch
from cpl.application.abc.application_abc import __not_implemented__
from cpl.core.environment import Environment
from cpl.dependency.service_provider import ServiceProvider
from cpl.dependency.typing import Modules
from cpl.graphql._endpoints.graphiql import graphiql_endpoint
from cpl.graphql._endpoints.graphql import graphql_endpoint
from cpl.graphql._endpoints.lazy_graphql_app import LazyGraphQLApp
from cpl.graphql._endpoints.playground import playground_endpoint
from cpl.graphql.graphql_module import GraphQLModule
from cpl.graphql.service.schema import Schema
class GraphQLApp(WebApp):
def __init__(self, services: ServiceProvider, modules: Modules):
WebApp.__init__(self, services, modules, [GraphQLModule])
self._with_graphiql = False
self._with_playground = False
def with_graphql(
self,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Schema:
self.with_route(
path="/api/graphql",
fn=graphql_endpoint,
method="POST",
authentication=authentication,
roles=roles,
permissions=permissions,
policies=policies,
match=match,
)
schema = self._services.get_service(Schema)
if schema is None:
self._logger.fatal("Could not resolve RootQuery. Make sure GraphQLModule is registered.")
#
# graphql_ws_app = GraphQL(
# schema,
# subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
# )
self.with_websocket("/api/graphql/ws", LazyGraphQLApp(self._services))
return schema
def with_graphiql(
self,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Self:
self.with_route(
path="/api/graphiql",
fn=graphiql_endpoint,
method="GET",
authentication=authentication,
roles=roles,
permissions=permissions,
policies=policies,
match=match,
)
self._with_graphiql = True
return self
def with_playground(
self,
authentication: bool = False,
roles: list[str | Enum] = None,
permissions: list[str | Enum] = None,
policies: list[str] = None,
match: ValidationMatch = None,
) -> Self:
self.with_route(
path="/api/playground",
fn=playground_endpoint,
method="GET",
authentication=authentication,
roles=roles,
permissions=permissions,
policies=policies,
match=match,
)
self._with_playground = True
return self
def with_auth_root_queries(self, public: bool = False):
try:
from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule
GraphQLAuthModule.with_auth_root_queries(self._services, public=public)
except ImportError:
__not_implemented__("cpl-auth & cpl-graphql", self.with_auth_root_mutations)
def with_auth_root_mutations(self, public: bool = False):
try:
from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule
GraphQLAuthModule.with_auth_root_mutations(self._services, public=public)
except ImportError:
__not_implemented__("cpl-auth & cpl-graphql", self.with_auth_root_mutations)
async def _log_before_startup(self):
host = self._api_settings.host
if host == "0.0.0.0" and Environment.get_environment() == "development":
host = "localhost"
elif host == "0.0.0.0":
host = socket.gethostbyname(socket.gethostname())
self._logger.info(f"Start API on {host}:{self._api_settings.port}")
if self._with_graphiql:
self._logger.warning(f"GraphiQL available at http://{host}:{self._api_settings.port}/api/graphiql")
if self._with_playground:
self._logger.warning(
f"GraphQL Playground available at http://{host}:{self._api_settings.port}/api/playground"
)

View File

@@ -0,0 +1,10 @@
from cpl.auth.schema import ApiKey
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
class ApiKeyFilter(DbModelFilter[ApiKey]):
def __init__(self, public: bool = False):
DbModelFilter.__init__(self, public)
self.field("identifier", StringFilter).with_public(public)

View File

@@ -0,0 +1,14 @@
from cpl.auth.schema import ApiKey, RolePermissionDao
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
class ApiKeyGraphType(DbModelGraphType[ApiKey]):
def __init__(self, role_permission_dao: RolePermissionDao):
DbModelGraphType.__init__(self)
self.string_field(ApiKey.identifier, lambda root: root.identifier)
self.string_field(ApiKey.key, lambda root: root.key)
self.string_field(ApiKey.permissions, lambda root: root.permissions)
self.set_history_reference_dao(role_permission_dao, "apikeyid")

View File

@@ -0,0 +1,25 @@
from cpl.auth.schema import ApiKey
from cpl.core.typing import SerialId
from cpl.graphql.schema.input import Input
class ApiKeyCreateInput(Input[ApiKey]):
identifier: str
permissions: list[SerialId]
def __init__(self):
Input.__init__(self)
self.string_field("identifier").with_required()
self.list_field("permissions", SerialId)
class ApiKeyUpdateInput(Input[ApiKey]):
id: SerialId
identifier: str | None
permissions: list[SerialId] | None
def __init__(self):
Input.__init__(self)
self.int_field("id").with_required()
self.string_field("identifier").with_required()
self.list_field("permissions", SerialId)

View File

@@ -0,0 +1,93 @@
from cpl.api import APILogger
from cpl.auth.keycloak import KeycloakAdmin
from cpl.auth.permission import Permissions
from cpl.auth.schema import ApiKey, ApiKeyDao, ApiKeyPermissionDao, ApiKeyPermission
from cpl.graphql.auth.api_key.api_key_input import ApiKeyUpdateInput, ApiKeyCreateInput
from cpl.graphql.schema.mutation import Mutation
class ApiKeyMutation(Mutation):
def __init__(
self,
logger: APILogger,
api_key_dao: ApiKeyDao,
api_key_permission_dao: ApiKeyPermissionDao,
permission_dao: ApiKeyPermissionDao,
keycloak_admin: KeycloakAdmin,
):
Mutation.__init__(self)
self._logger = logger
self._api_key_dao = api_key_dao
self._api_key_permission_dao = api_key_permission_dao
self._permission_dao = permission_dao
self._keycloak_admin = keycloak_admin
self.int_field(
"create",
self.resolve_create,
).with_require_any_permission(Permissions.api_keys_create).with_argument(
"input",
ApiKeyCreateInput,
).with_required()
self.bool_field(
"update",
self.resolve_update,
).with_require_any_permission(Permissions.api_keys_update).with_argument(
"input",
ApiKeyUpdateInput,
).with_required()
self.bool_field(
"delete",
self.resolve_delete,
).with_require_any_permission(Permissions.api_keys_delete).with_argument(
"id",
int,
).with_required()
self.bool_field(
"restore",
self.resolve_restore,
).with_require_any_permission(Permissions.api_keys_delete).with_argument(
"id",
int,
).with_required()
async def resolve_create(self, obj: ApiKeyCreateInput):
self._logger.debug(f"create api key: {obj.__dict__}")
api_key = ApiKey.new(obj.identifier)
await self._api_key_dao.create(api_key)
api_key = await self._api_key_dao.get_single_by([{ApiKey.identifier: obj.identifier}])
await self._api_key_permission_dao.create_many([ApiKeyPermission(0, api_key.id, x) for x in obj.permissions])
return api_key
async def resolve_update(self, input: ApiKeyUpdateInput):
self._logger.debug(f"update api key: {input}")
api_key = await self._api_key_dao.get_by_id(input.id)
await self._resolve_assignments(
input.permissions or [],
api_key,
ApiKeyPermission.api_key_id,
ApiKeyPermission.permission_id,
self._api_key_dao,
self._api_key_permission_dao,
ApiKeyPermission,
self._permission_dao,
)
return api_key
async def resolve_delete(self, id: str):
self._logger.debug(f"delete api key: {id}")
api_key = await self._api_key_dao.get_by_id(id)
await self._api_key_dao.delete(api_key)
return True
async def resolve_restore(self, id: str):
self._logger.debug(f"restore api key: {id}")
api_key = await self._api_key_dao.get_by_id(id)
await self._api_key_dao.restore(api_key)
return True

View File

@@ -0,0 +1,9 @@
from cpl.auth.schema import ApiKey
from cpl.graphql.schema.sort.db_model_sort import DbModelSort
from cpl.graphql.schema.sort.sort_order import SortOrder
class ApiKeySort(DbModelSort[ApiKey]):
def __init__(self):
DbModelSort.__init__(self)
self.field("identifier", SortOrder)

View File

@@ -0,0 +1,77 @@
from cpl.auth.permission import Permissions
from cpl.auth.schema import UserDao, ApiKeyDao, RoleDao
from cpl.core.configuration import Configuration
from cpl.dependency import ServiceProvider
from cpl.dependency.module.module import Module
from cpl.dependency.service_collection import ServiceCollection
from cpl.graphql.auth.api_key.api_key_filter import ApiKeyFilter
from cpl.graphql.auth.api_key.api_key_graph_type import ApiKeyGraphType
from cpl.graphql.auth.api_key.api_key_mutation import ApiKeyMutation
from cpl.graphql.auth.api_key.api_key_sort import ApiKeySort
from cpl.graphql.auth.role.role_filter import RoleFilter
from cpl.graphql.auth.role.role_graph_type import RoleGraphType
from cpl.graphql.auth.role.role_mutation import RoleMutation
from cpl.graphql.auth.role.role_sort import RoleSort
from cpl.graphql.auth.user.user_filter import UserFilter
from cpl.graphql.auth.user.user_graph_type import UserGraphType
from cpl.graphql.auth.user.user_mutation import UserMutation
from cpl.graphql.auth.user.user_sort import UserSort
from cpl.graphql.graphql_module import GraphQLModule
from cpl.graphql.service.schema import Schema
class GraphQLAuthModule(Module):
dependencies = [GraphQLModule]
transient = [
UserGraphType,
UserMutation,
UserFilter,
UserSort,
ApiKeyGraphType,
ApiKeyMutation,
ApiKeyFilter,
ApiKeySort,
RoleGraphType,
RoleMutation,
RoleFilter,
RoleSort,
]
@staticmethod
def register(collection: ServiceCollection):
Configuration.set("GraphQLAuthModuleEnabled", True)
@staticmethod
def configure(provider: ServiceProvider):
schema = provider.get_service(Schema)
schema.with_type(UserGraphType)
schema.with_type(ApiKeyGraphType)
schema.with_type(RoleGraphType)
@staticmethod
def with_auth_root_queries(provider: ServiceProvider, public: bool = False):
if not Configuration.get("GraphQLAuthModuleEnabled", False):
raise Exception("GraphQLAuthModule is not loaded yet. Make sure to run 'add_module(GraphQLAuthModule)'")
schema = provider.get_service(Schema)
schema.query.dao_collection_field(
UserGraphType, UserDao, "users", UserFilter, UserSort
).with_require_any_permission(Permissions.users).with_public(public)
schema.query.dao_collection_field(
ApiKeyGraphType, ApiKeyDao, "apiKeys", ApiKeyFilter, ApiKeySort
).with_require_any_permission(Permissions.api_keys).with_public(public)
schema.query.dao_collection_field(
RoleGraphType, RoleDao, "roles", RoleFilter, RoleSort
).with_require_any_permission(Permissions.roles).with_public(public)
@staticmethod
def with_auth_root_mutations(provider: ServiceProvider, public: bool = False):
if not Configuration.get("GraphQLAuthModuleEnabled", False):
raise Exception("GraphQLAuthModule is not loaded yet. Make sure to run 'add_module(GraphQLAuthModule)'")
schema = provider.get_service(Schema)
schema.mutation.with_mutation("user", UserMutation).with_public(public)
schema.mutation.with_mutation("apiKey", ApiKeyMutation).with_public(public)
schema.mutation.with_mutation("role", RoleMutation).with_public(public)

View File

@@ -0,0 +1,11 @@
from cpl.auth.schema import User, Role
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
class RoleFilter(DbModelFilter[Role]):
def __init__(self, public: bool = False):
DbModelFilter.__init__(self, public)
self.field("name", StringFilter).with_public(public)
self.field("description", StringFilter).with_public(public)

View File

@@ -0,0 +1,14 @@
from cpl.auth.schema import Role
from cpl.graphql.auth.user.user_graph_type import UserGraphType
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
class RoleGraphType(DbModelGraphType[Role]):
def __init__(self, public: bool = False):
DbModelGraphType.__init__(self)
self.string_field("name", lambda root: root.name).with_public(public)
self.string_field("description", lambda root: root.description).with_public(public)
self.list_field("permissions", str, lambda root: root.permissions).with_public(public)
self.list_field("users", UserGraphType, lambda root: root.users).with_public(public)

View File

@@ -0,0 +1,29 @@
from cpl.auth.schema import User, Role
from cpl.core.typing import SerialId
from cpl.graphql.schema.input import Input
class RoleCreateInput(Input[Role]):
name: str
description: str | None
permissions: list[SerialId] | None
def __init__(self):
Input.__init__(self)
self.string_field("name").with_required()
self.string_field("description")
self.list_field("permissions", SerialId)
class RoleUpdateInput(Input[Role]):
id: SerialId
name: str | None
description: str | None
permissions: list[SerialId] | None
def __init__(self):
Input.__init__(self)
self.int_field("id").with_required()
self.string_field("name")
self.string_field("description")
self.list_field("permissions", SerialId)

View File

@@ -0,0 +1,101 @@
from cpl.api import APILogger
from cpl.auth.keycloak import KeycloakAdmin
from cpl.auth.permission import Permissions
from cpl.auth.schema import RoleDao, Role, RolePermissionDao, RolePermission
from cpl.graphql.auth.role.role_input import RoleCreateInput, RoleUpdateInput
from cpl.graphql.schema.mutation import Mutation
class RoleMutation(Mutation):
def __init__(
self,
logger: APILogger,
role_dao: RoleDao,
role_permission_dao: RolePermissionDao,
permission_dao: RolePermissionDao,
keycloak_admin: KeycloakAdmin,
):
Mutation.__init__(self)
self._logger = logger
self._role_dao = role_dao
self._role_permission_dao = role_permission_dao
self._permission_dao = permission_dao
self._keycloak_admin = keycloak_admin
self.int_field(
"create",
self.resolve_create,
).with_require_any_permission(Permissions.roles_create).with_argument(
"input",
RoleCreateInput,
).with_required()
self.bool_field(
"update",
self.resolve_update,
).with_require_any_permission(Permissions.roles_update).with_argument(
"input",
RoleUpdateInput,
).with_required()
self.bool_field(
"delete",
self.resolve_delete,
).with_require_any_permission(Permissions.roles_delete).with_argument(
"id",
int,
).with_required()
self.bool_field(
"restore",
self.resolve_restore,
).with_require_any_permission(Permissions.roles_delete).with_argument(
"id",
int,
).with_required()
async def resolve_create(self, input: RoleCreateInput, *_):
self._logger.debug(f"create role: {input.__dict__}")
role = Role(
0,
input.name,
input.description,
)
await self._role_dao.create(role)
role = await self._role_dao.get_by_name(role.name)
await self._role_permission_dao.create_many([RolePermission(0, role.id, x) for x in input.permissions])
return role
async def resolve_update(self, input: RoleUpdateInput, *_):
self._logger.debug(f"update role: {input.__dict__}")
role = await self._role_dao.get_by_id(input.id)
role.name = input.get("name", role.name)
role.description = input.get("description", role.description)
await self._role_dao.update(role)
await self._resolve_assignments(
input.get("permissions", []),
role,
RolePermission.role_id,
RolePermission.permission_id,
self._role_dao,
self._role_permission_dao,
RolePermission,
self._permission_dao,
)
return role
async def resolve_delete(self, id: int):
self._logger.debug(f"delete role: {id}")
role = await self._role_dao.get_by_id(id)
await self._role_dao.delete(role)
return True
async def resolve_restore(self, id: int):
self._logger.debug(f"restore role: {id}")
role = await self._role_dao.get_by_id(id)
await self._role_dao.restore(role)
return True

View File

@@ -0,0 +1,10 @@
from cpl.auth.schema import Role
from cpl.graphql.schema.sort.db_model_sort import DbModelSort
from cpl.graphql.schema.sort.sort_order import SortOrder
class RoleSort(DbModelSort[Role]):
def __init__(self):
DbModelSort.__init__(self)
self.field("name", SortOrder)
self.field("description", SortOrder)

View File

@@ -0,0 +1,11 @@
from cpl.auth.schema import User
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
class UserFilter(DbModelFilter[User]):
def __init__(self, public: bool = False):
DbModelFilter.__init__(self, public)
self.field("username", StringFilter).with_public(public)
self.field("email", StringFilter).with_public(public)

View File

@@ -0,0 +1,12 @@
from cpl.auth.schema import User
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
class UserGraphType(DbModelGraphType[User]):
def __init__(self, public: bool = False):
DbModelGraphType.__init__(self)
self.string_field(User.keycloak_id, lambda root: root.keycloak_id).with_public(public)
self.string_field(User.username, lambda root: root.username).with_public(public)
self.string_field(User.email, lambda root: root.email).with_public(public)

View File

@@ -0,0 +1,23 @@
from cpl.auth.schema import User
from cpl.core.typing import SerialId
from cpl.graphql.schema.input import Input
class UserCreateInput(Input[User]):
keycloak_id: str
roles: list[SerialId] | None
def __init__(self):
Input.__init__(self)
self.string_field("keycloak_id").with_required()
self.list_field("roles", SerialId)
class UserUpdateInput(Input[User]):
id: SerialId
roles: list[SerialId] | None
def __init__(self):
Input.__init__(self)
self.int_field("id").with_required()
self.list_field("roles", SerialId)

View File

@@ -0,0 +1,112 @@
from cpl.api import APILogger
from cpl.auth.keycloak import KeycloakAdmin
from cpl.auth.permission import Permissions
from cpl.auth.schema import UserDao, User, RoleUser, RoleUserDao, RoleDao
from cpl.core.ctx.user_context import get_user
from cpl.graphql.auth.user.user_input import UserCreateInput, UserUpdateInput
from cpl.graphql.schema.mutation import Mutation
class UserMutation(Mutation):
def __init__(
self,
logger: APILogger,
user_dao: UserDao,
role_user_dao: RoleUserDao,
role_dao: RoleDao,
keycloak_admin: KeycloakAdmin,
):
Mutation.__init__(self)
self._logger = logger
self._user_dao = user_dao
self._role_user_dao = role_user_dao
self._role_dao = role_dao
self._keycloak_admin = keycloak_admin
self.int_field(
"create",
self.resolve_create,
).with_require_any_permission(Permissions.users_create).with_argument(
"input",
UserCreateInput,
).with_required()
self.bool_field(
"update",
self.resolve_update,
).with_require_any_permission(Permissions.users_update).with_argument(
"input",
UserUpdateInput,
).with_required()
self.bool_field(
"delete",
self.resolve_delete,
).with_require_any_permission(Permissions.users_delete).with_argument(
"id",
int,
).with_required()
self.bool_field(
"restore",
self.resolve_restore,
).with_require_any_permission(Permissions.users_delete).with_argument(
"id",
int,
).with_required()
async def resolve_create(self, input: UserCreateInput):
self._logger.debug(f"create user: {input.__dict__}")
# ensure keycloak knows a user with this keycloak_id
# get_user should raise an exception if the user does not exist
kc_user = self._keycloak_admin.get_user(input.keycloak_id)
if kc_user is None:
raise ValueError(f"Keycloak user with id {input.keycloak_id} does not exist")
user = User(0, input.keycloak_id, input.license)
user_id = await self._user_dao.create(user)
user = await self._user_dao.get_by_id(user_id)
await self._role_user_dao.create_many([RoleUser(0, user.id, x) for x in set(input.roles)])
return user
async def resolve_update(self, input: UserUpdateInput):
self._logger.debug(f"update user: {input.__dict__}")
user = await self._user_dao.get_by_id(input.id)
if input.license:
user.license = input.license
await self._user_dao.update(user)
await self._resolve_assignments(
input.roles or [],
user,
RoleUser.user_id,
RoleUser.role_id,
self._user_dao,
self._role_user_dao,
RoleUser,
self._role_dao,
)
return user
async def resolve_delete(self, id: int):
self._logger.debug(f"delete user: {id}")
user = await self._user_dao.get_by_id(id)
await self._user_dao.delete(user)
try:
active_user = get_user()
if active_user is not None and active_user.id == user.id:
# await broadcast.publish("userLogout", user.id)
self._keycloak_admin.user_logout(user_id=user.keycloak_id)
except Exception as e:
self._logger.error(f"Failed to logout user from Keycloak", e)
return True
async def resolve_restore(self, id: int):
self._logger.debug(f"restore user: {id}")
user = await self._user_dao.get_by_id(id)
await self._user_dao.restore(user)
return True

View File

@@ -0,0 +1,10 @@
from cpl.auth.schema import User
from cpl.graphql.schema.sort.db_model_sort import DbModelSort
from cpl.graphql.schema.sort.sort_order import SortOrder
class UserSort(DbModelSort[User]):
def __init__(self):
DbModelSort.__init__(self)
self.field("username", SortOrder)
self.field("email", SortOrder)

View File

@@ -0,0 +1,14 @@
from graphql import GraphQLError
from cpl.api import APIError
def graphql_error(api_error: APIError) -> GraphQLError:
"""Convert an APIError (from cpl-api) into a GraphQL-friendly error."""
return GraphQLError(
message=api_error.error_message,
extensions={
"code": api_error.status_code,
},
original_error=api_error,
)

View File

@@ -0,0 +1,27 @@
import asyncio
from typing import Any, AsyncGenerator
from cpl.dependency.event_bus import EventBusABC
class InMemoryEventBus(EventBusABC):
def __init__(self):
self._subscribers: dict[str, list[asyncio.Queue]] = {}
async def publish(self, channel: str, event: Any) -> None:
queues = self._subscribers.get(channel, [])
for q in queues.copy():
await q.put(event)
async def subscribe(self, channel: str) -> AsyncGenerator[Any, None]:
q = asyncio.Queue()
if channel not in self._subscribers:
self._subscribers[channel] = []
self._subscribers[channel].append(q)
try:
while True:
item = await q.get()
yield item
finally:
self._subscribers[channel].remove(q)

View File

@@ -0,0 +1,25 @@
from cpl.api.api_module import ApiModule
from cpl.dependency.module.module import Module
from cpl.dependency.service_provider import ServiceProvider
from cpl.graphql.schema.filter.bool_filter import BoolFilter
from cpl.graphql.schema.filter.date_filter import DateFilter
from cpl.graphql.schema.filter.filter import Filter
from cpl.graphql.schema.filter.int_filter import IntFilter
from cpl.graphql.schema.filter.string_filter import StringFilter
from cpl.graphql.schema.root_mutation import RootMutation
from cpl.graphql.schema.root_query import RootQuery
from cpl.graphql.schema.root_subscription import RootSubscription
from cpl.graphql.service.graphql import GraphQLService
from cpl.graphql.service.schema import Schema
class GraphQLModule(Module):
dependencies = [ApiModule]
singleton = [Schema, RootQuery, RootMutation, RootSubscription]
scoped = [GraphQLService]
transient = [Filter, StringFilter, IntFilter, BoolFilter, DateFilter]
@staticmethod
def configure(services: ServiceProvider) -> None:
schema = services.get_service(Schema)
schema.build()

Some files were not shown because too many files have changed in this diff Show More