Compare commits
1 Commits
2025.10.08
...
2025.09.22
| Author | SHA1 | Date | |
|---|---|---|---|
| 69bbbc8cee |
@@ -25,11 +25,7 @@ jobs:
|
||||
git tag
|
||||
DATE=$(date +'%Y.%m.%d')
|
||||
TAG_COUNT=$(git tag -l "${DATE}.*" | wc -l)
|
||||
if [ "$TAG_COUNT" -eq 0 ]; then
|
||||
BUILD_NUMBER=0
|
||||
else
|
||||
BUILD_NUMBER=$(($TAG_COUNT + 1))
|
||||
fi
|
||||
BUILD_NUMBER=$(($TAG_COUNT + 1))
|
||||
|
||||
VERSION_SUFFIX=${{ inputs.version_suffix }}
|
||||
if [ -n "$VERSION_SUFFIX" ] && [ "$VERSION_SUFFIX" = "dev" ]; then
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -139,6 +139,3 @@ PythonImportHelper-v2-Completion.json
|
||||
|
||||
# cpl unittest stuff
|
||||
unittests/test_*_playground
|
||||
|
||||
# cpl logs
|
||||
**/logs/*.jsonl
|
||||
|
||||
@@ -1,147 +0,0 @@
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from cpl.dependency.event_bus import EventBusABC
|
||||
from cpl.graphql.event_bus.memory import InMemoryEventBus
|
||||
from queries.cities import CityGraphType, CityFilter, CitySort
|
||||
from queries.hello import UserGraphType # , UserFilter, UserSort, UserGraphType
|
||||
from queries.user import UserFilter, UserSort
|
||||
from cpl.api.api_module import ApiModule
|
||||
from cpl.application.application_builder import ApplicationBuilder
|
||||
from cpl.auth.schema import User, Role
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.core.console import Console
|
||||
from cpl.core.environment import Environment
|
||||
from cpl.core.utils.cache import Cache
|
||||
from cpl.database.mysql.mysql_module import MySQLModule
|
||||
from cpl.graphql.application.graphql_app import GraphQLApp
|
||||
from cpl.graphql.auth.graphql_auth_module import GraphQLAuthModule
|
||||
from cpl.graphql.graphql_module import GraphQLModule
|
||||
from model.author_dao import AuthorDao
|
||||
from model.author_query import AuthorGraphType, AuthorFilter, AuthorSort
|
||||
from model.post_dao import PostDao
|
||||
from model.post_query import PostFilter, PostSort, PostGraphType, PostMutation, PostSubscription
|
||||
from permissions import PostPermissions
|
||||
from queries.hello import HelloQuery
|
||||
from scoped_service import ScopedService
|
||||
from service import PingService
|
||||
from test_data_seeder import TestDataSeeder
|
||||
|
||||
|
||||
def main():
|
||||
builder = ApplicationBuilder[GraphQLApp](GraphQLApp)
|
||||
|
||||
Configuration.add_json_file(f"appsettings.json")
|
||||
Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json")
|
||||
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
|
||||
|
||||
# builder.services.add_logging()
|
||||
(
|
||||
builder.services.add_structured_logging()
|
||||
.add_transient(PingService)
|
||||
.add_module(MySQLModule)
|
||||
.add_module(ApiModule)
|
||||
.add_module(GraphQLModule)
|
||||
.add_module(GraphQLAuthModule)
|
||||
.add_scoped(ScopedService)
|
||||
.add_singleton(EventBusABC, InMemoryEventBus)
|
||||
.add_cache(User)
|
||||
.add_cache(Role)
|
||||
.add_transient(CityGraphType)
|
||||
.add_transient(CityFilter)
|
||||
.add_transient(CitySort)
|
||||
.add_transient(UserGraphType)
|
||||
.add_transient(UserFilter)
|
||||
.add_transient(UserSort)
|
||||
# .add_transient(UserGraphType)
|
||||
# .add_transient(UserFilter)
|
||||
# .add_transient(UserSort)
|
||||
.add_transient(HelloQuery)
|
||||
# test data
|
||||
.add_singleton(TestDataSeeder)
|
||||
# authors
|
||||
.add_transient(AuthorDao)
|
||||
.add_transient(AuthorGraphType)
|
||||
.add_transient(AuthorFilter)
|
||||
.add_transient(AuthorSort)
|
||||
# posts
|
||||
.add_transient(PostDao)
|
||||
.add_transient(PostGraphType)
|
||||
.add_transient(PostFilter)
|
||||
.add_transient(PostSort)
|
||||
.add_transient(PostMutation)
|
||||
.add_transient(PostSubscription)
|
||||
)
|
||||
|
||||
app = builder.build()
|
||||
app.with_logging()
|
||||
app.with_migrations("./scripts")
|
||||
|
||||
app.with_authentication()
|
||||
app.with_authorization()
|
||||
|
||||
app.with_route(
|
||||
path="/route1",
|
||||
fn=lambda r: JSONResponse("route1"),
|
||||
method="GET",
|
||||
# authentication=True,
|
||||
# permissions=[Permissions.administrator],
|
||||
)
|
||||
app.with_routes_directory("routes")
|
||||
|
||||
schema = app.with_graphql()
|
||||
schema.query.string_field("ping", resolver=lambda: "pong")
|
||||
schema.query.with_query("hello", HelloQuery)
|
||||
schema.query.dao_collection_field(AuthorGraphType, AuthorDao, "authors", AuthorFilter, AuthorSort)
|
||||
(
|
||||
schema.query.dao_collection_field(PostGraphType, PostDao, "posts", PostFilter, PostSort)
|
||||
# .with_require_any_permission(PostPermissions.read)
|
||||
.with_public()
|
||||
)
|
||||
|
||||
schema.mutation.with_mutation("post", PostMutation).with_public()
|
||||
|
||||
schema.subscription.with_subscription(PostSubscription)
|
||||
|
||||
app.with_auth_root_queries(True)
|
||||
app.with_auth_root_mutations(True)
|
||||
|
||||
app.with_playground()
|
||||
app.with_graphiql()
|
||||
|
||||
app.with_permissions(PostPermissions)
|
||||
|
||||
provider = builder.service_provider
|
||||
user_cache = provider.get_service(Cache[User])
|
||||
role_cache = provider.get_service(Cache[Role])
|
||||
|
||||
if role_cache == user_cache:
|
||||
raise Exception("Cache service is not working")
|
||||
|
||||
s1 = provider.get_service(ScopedService)
|
||||
s2 = provider.get_service(ScopedService)
|
||||
|
||||
if s1.name == s2.name:
|
||||
raise Exception("Scoped service is not working")
|
||||
|
||||
with provider.create_scope() as scope:
|
||||
s3 = scope.get_service(ScopedService)
|
||||
s4 = scope.get_service(ScopedService)
|
||||
|
||||
if s3.name != s4.name:
|
||||
raise Exception("Scoped service is not working")
|
||||
|
||||
if s1.name == s3.name:
|
||||
raise Exception("Scoped service is not working")
|
||||
|
||||
Console.write_line(
|
||||
s1.name,
|
||||
s2.name,
|
||||
s3.name,
|
||||
s4.name,
|
||||
)
|
||||
|
||||
app.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,30 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Self
|
||||
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
|
||||
|
||||
class Author(DbModelABC[Self]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
first_name: str,
|
||||
last_name: str,
|
||||
deleted: bool = False,
|
||||
editor_id: SerialId | None = None,
|
||||
created: datetime | None = None,
|
||||
updated: datetime | None = None,
|
||||
):
|
||||
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
||||
self._first_name = first_name
|
||||
self._last_name = last_name
|
||||
|
||||
@property
|
||||
def first_name(self) -> str:
|
||||
return self._first_name
|
||||
|
||||
@property
|
||||
def last_name(self) -> str:
|
||||
return self._last_name
|
||||
@@ -1,11 +0,0 @@
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from model.author import Author
|
||||
|
||||
|
||||
class AuthorDao(DbModelDaoABC):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, Author, "authors")
|
||||
|
||||
self.attribute(Author.first_name, str, db_name="firstname")
|
||||
self.attribute(Author.last_name, str, db_name="lastname")
|
||||
@@ -1,37 +0,0 @@
|
||||
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
|
||||
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
|
||||
from cpl.graphql.schema.sort.sort import Sort
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
from model.author import Author
|
||||
|
||||
class AuthorFilter(DbModelFilter[Author]):
|
||||
def __init__(self):
|
||||
DbModelFilter.__init__(self, public=True)
|
||||
self.int_field("id")
|
||||
self.string_field("firstName")
|
||||
self.string_field("lastName")
|
||||
|
||||
class AuthorSort(Sort[Author]):
|
||||
def __init__(self):
|
||||
Sort.__init__(self)
|
||||
self.field("id", SortOrder)
|
||||
self.field("firstName", SortOrder)
|
||||
self.field("lastName", SortOrder)
|
||||
|
||||
class AuthorGraphType(DbModelGraphType[Author]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelGraphType.__init__(self, public=True)
|
||||
|
||||
self.int_field(
|
||||
"id",
|
||||
resolver=lambda root: root.id,
|
||||
).with_public(True)
|
||||
self.string_field(
|
||||
"firstName",
|
||||
resolver=lambda root: root.first_name,
|
||||
).with_public(True)
|
||||
self.string_field(
|
||||
"lastName",
|
||||
resolver=lambda root: root.last_name,
|
||||
).with_public(True)
|
||||
@@ -1,44 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Self
|
||||
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
|
||||
|
||||
class Post(DbModelABC[Self]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
author_id: SerialId,
|
||||
title: str,
|
||||
content: str,
|
||||
deleted: bool = False,
|
||||
editor_id: SerialId | None = None,
|
||||
created: datetime | None = None,
|
||||
updated: datetime | None = None,
|
||||
):
|
||||
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
||||
self._author_id = author_id
|
||||
self._title = title
|
||||
self._content = content
|
||||
|
||||
@property
|
||||
def author_id(self) -> SerialId:
|
||||
return self._author_id
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
return self._title
|
||||
|
||||
@title.setter
|
||||
def title(self, value: str):
|
||||
self._title = value
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
return self._content
|
||||
|
||||
@content.setter
|
||||
def content(self, value: str):
|
||||
self._content = value
|
||||
@@ -1,15 +0,0 @@
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from model.author_dao import AuthorDao
|
||||
from model.post import Post
|
||||
|
||||
|
||||
class PostDao(DbModelDaoABC[Post]):
|
||||
|
||||
def __init__(self, authors: AuthorDao):
|
||||
DbModelDaoABC.__init__(self, Post, "posts")
|
||||
|
||||
self.attribute(Post.author_id, int, db_name="authorId")
|
||||
self.reference("author", "id", Post.author_id, "authors", authors)
|
||||
|
||||
self.attribute(Post.title, str)
|
||||
self.attribute(Post.content, str)
|
||||
@@ -1,148 +0,0 @@
|
||||
from cpl.dependency.event_bus import EventBusABC
|
||||
from cpl.graphql.query_context import QueryContext
|
||||
from cpl.graphql.schema.db_model_graph_type import DbModelGraphType
|
||||
from cpl.graphql.schema.filter.db_model_filter import DbModelFilter
|
||||
from cpl.graphql.schema.input import Input
|
||||
from cpl.graphql.schema.mutation import Mutation
|
||||
from cpl.graphql.schema.sort.sort import Sort
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
from cpl.graphql.schema.subscription import Subscription
|
||||
from model.author_dao import AuthorDao
|
||||
from model.author_query import AuthorGraphType, AuthorFilter
|
||||
from model.post import Post
|
||||
from model.post_dao import PostDao
|
||||
|
||||
|
||||
class PostFilter(DbModelFilter[Post]):
|
||||
def __init__(self):
|
||||
DbModelFilter.__init__(self, public=True)
|
||||
self.int_field("id")
|
||||
self.filter_field("author", AuthorFilter)
|
||||
self.string_field("title")
|
||||
self.string_field("content")
|
||||
|
||||
|
||||
class PostSort(Sort[Post]):
|
||||
def __init__(self):
|
||||
Sort.__init__(self)
|
||||
self.field("id", SortOrder)
|
||||
self.field("title", SortOrder)
|
||||
self.field("content", SortOrder)
|
||||
|
||||
|
||||
class PostGraphType(DbModelGraphType[Post]):
|
||||
|
||||
def __init__(self, authors: AuthorDao):
|
||||
DbModelGraphType.__init__(self, public=True)
|
||||
|
||||
self.int_field(
|
||||
"id",
|
||||
resolver=lambda root: root.id,
|
||||
).with_optional().with_public(True)
|
||||
|
||||
async def _a(root: Post):
|
||||
return await authors.get_by_id(root.author_id)
|
||||
|
||||
def r_name(ctx: QueryContext):
|
||||
return ctx.user.username == "admin"
|
||||
|
||||
self.object_field("author", AuthorGraphType, resolver=_a).with_public(True) # .with_require_any([], [r_name]))
|
||||
self.string_field(
|
||||
"title",
|
||||
resolver=lambda root: root.title,
|
||||
).with_public(True)
|
||||
self.string_field(
|
||||
"content",
|
||||
resolver=lambda root: root.content,
|
||||
).with_public(True)
|
||||
|
||||
|
||||
class PostCreateInput(Input[Post]):
|
||||
title: str
|
||||
content: str
|
||||
author_id: int
|
||||
|
||||
def __init__(self):
|
||||
Input.__init__(self)
|
||||
self.string_field("title").with_required()
|
||||
self.string_field("content").with_required()
|
||||
self.int_field("author_id").with_required()
|
||||
|
||||
|
||||
class PostUpdateInput(Input[Post]):
|
||||
title: str
|
||||
content: str
|
||||
author_id: int
|
||||
|
||||
def __init__(self):
|
||||
Input.__init__(self)
|
||||
self.int_field("id").with_required()
|
||||
self.string_field("title").with_required(False)
|
||||
self.string_field("content").with_required(False)
|
||||
|
||||
|
||||
class PostSubscription(Subscription):
|
||||
def __init__(self, bus: EventBusABC):
|
||||
Subscription.__init__(self)
|
||||
self._bus = bus
|
||||
|
||||
def selector(event: Post, info) -> bool:
|
||||
return event.id == 101
|
||||
|
||||
self.subscription_field("postChange", PostGraphType, selector).with_public()
|
||||
|
||||
|
||||
class PostMutation(Mutation):
|
||||
|
||||
def __init__(self, posts: PostDao, authors: AuthorDao, bus: EventBusABC):
|
||||
Mutation.__init__(self)
|
||||
|
||||
self._posts = posts
|
||||
self._authors = authors
|
||||
self._bus = bus
|
||||
|
||||
self.field("create", int, resolver=self.create_post).with_public().with_required().with_argument(
|
||||
"input",
|
||||
PostCreateInput,
|
||||
).with_required()
|
||||
self.field("update", bool, resolver=self.update_post).with_public().with_required().with_argument(
|
||||
"input",
|
||||
PostUpdateInput,
|
||||
).with_required()
|
||||
self.field("delete", bool, resolver=self.delete_post).with_public().with_required().with_argument(
|
||||
"id",
|
||||
int,
|
||||
).with_required()
|
||||
self.field("restore", bool, resolver=self.restore_post).with_public().with_required().with_argument(
|
||||
"id",
|
||||
int,
|
||||
).with_required()
|
||||
|
||||
async def create_post(self, input: PostCreateInput) -> int:
|
||||
return await self._posts.create(Post(0, input.author_id, input.title, input.content))
|
||||
|
||||
async def update_post(self, input: PostUpdateInput) -> bool:
|
||||
post = await self._posts.get_by_id(input.id)
|
||||
if post is None:
|
||||
return False
|
||||
|
||||
post.title = input.title if input.title is not None else post.title
|
||||
post.content = input.content if input.content is not None else post.content
|
||||
|
||||
await self._posts.update(post)
|
||||
await self._bus.publish("postChange", post)
|
||||
return True
|
||||
|
||||
async def delete_post(self, id: int) -> bool:
|
||||
post = await self._posts.get_by_id(id)
|
||||
if post is None:
|
||||
return False
|
||||
await self._posts.delete(post)
|
||||
return True
|
||||
|
||||
async def restore_post(self, id: int) -> bool:
|
||||
post = await self._posts.get_by_id(id)
|
||||
if post is None:
|
||||
return False
|
||||
await self._posts.restore(post)
|
||||
return True
|
||||
@@ -1,8 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PostPermissions(Enum):
|
||||
|
||||
read = "post.read"
|
||||
write = "post.write"
|
||||
delete = "post.delete"
|
||||
@@ -1,39 +0,0 @@
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
from cpl.graphql.schema.graph_type import GraphType
|
||||
|
||||
from cpl.graphql.schema.sort.sort import Sort
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
|
||||
|
||||
class City:
|
||||
def __init__(self, id: int, name: str):
|
||||
self.id = id
|
||||
self.name = name
|
||||
|
||||
|
||||
class CityFilter(Filter[City]):
|
||||
def __init__(self):
|
||||
Filter.__init__(self)
|
||||
self.field("id", int)
|
||||
self.field("name", str)
|
||||
|
||||
|
||||
class CitySort(Sort[City]):
|
||||
def __init__(self):
|
||||
Sort.__init__(self)
|
||||
self.field("id", SortOrder)
|
||||
self.field("name", SortOrder)
|
||||
|
||||
|
||||
class CityGraphType(GraphType[City]):
|
||||
def __init__(self):
|
||||
GraphType.__init__(self)
|
||||
|
||||
self.int_field(
|
||||
"id",
|
||||
resolver=lambda root: root.id,
|
||||
)
|
||||
self.string_field(
|
||||
"name",
|
||||
resolver=lambda root: root.name,
|
||||
)
|
||||
@@ -1,70 +0,0 @@
|
||||
from queries.cities import CityFilter, CitySort, CityGraphType, City
|
||||
from queries.user import User, UserFilter, UserSort, UserGraphType
|
||||
from cpl.api.middleware.request import get_request
|
||||
from cpl.auth.schema import UserDao, User
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
from cpl.graphql.schema.graph_type import GraphType
|
||||
from cpl.graphql.schema.query import Query
|
||||
from cpl.graphql.schema.sort.sort import Sort
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
|
||||
users = [User(i, f"User {i}") for i in range(1, 101)]
|
||||
cities = [City(i, f"City {i}") for i in range(1, 101)]
|
||||
|
||||
# class UserFilter(Filter[User]):
|
||||
# def __init__(self):
|
||||
# Filter.__init__(self)
|
||||
# self.field("id", int)
|
||||
# self.field("username", str)
|
||||
#
|
||||
#
|
||||
# class UserSort(Sort[User]):
|
||||
# def __init__(self):
|
||||
# Sort.__init__(self)
|
||||
# self.field("id", SortOrder)
|
||||
# self.field("username", SortOrder)
|
||||
#
|
||||
# class UserGraphType(GraphType[User]):
|
||||
#
|
||||
# def __init__(self):
|
||||
# GraphType.__init__(self)
|
||||
#
|
||||
# self.int_field(
|
||||
# "id",
|
||||
# resolver=lambda root: root.id,
|
||||
# )
|
||||
# self.string_field(
|
||||
# "username",
|
||||
# resolver=lambda root: root.username,
|
||||
# )
|
||||
|
||||
|
||||
class HelloQuery(Query):
|
||||
def __init__(self):
|
||||
Query.__init__(self)
|
||||
self.string_field(
|
||||
"message",
|
||||
resolver=lambda name: f"Hello {name} {get_request().state.request_id}",
|
||||
).with_argument("name", str, "Name to greet", "world")
|
||||
|
||||
self.collection_field(
|
||||
UserGraphType,
|
||||
"users",
|
||||
UserFilter,
|
||||
UserSort,
|
||||
resolver=lambda: users,
|
||||
)
|
||||
self.collection_field(
|
||||
CityGraphType,
|
||||
"cities",
|
||||
CityFilter,
|
||||
CitySort,
|
||||
resolver=lambda: cities,
|
||||
)
|
||||
# self.dao_collection_field(
|
||||
# UserGraphType,
|
||||
# UserDao,
|
||||
# "Users",
|
||||
# UserFilter,
|
||||
# UserSort,
|
||||
# )
|
||||
@@ -1,39 +0,0 @@
|
||||
from cpl.graphql.schema.filter.filter import Filter
|
||||
from cpl.graphql.schema.graph_type import GraphType
|
||||
from cpl.graphql.schema.sort.sort import Sort
|
||||
from cpl.graphql.schema.sort.sort_order import SortOrder
|
||||
|
||||
|
||||
class User:
|
||||
def __init__(self, id: int, name: str):
|
||||
self.id = id
|
||||
self.name = name
|
||||
|
||||
|
||||
class UserFilter(Filter[User]):
|
||||
def __init__(self):
|
||||
Filter.__init__(self)
|
||||
self.field("id", int)
|
||||
self.field("name", str)
|
||||
|
||||
|
||||
class UserSort(Sort[User]):
|
||||
def __init__(self):
|
||||
Sort.__init__(self)
|
||||
self.field("id", SortOrder)
|
||||
self.field("name", SortOrder)
|
||||
|
||||
|
||||
class UserGraphType(GraphType[User]):
|
||||
|
||||
def __init__(self):
|
||||
GraphType.__init__(self)
|
||||
|
||||
self.int_field(
|
||||
"id",
|
||||
resolver=lambda root: root.id,
|
||||
)
|
||||
self.string_field(
|
||||
"name",
|
||||
resolver=lambda root: root.name,
|
||||
)
|
||||
@@ -1,21 +0,0 @@
|
||||
from urllib.request import Request
|
||||
|
||||
from service import PingService
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from cpl.api import APILogger
|
||||
from cpl.api.router import Router
|
||||
from cpl.core.console import Console
|
||||
from cpl.dependency import ServiceProvider
|
||||
from scoped_service import ScopedService
|
||||
|
||||
|
||||
@Router.authenticate()
|
||||
# @Router.authorize(permissions=[Permissions.administrator])
|
||||
# @Router.authorize(policies=["test"])
|
||||
@Router.get(f"/ping")
|
||||
async def ping(r: Request, ping: PingService, logger: APILogger, provider: ServiceProvider, scoped: ScopedService):
|
||||
logger.info(f"Ping: {ping}")
|
||||
|
||||
Console.write_line(scoped.name)
|
||||
return JSONResponse(ping.ping(r))
|
||||
@@ -1,14 +0,0 @@
|
||||
from cpl.core.console.console import Console
|
||||
from cpl.core.utils.string import String
|
||||
|
||||
|
||||
class ScopedService:
|
||||
def __init__(self):
|
||||
self._name = String.random(8)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def run(self):
|
||||
Console.write_line(f"Im {self._name}")
|
||||
@@ -1,22 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS `authors` (
|
||||
`id` INT(30) NOT NULL AUTO_INCREMENT,
|
||||
`firstname` VARCHAR(64) NOT NULL,
|
||||
`lastname` VARCHAR(64) NOT NULL,
|
||||
deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
editorId INT NULL,
|
||||
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY(`id`)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `posts` (
|
||||
`id` INT(30) NOT NULL AUTO_INCREMENT,
|
||||
`authorId` INT(30) NOT NULL REFERENCES `authors`(`id`) ON DELETE CASCADE,
|
||||
`title` TEXT NOT NULL,
|
||||
`content` TEXT NOT NULL,
|
||||
deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
editorId INT NULL,
|
||||
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY(`id`)
|
||||
);
|
||||
@@ -1,48 +0,0 @@
|
||||
from faker import Faker
|
||||
|
||||
from cpl.database.abc import DataSeederABC
|
||||
from cpl.query import Enumerable
|
||||
from model.author import Author
|
||||
from model.author_dao import AuthorDao
|
||||
from model.post import Post
|
||||
from model.post_dao import PostDao
|
||||
|
||||
|
||||
fake = Faker()
|
||||
|
||||
|
||||
class TestDataSeeder(DataSeederABC):
|
||||
|
||||
def __init__(self, authors: AuthorDao, posts: PostDao):
|
||||
DataSeederABC.__init__(self)
|
||||
|
||||
self._authors = authors
|
||||
self._posts = posts
|
||||
|
||||
async def seed(self):
|
||||
if await self._authors.count() == 0:
|
||||
await self._seed_authors()
|
||||
|
||||
if await self._posts.count() == 0:
|
||||
await self._seed_posts()
|
||||
|
||||
async def _seed_authors(self):
|
||||
authors = Enumerable.range(0, 35).select(
|
||||
lambda x: Author(
|
||||
0,
|
||||
fake.first_name(),
|
||||
fake.last_name(),
|
||||
)
|
||||
).to_list()
|
||||
await self._authors.create_many(authors, skip_editor=True)
|
||||
|
||||
async def _seed_posts(self):
|
||||
posts = Enumerable.range(0, 100).select(
|
||||
lambda x: Post(
|
||||
id=0,
|
||||
author_id=fake.random_int(min=1, max=35),
|
||||
title=fake.sentence(nb_words=6),
|
||||
content=fake.paragraph(nb_sentences=6),
|
||||
)
|
||||
).to_list()
|
||||
await self._posts.create_many(posts, skip_editor=True)
|
||||
@@ -1,45 +0,0 @@
|
||||
from cpl.application.abc import ApplicationABC
|
||||
from cpl.core.console.console import Console
|
||||
from cpl.dependency import ServiceProvider
|
||||
from test_abc import TestABC
|
||||
from test_service import TestService
|
||||
from di_tester_service import DITesterService
|
||||
from tester import Tester
|
||||
|
||||
|
||||
class Application(ApplicationABC):
|
||||
def __init__(self, services: ServiceProvider):
|
||||
ApplicationABC.__init__(self, services)
|
||||
|
||||
def _part_of_scoped(self):
|
||||
ts: TestService = self._services.get_service(TestService)
|
||||
ts.run()
|
||||
|
||||
def main(self):
|
||||
with self._services.create_scope() as scope:
|
||||
Console.write_line("Scope1")
|
||||
ts: TestService = scope.get_service(TestService)
|
||||
ts.run()
|
||||
dit: DITesterService = scope.get_service(DITesterService)
|
||||
dit.run()
|
||||
|
||||
if ts.name != dit.name:
|
||||
raise Exception("DI is broken!")
|
||||
|
||||
with self._services.create_scope() as scope:
|
||||
Console.write_line("Scope2")
|
||||
ts: TestService = scope.get_service(TestService)
|
||||
ts.run()
|
||||
dit: DITesterService = scope.get_service(DITesterService)
|
||||
dit.run()
|
||||
|
||||
if ts.name != dit.name:
|
||||
raise Exception("DI is broken!")
|
||||
|
||||
Console.write_line("Global")
|
||||
self._part_of_scoped()
|
||||
#from static_test import StaticTest
|
||||
#StaticTest.test()
|
||||
|
||||
self._services.get_service(Tester)
|
||||
Console.write_line(self._services.get_services(TestABC))
|
||||
@@ -1,27 +0,0 @@
|
||||
from cpl.application.abc import StartupABC
|
||||
from cpl.dependency import ServiceProvider, ServiceCollection
|
||||
from di_tester_service import DITesterService
|
||||
from test1_service import Test1Service
|
||||
from test2_service import Test2Service
|
||||
from test_abc import TestABC
|
||||
from test_service import TestService
|
||||
from tester import Tester
|
||||
|
||||
|
||||
class Startup(StartupABC):
|
||||
def __init__(self):
|
||||
StartupABC.__init__(self)
|
||||
|
||||
@staticmethod
|
||||
def configure_configuration(): ...
|
||||
|
||||
@staticmethod
|
||||
def configure_services(services: ServiceCollection) -> ServiceProvider:
|
||||
services.add_scoped(TestService)
|
||||
services.add_scoped(DITesterService)
|
||||
|
||||
services.add_singleton(TestABC, Test1Service)
|
||||
services.add_singleton(TestABC, Test2Service)
|
||||
services.add_singleton(Tester)
|
||||
|
||||
return services.build()
|
||||
@@ -1,10 +0,0 @@
|
||||
from cpl.dependency import ServiceProvider, ServiceProvider
|
||||
from cpl.dependency.inject import inject
|
||||
from test_service import TestService
|
||||
|
||||
|
||||
class StaticTest:
|
||||
@staticmethod
|
||||
@inject
|
||||
def test(services: ServiceProvider, t1: TestService):
|
||||
t1.run()
|
||||
@@ -1,7 +0,0 @@
|
||||
from cpl.core.console.console import Console
|
||||
from test_abc import TestABC
|
||||
|
||||
|
||||
class Tester:
|
||||
def __init__(self, t1: TestABC, t2: TestABC, t3: TestABC, t: list[TestABC]):
|
||||
Console.write_line("Tester:", t, t1, t2, t3)
|
||||
@@ -1,30 +0,0 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from cpl.core.console import Console
|
||||
from cpl.core.time.cron import Cron
|
||||
from cpl.dependency.hosted.cronjob import CronjobABC
|
||||
from cpl.dependency.hosted.hosted_service import HostedService
|
||||
|
||||
|
||||
class Hosted(HostedService):
|
||||
def __init__(self):
|
||||
self._stopped = False
|
||||
|
||||
async def start(self):
|
||||
Console.write_line("Hosted Service Started")
|
||||
while not self._stopped:
|
||||
Console.write_line("Hosted Service Running")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def stop(self):
|
||||
Console.write_line("Hosted Service Stopped")
|
||||
self._stopped = True
|
||||
|
||||
|
||||
class MyCronJob(CronjobABC):
|
||||
def __init__(self):
|
||||
CronjobABC.__init__(self, Cron("*/1 * * * *")) # Every minute
|
||||
|
||||
async def loop(self):
|
||||
Console.write_line(f"[{datetime.now()}] Hello from Cronjob!")
|
||||
@@ -1,10 +0,0 @@
|
||||
from cpl.core.console import Console
|
||||
|
||||
|
||||
class ScopedService:
|
||||
def __init__(self):
|
||||
self.value = "I am a scoped service"
|
||||
Console.write_line(self.value, self)
|
||||
|
||||
def get_value(self):
|
||||
return self.value
|
||||
@@ -1,60 +0,0 @@
|
||||
from cpl.core.console import Console
|
||||
from cpl.core.utils.benchmark import Benchmark
|
||||
from cpl.query.enumerable import Enumerable
|
||||
from cpl.query.immutable_list import ImmutableList
|
||||
from cpl.query.list import List
|
||||
from cpl.query.set import Set
|
||||
|
||||
|
||||
def _default():
|
||||
Console.write_line(Enumerable.empty().to_list())
|
||||
|
||||
Console.write_line(Enumerable.range(0, 100).length)
|
||||
Console.write_line(Enumerable.range(0, 100).to_list())
|
||||
|
||||
Console.write_line(Enumerable.range(0, 100).where(lambda x: x % 2 == 0).length)
|
||||
Console.write_line(
|
||||
Enumerable.range(0, 100).where(lambda x: x % 2 == 0).to_list().select(lambda x: str(x)).to_list()
|
||||
)
|
||||
Console.write_line(List)
|
||||
|
||||
s =Enumerable.range(0, 10).to_set()
|
||||
Console.write_line(s)
|
||||
s.add(1)
|
||||
Console.write_line(s)
|
||||
|
||||
data = Enumerable(
|
||||
[
|
||||
{"name": "Alice", "age": 30},
|
||||
{"name": "Dave", "age": 35},
|
||||
{"name": "Charlie", "age": 25},
|
||||
{"name": "Bob", "age": 25},
|
||||
]
|
||||
)
|
||||
|
||||
Console.write_line(data.order_by(lambda x: x["age"]).to_list())
|
||||
Console.write_line(data.order_by(lambda x: x["age"]).then_by(lambda x: x["name"]).to_list())
|
||||
Console.write_line(data.order_by(lambda x: x["name"]).then_by(lambda x: x["age"]).to_list())
|
||||
|
||||
|
||||
def t_benchmark(data: list):
|
||||
Benchmark.all("Enumerable", lambda: Enumerable(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
|
||||
Benchmark.all("Set", lambda: Set(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
|
||||
Benchmark.all("List", lambda: List(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list())
|
||||
Benchmark.all(
|
||||
"ImmutableList", lambda: ImmutableList(data).where(lambda x: x % 2 == 0).select(lambda x: x * 2).to_list()
|
||||
)
|
||||
Benchmark.all("List comprehension", lambda: [x * 2 for x in data if x % 2 == 0])
|
||||
|
||||
|
||||
def main():
|
||||
N = 1_000_000
|
||||
data = list(range(N))
|
||||
t_benchmark(data)
|
||||
|
||||
Console.write_line()
|
||||
_default()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,4 +1,32 @@
|
||||
from .error import APIError, AlreadyExists, EndpointNotImplemented, Forbidden, NotFound, Unauthorized
|
||||
from .logger import APILogger
|
||||
from .settings import ApiSettings
|
||||
from .api_module import ApiModule
|
||||
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
|
||||
|
||||
|
||||
def add_api(collection: _ServiceCollection):
|
||||
try:
|
||||
from cpl.database import mysql
|
||||
|
||||
collection.add_module(mysql)
|
||||
except ImportError as e:
|
||||
from cpl.core.errors import dependency_error
|
||||
|
||||
dependency_error("cpl-database", e)
|
||||
|
||||
try:
|
||||
from cpl import auth
|
||||
from cpl.auth import permission
|
||||
|
||||
collection.add_module(auth)
|
||||
collection.add_module(permission)
|
||||
except ImportError as e:
|
||||
from cpl.core.errors import dependency_error
|
||||
|
||||
dependency_error("cpl-auth", e)
|
||||
|
||||
from cpl.api.registry.policy import PolicyRegistry
|
||||
from cpl.api.registry.route import RouteRegistry
|
||||
|
||||
collection.add_singleton(PolicyRegistry)
|
||||
collection.add_singleton(RouteRegistry)
|
||||
|
||||
|
||||
_ServiceCollection.with_module(add_api, __name__)
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .asgi_middleware_abc import ASGIMiddleware
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
from abc import ABC
|
||||
from enum import Enum
|
||||
from typing import Self
|
||||
|
||||
from starlette.applications import Starlette
|
||||
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
from cpl.api.model.validation_match import ValidationMatch
|
||||
from cpl.api.typing import HTTPMethods, PartialMiddleware, TEndpoint, PolicyInput
|
||||
from cpl.application.abc.application_abc import ApplicationABC
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.dependency.typing import Modules
|
||||
|
||||
|
||||
class WebAppABC(ApplicationABC, ABC):
|
||||
|
||||
def __init__(self, services: ServiceProvider, modules: Modules, required_modules: list[str | object] = None):
|
||||
ApplicationABC.__init__(self, services, modules, required_modules)
|
||||
|
||||
def with_routes_directory(self, directory: str) -> Self: ...
|
||||
def with_app(self, app: Starlette) -> Self: ...
|
||||
def with_routes(
|
||||
self,
|
||||
routes: list[ApiRoute],
|
||||
method: HTTPMethods,
|
||||
authentication: bool = False,
|
||||
roles: list[str | Enum] = None,
|
||||
permissions: list[str | Enum] = None,
|
||||
policies: list[str] = None,
|
||||
match: ValidationMatch = None,
|
||||
) -> Self: ...
|
||||
def with_route(
|
||||
self,
|
||||
path: str,
|
||||
fn: TEndpoint,
|
||||
method: HTTPMethods,
|
||||
authentication: bool = False,
|
||||
roles: list[str | Enum] = None,
|
||||
permissions: list[str | Enum] = None,
|
||||
policies: list[str] = None,
|
||||
match: ValidationMatch = None,
|
||||
) -> Self: ...
|
||||
def with_middleware(self, middleware: PartialMiddleware) -> Self: ...
|
||||
def with_authentication(self) -> Self: ...
|
||||
def with_authorization(self, *policies: list[PolicyInput] | PolicyInput) -> Self: ...
|
||||
@@ -1,22 +0,0 @@
|
||||
from cpl.api import ApiSettings
|
||||
from cpl.api.registry.policy import PolicyRegistry
|
||||
from cpl.api.registry.route import RouteRegistry
|
||||
from cpl.auth.auth_module import AuthModule
|
||||
from cpl.auth.permission.permission_module import PermissionsModule
|
||||
from cpl.database.database_module import DatabaseModule
|
||||
from cpl.dependency import ServiceCollection
|
||||
from cpl.dependency.module.module import Module
|
||||
|
||||
|
||||
class ApiModule(Module):
|
||||
config = [ApiSettings]
|
||||
singleton = [
|
||||
PolicyRegistry,
|
||||
RouteRegistry,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def register(collection: ServiceCollection):
|
||||
collection.add_module(DatabaseModule)
|
||||
collection.add_module(AuthModule)
|
||||
collection.add_module(PermissionsModule)
|
||||
@@ -1 +0,0 @@
|
||||
from .web_app import WebApp
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Mapping, Any, Self
|
||||
from typing import Mapping, Any, Callable, Self, Union
|
||||
|
||||
import uvicorn
|
||||
from starlette.applications import Starlette
|
||||
@@ -10,8 +10,7 @@ from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ExceptionHandler
|
||||
|
||||
from cpl.api.abc.web_app_abc import WebAppABC
|
||||
from cpl.api.api_module import ApiModule
|
||||
from cpl import api, auth
|
||||
from cpl.api.error import APIError
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.middleware.authentication import AuthenticationMiddleware
|
||||
@@ -25,46 +24,43 @@ from cpl.api.registry.policy import PolicyRegistry
|
||||
from cpl.api.registry.route import RouteRegistry
|
||||
from cpl.api.router import Router
|
||||
from cpl.api.settings import ApiSettings
|
||||
from cpl.api.typing import HTTPMethods, PartialMiddleware, TEndpoint, PolicyInput
|
||||
from cpl.auth.auth_module import AuthModule
|
||||
from cpl.auth.permission.permission_module import PermissionsModule
|
||||
from cpl.core.configuration.configuration import Configuration
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.dependency.typing import Modules
|
||||
from cpl.api.typing import HTTPMethods, PartialMiddleware, PolicyResolver
|
||||
from cpl.application.abc.application_abc import ApplicationABC
|
||||
from cpl.core.configuration import Configuration
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
_logger = APILogger("API")
|
||||
|
||||
class WebApp(WebAppABC):
|
||||
def __init__(self, services: ServiceProvider, modules: Modules, required_modules: list[str | object] = None):
|
||||
WebAppABC.__init__(
|
||||
self, services, modules, [AuthModule, PermissionsModule, ApiModule] + (required_modules or [])
|
||||
)
|
||||
PolicyInput = Union[dict[str, PolicyResolver], Policy]
|
||||
|
||||
class WebApp(ApplicationABC):
|
||||
def __init__(self, services: ServiceProviderABC):
|
||||
super().__init__(services, [auth, api])
|
||||
self._app: Starlette | None = None
|
||||
|
||||
self._logger = services.get_service(APILogger)
|
||||
|
||||
self._api_settings = Configuration.get(ApiSettings)
|
||||
self._policies = services.get_service(PolicyRegistry)
|
||||
self._routes = services.get_service(RouteRegistry)
|
||||
|
||||
self._middleware: list[Middleware] = []
|
||||
self._middleware: list[Middleware] = [
|
||||
Middleware(RequestMiddleware),
|
||||
Middleware(LoggingMiddleware),
|
||||
]
|
||||
self._exception_handlers: Mapping[Any, ExceptionHandler] = {
|
||||
Exception: self._handle_exception,
|
||||
APIError: self._handle_exception,
|
||||
}
|
||||
|
||||
self.with_middleware(RequestMiddleware)
|
||||
self.with_middleware(LoggingMiddleware)
|
||||
|
||||
async def _handle_exception(self, request: Request, exc: Exception):
|
||||
@staticmethod
|
||||
async def _handle_exception(request: Request, exc: Exception):
|
||||
if isinstance(exc, APIError):
|
||||
self._logger.error(exc)
|
||||
_logger.error(exc)
|
||||
return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
|
||||
|
||||
if hasattr(request.state, "request_id"):
|
||||
self._logger.error(f"Request {request.state.request_id}", exc)
|
||||
_logger.error(f"Request {request.state.request_id}", exc)
|
||||
else:
|
||||
self._logger.error("Request unknown", exc)
|
||||
_logger.error("Request unknown", exc)
|
||||
|
||||
return JSONResponse({"error": str(exc)}, status_code=500)
|
||||
|
||||
@@ -72,23 +68,27 @@ class WebApp(WebAppABC):
|
||||
origins = self._api_settings.allowed_origins
|
||||
|
||||
if origins is None or origins == "":
|
||||
self._logger.warning("No allowed origins specified, allowing all origins")
|
||||
_logger.warning("No allowed origins specified, allowing all origins")
|
||||
return ["*"]
|
||||
|
||||
self._logger.debug(f"Allowed origins: {origins}")
|
||||
_logger.debug(f"Allowed origins: {origins}")
|
||||
return origins.split(",")
|
||||
|
||||
def with_database(self) -> Self:
|
||||
self.with_migrations()
|
||||
self.with_seeders()
|
||||
return self
|
||||
|
||||
def with_app(self, app: Starlette) -> Self:
|
||||
assert app is not None, "app must not be None"
|
||||
assert isinstance(app, Starlette), "app must be an instance of Starlette"
|
||||
self._app = app
|
||||
return self
|
||||
|
||||
def _check_for_app(self):
|
||||
if self._app is not None:
|
||||
raise ValueError("App is already set, cannot add routes or middleware")
|
||||
|
||||
def _validate_policies(self):
|
||||
for rule in Router.get_authorization_rules():
|
||||
for policy_name in rule["policies"]:
|
||||
policy = self._policies.get(policy_name)
|
||||
if not policy:
|
||||
self._logger.fatal(f"Authorization policy '{policy_name}' not found")
|
||||
|
||||
def with_routes_directory(self, directory: str) -> Self:
|
||||
self._check_for_app()
|
||||
assert directory is not None, "directory must not be None"
|
||||
@@ -103,12 +103,6 @@ class WebApp(WebAppABC):
|
||||
|
||||
return self
|
||||
|
||||
def with_app(self, app: Starlette) -> Self:
|
||||
assert app is not None, "app must not be None"
|
||||
assert isinstance(app, Starlette), "app must be an instance of Starlette"
|
||||
self._app = app
|
||||
return self
|
||||
|
||||
def with_routes(
|
||||
self,
|
||||
routes: list[ApiRoute],
|
||||
@@ -138,7 +132,7 @@ class WebApp(WebAppABC):
|
||||
def with_route(
|
||||
self,
|
||||
path: str,
|
||||
fn: TEndpoint,
|
||||
fn: Callable[[Request], Any],
|
||||
method: HTTPMethods,
|
||||
authentication: bool = False,
|
||||
roles: list[str | Enum] = None,
|
||||
@@ -169,37 +163,13 @@ class WebApp(WebAppABC):
|
||||
|
||||
return self
|
||||
|
||||
def with_websocket(
|
||||
self,
|
||||
path: str,
|
||||
fn: TEndpoint,
|
||||
authentication: bool = False,
|
||||
roles: list[str | Enum] = None,
|
||||
permissions: list[str | Enum] = None,
|
||||
policies: list[str] = None,
|
||||
match: ValidationMatch = None,
|
||||
) -> Self:
|
||||
self._check_for_app()
|
||||
assert path is not None, "path must not be None"
|
||||
assert fn is not None, "fn must not be None"
|
||||
|
||||
Router.websocket(path, registry=self._routes)(fn)
|
||||
|
||||
if authentication:
|
||||
Router.authenticate()(fn)
|
||||
|
||||
if roles or permissions or policies:
|
||||
Router.authorize(roles, permissions, policies, match)(fn)
|
||||
|
||||
return self
|
||||
|
||||
def with_middleware(self, middleware: PartialMiddleware) -> Self:
|
||||
self._check_for_app()
|
||||
|
||||
if isinstance(middleware, Middleware):
|
||||
self._middleware.append(inject(middleware))
|
||||
self._middleware.append(middleware)
|
||||
elif callable(middleware):
|
||||
self._middleware.append(Middleware(inject(middleware)))
|
||||
self._middleware.append(Middleware(middleware))
|
||||
else:
|
||||
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
|
||||
|
||||
@@ -210,7 +180,6 @@ class WebApp(WebAppABC):
|
||||
return self
|
||||
|
||||
def with_authorization(self, *policies: list[PolicyInput] | PolicyInput) -> Self:
|
||||
self._check_for_app()
|
||||
if policies:
|
||||
_policies = []
|
||||
|
||||
@@ -221,11 +190,11 @@ class WebApp(WebAppABC):
|
||||
if isinstance(policy, dict):
|
||||
for name, resolver in policy.items():
|
||||
if not isinstance(name, str):
|
||||
self._logger.warning(f"Skipping policy at index {i}, name must be a string")
|
||||
_logger.warning(f"Skipping policy at index {i}, name must be a string")
|
||||
continue
|
||||
|
||||
if not callable(resolver):
|
||||
self._logger.warning(f"Skipping policy {name}, resolver must be callable")
|
||||
_logger.warning(f"Skipping policy {name}, resolver must be callable")
|
||||
continue
|
||||
|
||||
_policies.append(Policy(name, resolver))
|
||||
@@ -233,20 +202,24 @@ class WebApp(WebAppABC):
|
||||
|
||||
_policies.append(policy)
|
||||
|
||||
self._policies.extend(_policies)
|
||||
self._policies.extend_policies(_policies)
|
||||
|
||||
self.with_middleware(AuthorizationMiddleware)
|
||||
return self
|
||||
|
||||
async def _log_before_startup(self):
|
||||
self._logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
|
||||
def _validate_policies(self):
|
||||
for rule in Router.get_authorization_rules():
|
||||
for policy_name in rule["policies"]:
|
||||
policy = self._policies.get(policy_name)
|
||||
if not policy:
|
||||
_logger.fatal(f"Authorization policy '{policy_name}' not found")
|
||||
|
||||
async def main(self):
|
||||
self._logger.debug(f"Preparing API")
|
||||
_logger.debug(f"Preparing API")
|
||||
self._validate_policies()
|
||||
|
||||
if self._app is None:
|
||||
routes = [route.to_starlette(inject) for route in self._routes.all()]
|
||||
routes = [route.to_starlette(self._services.inject) for route in self._routes.all()]
|
||||
|
||||
app = Starlette(
|
||||
routes=routes,
|
||||
@@ -264,7 +237,7 @@ class WebApp(WebAppABC):
|
||||
else:
|
||||
app = self._app
|
||||
|
||||
await self._log_before_startup()
|
||||
_logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
|
||||
|
||||
config = uvicorn.Config(
|
||||
app, host=self._api_settings.host, port=self._api_settings.port, log_config=None, loop="asyncio"
|
||||
@@ -272,4 +245,4 @@ class WebApp(WebAppABC):
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
self._logger.info("Shutdown API")
|
||||
_logger.info("Shutdown API")
|
||||
|
||||
@@ -8,7 +8,7 @@ class APIError(HTTPException):
|
||||
status_code = 500
|
||||
|
||||
def __init__(self, message: str = ""):
|
||||
HTTPException.__init__(self, self.status_code, message)
|
||||
super().__init__(self.status_code, message)
|
||||
self._message = message
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||
from cpl.core.log.logger import Logger
|
||||
|
||||
|
||||
class APILogger(WrappedLogger):
|
||||
class APILogger(Logger):
|
||||
|
||||
def __init__(self):
|
||||
WrappedLogger.__init__(self, "api")
|
||||
def __init__(self, source: str):
|
||||
Logger.__init__(self, source, "api")
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from .authentication import AuthenticationMiddleware
|
||||
from .authorization import AuthorizationMiddleware
|
||||
from .logging import LoggingMiddleware
|
||||
from .request import RequestMiddleware
|
||||
|
||||
@@ -2,22 +2,24 @@ from keycloak import KeycloakAuthenticationError
|
||||
from starlette.types import Scope, Receive, Send
|
||||
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.error import Unauthorized
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.error import Unauthorized
|
||||
from cpl.api.middleware.request import get_request
|
||||
from cpl.api.router import Router
|
||||
from cpl.auth.keycloak import KeycloakClient
|
||||
from cpl.auth.schema import UserDao, User
|
||||
from cpl.auth.schema import AuthUserDao, AuthUser
|
||||
from cpl.core.ctx import set_user
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
|
||||
class AuthenticationMiddleware(ASGIMiddleware):
|
||||
|
||||
def __init__(self, app, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao):
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, keycloak: KeycloakClient, user_dao: AuthUserDao):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._logger = logger
|
||||
|
||||
self._keycloak = keycloak
|
||||
self._user_dao = user_dao
|
||||
|
||||
@@ -26,26 +28,11 @@ class AuthenticationMiddleware(ASGIMiddleware):
|
||||
url = request.url.path
|
||||
|
||||
if url not in Router.get_auth_required_routes():
|
||||
self._logger.trace(f"No authentication required for {url}")
|
||||
return await self._app(scope, receive, send)
|
||||
|
||||
user = getattr(request.state, "user", None)
|
||||
if not user or user.deleted:
|
||||
self._logger.debug(f"Unauthorized access to {url}, user missing or deleted")
|
||||
return await Unauthorized("Unauthorized").asgi_response(scope, receive, send)
|
||||
|
||||
return await self._call_next(scope, receive, send)
|
||||
|
||||
async def _old_call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
request = get_request()
|
||||
url = request.url.path
|
||||
|
||||
if url not in Router.get_auth_required_routes():
|
||||
self._logger.trace(f"No authentication required for {url}")
|
||||
_logger.trace(f"No authentication required for {url}")
|
||||
return await self._app(scope, receive, send)
|
||||
|
||||
if not request.headers.get("Authorization"):
|
||||
self._logger.debug(f"Unauthorized access to {url}, missing Authorization header")
|
||||
_logger.debug(f"Unauthorized access to {url}, missing Authorization header")
|
||||
return await Unauthorized(f"Missing header Authorization").asgi_response(scope, receive, send)
|
||||
|
||||
auth_header = request.headers.get("Authorization", None)
|
||||
@@ -54,7 +41,7 @@ class AuthenticationMiddleware(ASGIMiddleware):
|
||||
|
||||
token = auth_header.split("Bearer ")[1]
|
||||
if not await self._verify_login(token):
|
||||
self._logger.debug(f"Unauthorized access to {url}, invalid token")
|
||||
_logger.debug(f"Unauthorized access to {url}, invalid token")
|
||||
return await Unauthorized("Invalid token").asgi_response(scope, receive, send)
|
||||
|
||||
# check user exists in db, if not create
|
||||
@@ -64,7 +51,7 @@ class AuthenticationMiddleware(ASGIMiddleware):
|
||||
|
||||
user = await self._get_or_crate_user(keycloak_id)
|
||||
if user.deleted:
|
||||
self._logger.debug(f"Unauthorized access to {url}, user is deleted")
|
||||
_logger.debug(f"Unauthorized access to {url}, user is deleted")
|
||||
return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
|
||||
|
||||
request.state.user = user
|
||||
@@ -72,12 +59,12 @@ class AuthenticationMiddleware(ASGIMiddleware):
|
||||
|
||||
return await self._call_next(scope, receive, send)
|
||||
|
||||
async def _get_or_crate_user(self, keycloak_id: str) -> User:
|
||||
async def _get_or_crate_user(self, keycloak_id: str) -> AuthUser:
|
||||
existing = await self._user_dao.find_by_keycloak_id(keycloak_id)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
user = User(0, keycloak_id)
|
||||
user = AuthUser(0, keycloak_id)
|
||||
uid = await self._user_dao.create(user)
|
||||
return await self._user_dao.get_by_id(uid)
|
||||
|
||||
@@ -86,8 +73,8 @@ class AuthenticationMiddleware(ASGIMiddleware):
|
||||
token_info = self._keycloak.introspect(token)
|
||||
return token_info.get("active", False)
|
||||
except KeycloakAuthenticationError as e:
|
||||
self._logger.debug(f"Keycloak authentication error: {e}")
|
||||
_logger.debug(f"Keycloak authentication error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
self._logger.error(f"Unexpected error during token verification: {e}")
|
||||
_logger.error(f"Unexpected error during token verification: {e}")
|
||||
return False
|
||||
|
||||
@@ -7,17 +7,19 @@ from cpl.api.middleware.request import get_request
|
||||
from cpl.api.model.validation_match import ValidationMatch
|
||||
from cpl.api.registry.policy import PolicyRegistry
|
||||
from cpl.api.router import Router
|
||||
from cpl.auth.schema._administration.user_dao import UserDao
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
from cpl.core.ctx.user_context import get_user
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
|
||||
class AuthorizationMiddleware(ASGIMiddleware):
|
||||
|
||||
def __init__(self, app, logger: APILogger, policies: PolicyRegistry, user_dao: UserDao):
|
||||
@ServiceProviderABC.inject
|
||||
def __init__(self, app, policies: PolicyRegistry, user_dao: AuthUserDao):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._logger = logger
|
||||
|
||||
self._policies = policies
|
||||
self._user_dao = user_dao
|
||||
|
||||
@@ -26,7 +28,7 @@ class AuthorizationMiddleware(ASGIMiddleware):
|
||||
url = request.url.path
|
||||
|
||||
if url not in Router.get_authorization_rules_paths():
|
||||
self._logger.trace(f"No authorization required for {url}")
|
||||
_logger.trace(f"No authorization required for {url}")
|
||||
return await self._app(scope, receive, send)
|
||||
|
||||
user = get_user()
|
||||
@@ -51,21 +53,17 @@ class AuthorizationMiddleware(ASGIMiddleware):
|
||||
|
||||
if rule["permissions"]:
|
||||
if match == ValidationMatch.all and not all(p in perm_names for p in rule["permissions"]):
|
||||
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
|
||||
scope, receive, send
|
||||
)
|
||||
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
|
||||
if match == ValidationMatch.any and not any(p in perm_names for p in rule["permissions"]):
|
||||
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(
|
||||
scope, receive, send
|
||||
)
|
||||
return await Forbidden(f"missing permissions: {rule["permissions"]}").asgi_response(scope, receive, send)
|
||||
|
||||
for policy_name in rule["policies"]:
|
||||
policy = self._policies.get(policy_name)
|
||||
if not policy:
|
||||
self._logger.warning(f"Authorization policy '{policy_name}' not found")
|
||||
_logger.warning(f"Authorization policy '{policy_name}' not found")
|
||||
continue
|
||||
|
||||
if not await policy.resolve(user):
|
||||
return await Forbidden(f"policy {policy.name} failed").asgi_response(scope, receive, send)
|
||||
|
||||
return await self._call_next(scope, receive, send)
|
||||
return await self._call_next(scope, receive, send)
|
||||
@@ -7,14 +7,14 @@ from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.middleware.request import get_request
|
||||
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
|
||||
class LoggingMiddleware(ASGIMiddleware):
|
||||
|
||||
def __init__(self, app, logger: APILogger):
|
||||
def __init__(self, app):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._logger = logger
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] != "http":
|
||||
await self._call_next(scope, receive, send)
|
||||
@@ -53,8 +53,9 @@ class LoggingMiddleware(ASGIMiddleware):
|
||||
}
|
||||
return {key: value for key, value in headers.items() if key in relevant_keys}
|
||||
|
||||
async def _log_request(self, request: Request):
|
||||
self._logger.debug(
|
||||
@classmethod
|
||||
async def _log_request(cls, request: Request):
|
||||
_logger.debug(
|
||||
f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}"
|
||||
)
|
||||
|
||||
@@ -63,7 +64,7 @@ class LoggingMiddleware(ASGIMiddleware):
|
||||
user = get_user()
|
||||
|
||||
request_info = {
|
||||
"headers": self._filter_relevant_headers(dict(request.headers)),
|
||||
"headers": cls._filter_relevant_headers(dict(request.headers)),
|
||||
"args": dict(request.query_params),
|
||||
"form-data": (
|
||||
await request.form()
|
||||
@@ -77,9 +78,10 @@ class LoggingMiddleware(ASGIMiddleware):
|
||||
),
|
||||
}
|
||||
|
||||
self._logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
|
||||
_logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
|
||||
|
||||
async def _log_after_request(self, request: Request, status_code: int, duration: float):
|
||||
self._logger.info(
|
||||
@staticmethod
|
||||
async def _log_after_request(request: Request, status_code: int, duration: float):
|
||||
_logger.info(
|
||||
f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
|
||||
)
|
||||
|
||||
@@ -5,49 +5,35 @@ from uuid import uuid4
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.types import Scope, Receive, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.api.logger import APILogger
|
||||
from cpl.api.typing import TRequest
|
||||
from cpl.auth.keycloak.keycloak_client import KeycloakClient
|
||||
from cpl.auth.schema import User
|
||||
from cpl.auth.schema._administration.user_dao import UserDao
|
||||
from cpl.core.ctx import set_user
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
_request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", default=None)
|
||||
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
|
||||
class RequestMiddleware(ASGIMiddleware):
|
||||
|
||||
def __init__(self, app, provider: ServiceProvider, logger: APILogger, keycloak: KeycloakClient, user_dao: UserDao):
|
||||
def __init__(self, app):
|
||||
ASGIMiddleware.__init__(self, app)
|
||||
|
||||
self._provider = provider
|
||||
self._logger = logger
|
||||
|
||||
self._keycloak = keycloak
|
||||
self._user_dao = user_dao
|
||||
|
||||
self._ctx_token = None
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
request = Request(scope, receive, send) if scope["type"] != "websocket" else WebSocket(scope, receive, send)
|
||||
request = Request(scope, receive, send)
|
||||
await self.set_request_data(request)
|
||||
|
||||
try:
|
||||
await self._try_set_user(request)
|
||||
with self._provider.create_scope():
|
||||
inject(await self._app(scope, receive, send))
|
||||
await self._app(scope, receive, send)
|
||||
finally:
|
||||
await self.clean_request_data()
|
||||
|
||||
async def set_request_data(self, request: TRequest):
|
||||
request.state.request_id = uuid4()
|
||||
request.state.start_time = time.time()
|
||||
self._logger.trace(f"Set new current request: {request.state.request_id}")
|
||||
_logger.trace(f"Set new current request: {request.state.request_id}")
|
||||
|
||||
self._ctx_token = _request_context.set(request)
|
||||
|
||||
@@ -59,40 +45,9 @@ class RequestMiddleware(ASGIMiddleware):
|
||||
if self._ctx_token is None:
|
||||
return
|
||||
|
||||
self._logger.trace(f"Clearing current request: {request.state.request_id}")
|
||||
_logger.trace(f"Clearing current request: {request.state.request_id}")
|
||||
_request_context.reset(self._ctx_token)
|
||||
|
||||
async def _try_set_user(self, request: Request):
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
return
|
||||
|
||||
token = auth_header.split("Bearer ")[1]
|
||||
try:
|
||||
token_info = self._keycloak.introspect(token)
|
||||
if not token_info.get("active", False):
|
||||
return
|
||||
|
||||
keycloak_id = self._keycloak.get_user_id(token)
|
||||
if not keycloak_id:
|
||||
return
|
||||
|
||||
user = await self._user_dao.find_by_keycloak_id(keycloak_id)
|
||||
if not user:
|
||||
user = User(0, keycloak_id)
|
||||
uid = await self._user_dao.create(user)
|
||||
user = await self._user_dao.get_by_id(uid)
|
||||
|
||||
if user.deleted:
|
||||
return
|
||||
|
||||
request.state.user = user
|
||||
set_user(user)
|
||||
self._logger.trace(f"User {user.id} bound to request {request.state.request_id}")
|
||||
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Silent user binding failed: {e}")
|
||||
|
||||
|
||||
def get_request() -> Optional[TRequest]:
|
||||
return _request_context.get()
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .api_route import ApiRoute
|
||||
from .policy import Policy
|
||||
from .validation_match import ValidationMatch
|
||||
|
||||
@@ -7,7 +7,13 @@ from cpl.api.typing import HTTPMethods
|
||||
|
||||
class ApiRoute:
|
||||
|
||||
def __init__(self, path: str, fn: Callable, method: HTTPMethods, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
fn: Callable,
|
||||
method: HTTPMethods,
|
||||
**kwargs
|
||||
):
|
||||
self._path = path
|
||||
self._fn = fn
|
||||
self._method = method
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from asyncio import iscoroutinefunction
|
||||
from typing import Optional
|
||||
from typing import Optional, Any, Coroutine, Awaitable
|
||||
|
||||
from cpl.api.typing import PolicyResolver
|
||||
from cpl.core.ctx import get_user
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
from typing import Callable
|
||||
|
||||
import starlette.routing
|
||||
|
||||
|
||||
class WebSocketRoute:
|
||||
|
||||
def __init__(self, path: str, fn: Callable, **kwargs):
|
||||
self._path = path
|
||||
self._fn = fn
|
||||
|
||||
self._kwargs = kwargs
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._fn.__name__
|
||||
|
||||
@property
|
||||
def fn(self) -> Callable:
|
||||
return self._fn
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self._path
|
||||
|
||||
@property
|
||||
def kwargs(self) -> dict:
|
||||
return self._kwargs
|
||||
|
||||
def to_starlette(self, *args) -> starlette.routing.WebSocketRoute:
|
||||
return starlette.routing.WebSocketRoute(self._path, self._fn)
|
||||
@@ -1,2 +0,0 @@
|
||||
from .policy import PolicyRegistry
|
||||
from .route import RouteRegistry
|
||||
|
||||
@@ -1,35 +1,33 @@
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from cpl.api.model.policy import Policy
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
from cpl.api.model.websocket_route import WebSocketRoute
|
||||
from cpl.core.abc.registry_abc import RegistryABC
|
||||
|
||||
TRoute = Union[ApiRoute, WebSocketRoute]
|
||||
|
||||
|
||||
class RouteRegistry(RegistryABC):
|
||||
|
||||
def __init__(self):
|
||||
RegistryABC.__init__(self)
|
||||
|
||||
def extend(self, items: list[TRoute]):
|
||||
def extend(self, items: list[ApiRoute]):
|
||||
for policy in items:
|
||||
self.add(policy)
|
||||
|
||||
def add(self, item: TRoute):
|
||||
assert isinstance(item, (ApiRoute, WebSocketRoute)), "route must be an instance of ApiRoute"
|
||||
def add(self, item: ApiRoute):
|
||||
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
|
||||
|
||||
if item.path in self._items:
|
||||
raise ValueError(f"ApiRoute {item.path} is already registered")
|
||||
|
||||
self._items[item.path] = item
|
||||
|
||||
def set(self, item: TRoute):
|
||||
def set(self, item: ApiRoute):
|
||||
assert isinstance(item, ApiRoute), "route must be an instance of ApiRoute"
|
||||
self._items[item.path] = item
|
||||
|
||||
def get(self, key: str) -> Optional[TRoute]:
|
||||
def get(self, key: str) -> Optional[ApiRoute]:
|
||||
return self._items.get(key)
|
||||
|
||||
def all(self) -> list[TRoute]:
|
||||
def all(self) -> list[ApiRoute]:
|
||||
return list(self._items.values())
|
||||
|
||||
@@ -3,7 +3,6 @@ from enum import Enum
|
||||
from cpl.api.model.validation_match import ValidationMatch
|
||||
from cpl.api.registry.route import RouteRegistry
|
||||
from cpl.api.typing import HTTPMethods
|
||||
from cpl.dependency import get_provider
|
||||
|
||||
|
||||
class Router:
|
||||
@@ -42,13 +41,7 @@ class Router:
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def authorize(
|
||||
cls,
|
||||
roles: list[str | Enum] = None,
|
||||
permissions: list[str | Enum] = None,
|
||||
policies: list[str] = None,
|
||||
match: ValidationMatch = None,
|
||||
):
|
||||
def authorize(cls, roles: list[str | Enum]=None, permissions: list[str | Enum]=None, policies: list[str]=None, match: ValidationMatch=None):
|
||||
"""
|
||||
Decorator to mark a route as requiring authorization.
|
||||
Usage:
|
||||
@@ -92,29 +85,14 @@ class Router:
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def websocket(cls, path: str, registry: RouteRegistry = None, **kwargs):
|
||||
from cpl.api.model.websocket_route import WebSocketRoute
|
||||
|
||||
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry=None, **kwargs):
|
||||
if not registry:
|
||||
routes = get_provider().get_service(RouteRegistry)
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
routes = ServiceProviderABC.get_global_service(RouteRegistry)
|
||||
else:
|
||||
routes = registry
|
||||
|
||||
def inner(fn):
|
||||
routes.add(WebSocketRoute(path, fn, **kwargs))
|
||||
setattr(fn, "_route_path", path)
|
||||
return fn
|
||||
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def route(cls, path: str, method: HTTPMethods, registry: RouteRegistry = None, **kwargs):
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
|
||||
if not registry:
|
||||
routes = get_provider().get_service(RouteRegistry)
|
||||
else:
|
||||
routes = registry
|
||||
|
||||
def inner(fn):
|
||||
routes.add(ApiRoute(path, fn, method, **kwargs))
|
||||
@@ -159,9 +137,8 @@ class Router:
|
||||
"""
|
||||
|
||||
from cpl.api.model.api_route import ApiRoute
|
||||
|
||||
routes = get_provider().get_service(RouteRegistry)
|
||||
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
routes = ServiceProviderABC.get_global_service(RouteRegistry)
|
||||
def inner(fn):
|
||||
path = getattr(fn, "_route_path", None)
|
||||
if path is None:
|
||||
@@ -170,7 +147,7 @@ class Router:
|
||||
route = routes.get(path)
|
||||
if route is None:
|
||||
raise ValueError(f"Cannot override a route that does not exist: {path}")
|
||||
|
||||
|
||||
routes.add(ApiRoute(path, fn, route.method, **route.kwargs))
|
||||
setattr(fn, "_route_path", path)
|
||||
return fn
|
||||
|
||||
@@ -6,7 +6,7 @@ from cpl.core.configuration import ConfigurationModelABC
|
||||
class ApiSettings(ConfigurationModelABC):
|
||||
|
||||
def __init__(self, src: Optional[dict] = None):
|
||||
ConfigurationModelABC.__init__(self, src)
|
||||
super().__init__(src)
|
||||
|
||||
self.option("host", str, "0.0.0.0")
|
||||
self.option("port", int, 5000)
|
||||
|
||||
@@ -2,15 +2,13 @@ from typing import Union, Literal, Callable, Type, Awaitable
|
||||
from urllib.request import Request
|
||||
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.responses import Response
|
||||
from starlette.types import ASGIApp
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
|
||||
from cpl.auth.schema import User
|
||||
from cpl.auth.schema import AuthUser
|
||||
|
||||
TRequest = Union[Request, WebSocket]
|
||||
TEndpoint = Callable[[TRequest, ...], Awaitable[Response]] | Callable[[TRequest, ...], Response]
|
||||
HTTPMethods = Literal["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
|
||||
PartialMiddleware = Union[
|
||||
ASGIMiddleware,
|
||||
@@ -18,5 +16,4 @@ PartialMiddleware = Union[
|
||||
Middleware,
|
||||
Callable[[ASGIApp], ASGIApp],
|
||||
]
|
||||
PolicyResolver = Callable[[User], bool | Awaitable[bool]]
|
||||
PolicyInput = Union[dict[str, PolicyResolver], "Policy"]
|
||||
PolicyResolver = Callable[[AuthUser], bool | Awaitable[bool]]
|
||||
@@ -1,2 +1 @@
|
||||
from .application_builder import ApplicationBuilder
|
||||
from .host import Host
|
||||
|
||||
@@ -2,12 +2,11 @@ from abc import ABC, abstractmethod
|
||||
from typing import Callable, Self
|
||||
|
||||
from cpl.application.host import Host
|
||||
from cpl.core.errors import module_dependency_error
|
||||
from cpl.core.console.console import Console
|
||||
from cpl.core.log import LogSettings
|
||||
from cpl.core.log.log_level import LogLevel
|
||||
from cpl.core.log.log_settings import LogSettings
|
||||
from cpl.core.log.logger_abc import LoggerABC
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.dependency.typing import TModule
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
|
||||
def __not_implemented__(package: str, func: Callable):
|
||||
@@ -18,10 +17,21 @@ class ApplicationABC(ABC):
|
||||
r"""ABC for the Application class
|
||||
|
||||
Parameters:
|
||||
services: :class:`cpl.dependency.service_provider.ServiceProvider`
|
||||
services: :class:`cpl.dependency.service_provider_abc.ServiceProviderABC`
|
||||
Contains instances of prepared objects
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, services: ServiceProviderABC, required_modules: list[str | object] = None):
|
||||
self._services = services
|
||||
self._required_modules = (
|
||||
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []
|
||||
)
|
||||
|
||||
@property
|
||||
def required_modules(self) -> list[str]:
|
||||
return self._required_modules
|
||||
|
||||
@classmethod
|
||||
def extend(cls, name: str | Callable, func: Callable[[Self], Self]):
|
||||
r"""Extend the Application with a custom method
|
||||
@@ -38,30 +48,6 @@ class ApplicationABC(ABC):
|
||||
setattr(cls, name, func)
|
||||
return cls
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self, services: ServiceProvider, loaded_modules: set[TModule], required_modules: list[str | object] = None
|
||||
):
|
||||
self._services = services
|
||||
self._modules = loaded_modules
|
||||
self._required_modules = (
|
||||
[x.__name__ if not isinstance(x, str) else x for x in required_modules] if required_modules else []
|
||||
)
|
||||
|
||||
def validate_app_required_modules(self):
|
||||
modules_names = {x.__name__ for x in self._modules}
|
||||
for module in self._required_modules:
|
||||
if module in modules_names:
|
||||
continue
|
||||
|
||||
module_dependency_error(
|
||||
type(self).__name__,
|
||||
module.__name__ if not isinstance(module, str) else module,
|
||||
ImportError(
|
||||
f"Required module '{module}' for application '{self.__class__.__name__}' is not loaded. Load using 'add_module({module})' method."
|
||||
),
|
||||
)
|
||||
|
||||
def with_logging(self, level: LogLevel = None):
|
||||
if level is None:
|
||||
from cpl.core.configuration.configuration import Configuration
|
||||
@@ -72,21 +58,14 @@ class ApplicationABC(ABC):
|
||||
logger = self._services.get_service(LoggerABC)
|
||||
logger.set_level(level)
|
||||
|
||||
def with_permissions(self, *args):
|
||||
try:
|
||||
from cpl.auth import AuthModule
|
||||
def with_permissions(self, *args, **kwargs):
|
||||
__not_implemented__("cpl-auth", self.with_permissions)
|
||||
|
||||
AuthModule.with_permissions(*args)
|
||||
except ImportError:
|
||||
__not_implemented__("cpl-auth", self.with_permissions)
|
||||
def with_migrations(self, *args, **kwargs):
|
||||
__not_implemented__("cpl-database", self.with_migrations)
|
||||
|
||||
def with_migrations(self, *args):
|
||||
try:
|
||||
from cpl.database.database_module import DatabaseModule
|
||||
|
||||
DatabaseModule.with_migrations(self._services, *args)
|
||||
except ImportError:
|
||||
__not_implemented__("cpl-database", self.with_migrations)
|
||||
def with_seeders(self, *args, **kwargs):
|
||||
__not_implemented__("cpl-database", self.with_seeders)
|
||||
|
||||
def with_extension(self, func: Callable[[Self, ...], None], *args, **kwargs):
|
||||
r"""Extend the Application with a custom method
|
||||
@@ -106,17 +85,9 @@ class ApplicationABC(ABC):
|
||||
Called by custom Application.main
|
||||
"""
|
||||
try:
|
||||
for module in self._modules:
|
||||
if not hasattr(module, "configure") and not callable(getattr(module, "configure")):
|
||||
continue
|
||||
module.configure(self._services)
|
||||
|
||||
Host.run_app(self.main)
|
||||
Host.run(self.main)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
logger = self._services.get_service(LoggerABC)
|
||||
logger.info("Application shutdown")
|
||||
|
||||
@abstractmethod
|
||||
def main(self): ...
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class ApplicationExtensionABC(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def run(services: ServiceProvider): ...
|
||||
def run(services: ServiceProviderABC): ...
|
||||
|
||||
@@ -6,7 +6,7 @@ from cpl.application.abc.application_extension_abc import ApplicationExtensionAB
|
||||
from cpl.application.abc.startup_abc import StartupABC
|
||||
from cpl.application.abc.startup_extension_abc import StartupExtensionABC
|
||||
from cpl.application.host import Host
|
||||
from cpl.dependency.context import get_provider, use_root_provider
|
||||
from cpl.core.errors import dependency_error
|
||||
from cpl.dependency.service_collection import ServiceCollection
|
||||
|
||||
TApp = TypeVar("TApp", bound=ApplicationABC)
|
||||
@@ -21,7 +21,6 @@ class ApplicationBuilder(Generic[TApp]):
|
||||
self._app = app if app is not None else ApplicationABC
|
||||
|
||||
self._services = ServiceCollection()
|
||||
use_root_provider(self._services.build())
|
||||
|
||||
self._startup: Optional[StartupABC] = None
|
||||
self._app_extensions: list[Type[ApplicationExtensionABC]] = []
|
||||
@@ -35,12 +34,19 @@ class ApplicationBuilder(Generic[TApp]):
|
||||
|
||||
@property
|
||||
def service_provider(self):
|
||||
provider = get_provider()
|
||||
if provider is None:
|
||||
provider = self._services.build()
|
||||
use_root_provider(provider)
|
||||
return self._services.build()
|
||||
|
||||
return provider
|
||||
def validate_app_required_modules(self, app: ApplicationABC):
|
||||
for module in app.required_modules:
|
||||
if module in self._services.loaded_modules:
|
||||
continue
|
||||
|
||||
dependency_error(
|
||||
module,
|
||||
ImportError(
|
||||
f"Required module '{module}' for application '{app.__class__.__name__}' is not loaded. Load using 'add_module({module})' method."
|
||||
),
|
||||
)
|
||||
|
||||
def with_startup(self, startup: Type[StartupABC]) -> "ApplicationBuilder":
|
||||
self._startup = startup
|
||||
@@ -69,7 +75,6 @@ class ApplicationBuilder(Generic[TApp]):
|
||||
for extension in self._app_extensions:
|
||||
Host.run(extension.run, self.service_provider)
|
||||
|
||||
use_root_provider(self._services.build())
|
||||
app = self._app(self.service_provider, self._services.loaded_modules)
|
||||
app.validate_app_required_modules()
|
||||
app = self._app(self.service_provider)
|
||||
self.validate_app_required_modules(app)
|
||||
return app
|
||||
|
||||
@@ -1,75 +1,17 @@
|
||||
import asyncio
|
||||
from typing import Callable
|
||||
|
||||
from cpl.dependency import get_provider
|
||||
from cpl.dependency.hosted.startup_task import StartupTask
|
||||
|
||||
|
||||
class Host:
|
||||
_loop: asyncio.AbstractEventLoop | None = None
|
||||
_tasks: dict = {}
|
||||
_loop = asyncio.get_event_loop()
|
||||
|
||||
@classmethod
|
||||
def get_loop(cls) -> asyncio.AbstractEventLoop:
|
||||
if cls._loop is None:
|
||||
cls._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(cls._loop)
|
||||
def get_loop(cls):
|
||||
return cls._loop
|
||||
|
||||
@classmethod
|
||||
def run_start_tasks(cls):
|
||||
provider = get_provider()
|
||||
tasks = provider.get_services(StartupTask)
|
||||
loop = cls.get_loop()
|
||||
|
||||
for task in tasks:
|
||||
if asyncio.iscoroutinefunction(task.run):
|
||||
loop.run_until_complete(task.run())
|
||||
else:
|
||||
task.run()
|
||||
|
||||
@classmethod
|
||||
def run_hosted_services(cls):
|
||||
provider = get_provider()
|
||||
services = provider.get_hosted_services()
|
||||
loop = cls.get_loop()
|
||||
|
||||
for service in services:
|
||||
if asyncio.iscoroutinefunction(service.start):
|
||||
cls._tasks[service] = loop.create_task(service.start())
|
||||
|
||||
@classmethod
|
||||
async def _stop_all(cls):
|
||||
for service in cls._tasks.keys():
|
||||
if asyncio.iscoroutinefunction(service.stop):
|
||||
await service.stop()
|
||||
|
||||
for task in cls._tasks.values():
|
||||
task.cancel()
|
||||
|
||||
cls._tasks.clear()
|
||||
|
||||
@classmethod
|
||||
def run_app(cls, func: Callable, *args, **kwargs):
|
||||
cls.run_start_tasks()
|
||||
cls.run_hosted_services()
|
||||
|
||||
async def runner():
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
await func(*args, **kwargs)
|
||||
else:
|
||||
func(*args, **kwargs)
|
||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
||||
pass
|
||||
finally:
|
||||
await cls._stop_all()
|
||||
|
||||
cls.get_loop().run_until_complete(runner())
|
||||
|
||||
@classmethod
|
||||
def run(cls, func: Callable, *args, **kwargs):
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return cls.get_loop().run_until_complete(func(*args, **kwargs))
|
||||
return cls._loop.run_until_complete(func(*args, **kwargs))
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@@ -1,6 +1,84 @@
|
||||
from enum import Enum
|
||||
from typing import Type
|
||||
|
||||
from cpl.application.abc import ApplicationABC as _ApplicationABC
|
||||
from cpl.auth import permission as _permission
|
||||
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin as _KeycloakAdmin
|
||||
from cpl.auth.keycloak.keycloak_client import KeycloakClient as _KeycloakClient
|
||||
from .auth_module import AuthModule
|
||||
from cpl.dependency.service_collection import ServiceCollection as _ServiceCollection
|
||||
from .auth_logger import AuthLogger
|
||||
from .keycloak_settings import KeycloakSettings
|
||||
from .logger import AuthLogger
|
||||
from .permission_seeder import PermissionSeeder
|
||||
|
||||
|
||||
def _with_permissions(self: _ApplicationABC, *permissions: Type[Enum]) -> _ApplicationABC:
|
||||
from cpl.auth.permission.permissions_registry import PermissionsRegistry
|
||||
|
||||
for perm in permissions:
|
||||
PermissionsRegistry.with_enum(perm)
|
||||
return self
|
||||
|
||||
|
||||
def _add_daos(collection: _ServiceCollection):
|
||||
from .schema._administration.auth_user_dao import AuthUserDao
|
||||
from .schema._administration.api_key_dao import ApiKeyDao
|
||||
from .schema._permission.api_key_permission_dao import ApiKeyPermissionDao
|
||||
from .schema._permission.permission_dao import PermissionDao
|
||||
from .schema._permission.role_dao import RoleDao
|
||||
from .schema._permission.role_permission_dao import RolePermissionDao
|
||||
from .schema._permission.role_user_dao import RoleUserDao
|
||||
|
||||
collection.add_singleton(AuthUserDao)
|
||||
collection.add_singleton(ApiKeyDao)
|
||||
collection.add_singleton(ApiKeyPermissionDao)
|
||||
collection.add_singleton(PermissionDao)
|
||||
collection.add_singleton(RoleDao)
|
||||
collection.add_singleton(RolePermissionDao)
|
||||
collection.add_singleton(RoleUserDao)
|
||||
|
||||
|
||||
def add_auth(collection: _ServiceCollection):
|
||||
import os
|
||||
|
||||
try:
|
||||
from cpl.database.service.migration_service import MigrationService
|
||||
from cpl.database.model.server_type import ServerType, ServerTypes
|
||||
|
||||
collection.add_singleton(_KeycloakClient)
|
||||
collection.add_singleton(_KeycloakAdmin)
|
||||
|
||||
_add_daos(collection)
|
||||
|
||||
provider = collection.build()
|
||||
migration_service: MigrationService = provider.get_service(MigrationService)
|
||||
if ServerType.server_type == ServerTypes.POSTGRES:
|
||||
migration_service.with_directory(
|
||||
os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/postgres")
|
||||
)
|
||||
elif ServerType.server_type == ServerTypes.MYSQL:
|
||||
migration_service.with_directory(os.path.join(os.path.dirname(os.path.realpath(__file__)), "scripts/mysql"))
|
||||
except ImportError as e:
|
||||
from cpl.core.console import Console
|
||||
|
||||
Console.error("cpl-database is not installed", str(e))
|
||||
|
||||
|
||||
def add_permission(collection: _ServiceCollection):
|
||||
from .permission_seeder import PermissionSeeder
|
||||
from .permission.permissions_registry import PermissionsRegistry
|
||||
from .permission.permissions import Permissions
|
||||
|
||||
try:
|
||||
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
||||
|
||||
collection.add_singleton(DataSeederABC, PermissionSeeder)
|
||||
PermissionsRegistry.with_enum(Permissions)
|
||||
except ImportError as e:
|
||||
from cpl.core.console import Console
|
||||
|
||||
Console.error("cpl-database is not installed", str(e))
|
||||
|
||||
|
||||
_ServiceCollection.with_module(add_auth, __name__)
|
||||
_ServiceCollection.with_module(add_permission, _permission.__name__)
|
||||
_ApplicationABC.extend(_ApplicationABC.with_permissions, _with_permissions)
|
||||
|
||||
8
src/cpl-auth/cpl/auth/auth_logger.py
Normal file
8
src/cpl-auth/cpl/auth/auth_logger.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from cpl.core.log import Logger
|
||||
from cpl.core.typing import Source
|
||||
|
||||
|
||||
class AuthLogger(Logger):
|
||||
|
||||
def __init__(self, source: Source):
|
||||
Logger.__init__(self, source, "auth")
|
||||
@@ -1,56 +0,0 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Type
|
||||
|
||||
from cpl.auth.keycloak_settings import KeycloakSettings
|
||||
from cpl.database.database_module import DatabaseModule
|
||||
from cpl.database.model.server_type import ServerType, ServerTypes
|
||||
from cpl.database.mysql.mysql_module import MySQLModule
|
||||
from cpl.database.postgres.postgres_module import PostgresModule
|
||||
from cpl.dependency.module.module import Module
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from .keycloak.keycloak_admin import KeycloakAdmin
|
||||
from .keycloak.keycloak_client import KeycloakClient
|
||||
from .schema._administration.api_key_dao import ApiKeyDao
|
||||
from .schema._administration.user_dao import UserDao
|
||||
from .schema._permission.api_key_permission_dao import ApiKeyPermissionDao
|
||||
from .schema._permission.permission_dao import PermissionDao
|
||||
from .schema._permission.role_dao import RoleDao
|
||||
from .schema._permission.role_permission_dao import RolePermissionDao
|
||||
from .schema._permission.role_user_dao import RoleUserDao
|
||||
|
||||
|
||||
class AuthModule(Module):
|
||||
dependencies = [DatabaseModule, (MySQLModule, PostgresModule)]
|
||||
config = [KeycloakSettings]
|
||||
singleton = [
|
||||
KeycloakClient,
|
||||
KeycloakAdmin,
|
||||
UserDao,
|
||||
ApiKeyDao,
|
||||
ApiKeyPermissionDao,
|
||||
PermissionDao,
|
||||
RoleDao,
|
||||
RolePermissionDao,
|
||||
RoleUserDao,
|
||||
]
|
||||
scoped = []
|
||||
transient = []
|
||||
|
||||
@staticmethod
|
||||
def configure(provider: ServiceProvider):
|
||||
paths = {
|
||||
ServerTypes.POSTGRES: "scripts/postgres",
|
||||
ServerTypes.MYSQL: "scripts/mysql",
|
||||
}
|
||||
|
||||
DatabaseModule.with_migrations(
|
||||
provider, str(os.path.join(os.path.dirname(os.path.realpath(__file__)), paths[ServerType.server_type]))
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def with_permissions(*permissions: Type[Enum]):
|
||||
from cpl.auth.permission.permissions_registry import PermissionsRegistry
|
||||
|
||||
for perm in permissions:
|
||||
PermissionsRegistry.with_enum(perm)
|
||||
@@ -1,13 +1,15 @@
|
||||
from keycloak import KeycloakAdmin as _KeycloakAdmin, KeycloakOpenIDConnection
|
||||
|
||||
from cpl.auth.auth_logger import AuthLogger
|
||||
from cpl.auth.keycloak_settings import KeycloakSettings
|
||||
from cpl.auth.logger import AuthLogger
|
||||
|
||||
_logger = AuthLogger("keycloak")
|
||||
|
||||
|
||||
class KeycloakAdmin(_KeycloakAdmin):
|
||||
|
||||
def __init__(self, logger: AuthLogger, settings: KeycloakSettings):
|
||||
# logger.info("Initializing Keycloak admin")
|
||||
def __init__(self, settings: KeycloakSettings):
|
||||
_logger.info("Initializing Keycloak admin")
|
||||
_connection = KeycloakOpenIDConnection(
|
||||
server_url=settings.url,
|
||||
client_id=settings.client_id,
|
||||
|
||||
@@ -2,13 +2,15 @@ from typing import Optional
|
||||
|
||||
from keycloak import KeycloakOpenID
|
||||
|
||||
from cpl.auth.logger import AuthLogger
|
||||
from cpl.auth.auth_logger import AuthLogger
|
||||
from cpl.auth.keycloak_settings import KeycloakSettings
|
||||
|
||||
_logger = AuthLogger("keycloak")
|
||||
|
||||
|
||||
class KeycloakClient(KeycloakOpenID):
|
||||
|
||||
def __init__(self, logger: AuthLogger, settings: KeycloakSettings):
|
||||
def __init__(self, settings: KeycloakSettings):
|
||||
KeycloakOpenID.__init__(
|
||||
self,
|
||||
server_url=settings.url,
|
||||
@@ -16,7 +18,7 @@ class KeycloakClient(KeycloakOpenID):
|
||||
realm_name=settings.realm,
|
||||
client_secret_key=settings.client_secret,
|
||||
)
|
||||
logger.info("Initializing Keycloak client")
|
||||
_logger.info("Initializing Keycloak client")
|
||||
|
||||
def get_user_id(self, token: str) -> Optional[str]:
|
||||
info = self.introspect(token)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from cpl.core.utils.get_value import get_value
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class KeycloakUser:
|
||||
@@ -32,5 +32,5 @@ class KeycloakUser:
|
||||
def id(self) -> str:
|
||||
from cpl.auth import KeycloakAdmin
|
||||
|
||||
keycloak_admin: KeycloakAdmin = get_provider().get_service(KeycloakAdmin)
|
||||
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
return keycloak_admin.get_user_id(self._username)
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from cpl.core.log.wrapped_logger import WrappedLogger
|
||||
|
||||
|
||||
class AuthLogger(WrappedLogger):
|
||||
|
||||
def __init__(self):
|
||||
WrappedLogger.__init__(self, "auth")
|
||||
@@ -1,4 +0,0 @@
|
||||
from .permission_module import PermissionsModule
|
||||
from .permission_seeder import PermissionSeeder
|
||||
from .permissions import Permissions
|
||||
from .permissions_registry import PermissionsRegistry
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
from cpl.auth.auth_module import AuthModule
|
||||
from cpl.auth.permission.permission_seeder import PermissionSeeder
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.auth.permission.permissions_registry import PermissionsRegistry
|
||||
from cpl.auth.permission.role_seeder import RoleSeeder
|
||||
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
||||
from cpl.database.database_module import DatabaseModule
|
||||
from cpl.dependency.module.module import Module
|
||||
from cpl.dependency.service_collection import ServiceCollection
|
||||
|
||||
|
||||
class PermissionsModule(Module):
|
||||
dependencies = [DatabaseModule, AuthModule]
|
||||
transient = [(DataSeederABC, PermissionSeeder), (DataSeederABC, RoleSeeder)]
|
||||
|
||||
@staticmethod
|
||||
def register(collection: ServiceCollection):
|
||||
PermissionsRegistry.with_enum(Permissions)
|
||||
@@ -1,60 +0,0 @@
|
||||
from cpl.auth.schema import (
|
||||
Role,
|
||||
RolePermission,
|
||||
PermissionDao,
|
||||
RoleDao,
|
||||
RolePermissionDao,
|
||||
ApiKeyDao,
|
||||
ApiKeyPermissionDao,
|
||||
UserDao,
|
||||
RoleUserDao,
|
||||
RoleUser,
|
||||
)
|
||||
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
||||
from cpl.database.logger import DBLogger
|
||||
|
||||
|
||||
class RoleSeeder(DataSeederABC):
|
||||
def __init__(
|
||||
self,
|
||||
logger: DBLogger,
|
||||
permission_dao: PermissionDao,
|
||||
role_dao: RoleDao,
|
||||
role_permission_dao: RolePermissionDao,
|
||||
api_key_dao: ApiKeyDao,
|
||||
api_key_permission_dao: ApiKeyPermissionDao,
|
||||
user_dao: UserDao,
|
||||
role_user_dao: RoleUserDao,
|
||||
):
|
||||
DataSeederABC.__init__(self)
|
||||
self._logger = logger
|
||||
self._permission_dao = permission_dao
|
||||
self._role_dao = role_dao
|
||||
self._role_permission_dao = role_permission_dao
|
||||
self._api_key_dao = api_key_dao
|
||||
self._api_key_permission_dao = api_key_permission_dao
|
||||
self._user_dao = user_dao
|
||||
self._role_user_dao = role_user_dao
|
||||
|
||||
async def seed(self):
|
||||
self._logger.info("Creating admin role")
|
||||
roles = await self._role_dao.get_all()
|
||||
if len(roles) == 0:
|
||||
rid = await self._role_dao.create(Role(0, "admin", "Default admin role"))
|
||||
permissions = await self._permission_dao.get_all()
|
||||
|
||||
await self._role_permission_dao.create_many(
|
||||
[RolePermission(0, rid, permission.id) for permission in permissions]
|
||||
)
|
||||
|
||||
role = await self._role_dao.get_by_name("admin")
|
||||
if len(await role.users) > 0:
|
||||
return
|
||||
|
||||
users = await self._user_dao.get_all()
|
||||
if len(users) == 0:
|
||||
return
|
||||
|
||||
user = users[0]
|
||||
self._logger.warning(f"Assigning admin role to first user {user.id}")
|
||||
await self._role_user_dao.create(RoleUser(0, role.id, user.id))
|
||||
@@ -1,3 +1,4 @@
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.auth.permission.permissions_registry import PermissionsRegistry
|
||||
from cpl.auth.schema import (
|
||||
Permission,
|
||||
@@ -13,13 +14,14 @@ from cpl.auth.schema import (
|
||||
)
|
||||
from cpl.core.utils.get_value import get_value
|
||||
from cpl.database.abc.data_seeder_abc import DataSeederABC
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class PermissionSeeder(DataSeederABC):
|
||||
def __init__(
|
||||
self,
|
||||
logger: DBLogger,
|
||||
permission_dao: PermissionDao,
|
||||
role_dao: RoleDao,
|
||||
role_permission_dao: RolePermissionDao,
|
||||
@@ -27,7 +29,6 @@ class PermissionSeeder(DataSeederABC):
|
||||
api_key_permission_dao: ApiKeyPermissionDao,
|
||||
):
|
||||
DataSeederABC.__init__(self)
|
||||
self._logger = logger
|
||||
self._permission_dao = permission_dao
|
||||
self._role_dao = role_dao
|
||||
self._role_permission_dao = role_permission_dao
|
||||
@@ -39,7 +40,7 @@ class PermissionSeeder(DataSeederABC):
|
||||
possible_permissions = [permission for permission in PermissionsRegistry.get()]
|
||||
|
||||
if len(permissions) == len(possible_permissions):
|
||||
self._logger.info("Permissions already existing")
|
||||
_logger.info("Permissions already existing")
|
||||
await self._update_missing_descriptions()
|
||||
return
|
||||
|
||||
@@ -52,7 +53,7 @@ class PermissionSeeder(DataSeederABC):
|
||||
|
||||
await self._permission_dao.delete_many(to_delete, hard_delete=True)
|
||||
|
||||
self._logger.warning("Permissions incomplete")
|
||||
_logger.warning("Permissions incomplete")
|
||||
permission_names = [permission.name for permission in permissions]
|
||||
await self._permission_dao.create_many(
|
||||
[
|
||||
@@ -1,7 +1,7 @@
|
||||
from ._administration.api_key import ApiKey
|
||||
from ._administration.api_key_dao import ApiKeyDao
|
||||
from ._administration.user import User
|
||||
from ._administration.user_dao import UserDao
|
||||
from ._administration.auth_user import AuthUser
|
||||
from ._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
from ._permission.api_key_permission import ApiKeyPermission
|
||||
from ._permission.api_key_permission_dao import ApiKeyPermissionDao
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union, Self
|
||||
from typing import Optional, Union
|
||||
|
||||
from async_property import async_property
|
||||
|
||||
@@ -10,13 +10,12 @@ from cpl.core.log.logger import Logger
|
||||
from cpl.core.typing import Id, SerialId
|
||||
from cpl.core.utils.credential_manager import CredentialManager
|
||||
from cpl.database.abc.db_model_abc import DbModelABC
|
||||
from cpl.dependency import get_provider
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
|
||||
_logger = Logger(__name__)
|
||||
|
||||
|
||||
class ApiKey(DbModelABC[Self]):
|
||||
class ApiKey(DbModelABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -25,8 +24,8 @@ class ApiKey(DbModelABC[Self]):
|
||||
key: Union[str, bytes],
|
||||
deleted: bool = False,
|
||||
editor_id: Optional[Id] = None,
|
||||
created: datetime | None = None,
|
||||
updated: datetime | None = None,
|
||||
created: Optional[datetime] = None,
|
||||
updated: Optional[datetime] = None,
|
||||
):
|
||||
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
||||
self._identifier = identifier
|
||||
@@ -48,7 +47,7 @@ class ApiKey(DbModelABC[Self]):
|
||||
async def permissions(self):
|
||||
from cpl.auth.schema._permission.api_key_permission_dao import ApiKeyPermissionDao
|
||||
|
||||
apiKeyPermissionDao = get_provider().get_service(ApiKeyPermissionDao)
|
||||
apiKeyPermissionDao = ServiceProviderABC.get_global_provider().get_service(ApiKeyPermissionDao)
|
||||
|
||||
return [await x.permission for x in await apiKeyPermissionDao.find_by_api_key_id(self.id)]
|
||||
|
||||
|
||||
@@ -3,12 +3,15 @@ from typing import Optional
|
||||
from cpl.auth.schema._administration.api_key import ApiKey
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class ApiKeyDao(DbModelDaoABC[ApiKey]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, ApiKey, TableManager.get("api_keys"))
|
||||
DbModelDaoABC.__init__(self, __name__, ApiKey, TableManager.get("api_keys"))
|
||||
|
||||
self.attribute(ApiKey.identifier, str)
|
||||
self.attribute(ApiKey.key, str, "keystring")
|
||||
|
||||
89
src/cpl-auth/cpl/auth/schema/_administration/auth_user.py
Normal file
89
src/cpl-auth/cpl/auth/schema/_administration/auth_user.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from async_property import async_property
|
||||
from keycloak import KeycloakGetError
|
||||
|
||||
from cpl.auth.keycloak import KeycloakAdmin
|
||||
from cpl.auth.auth_logger import AuthLogger
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
_logger = AuthLogger(__name__)
|
||||
|
||||
|
||||
class AuthUser(DbModelABC):
|
||||
def __init__(
|
||||
self,
|
||||
id: SerialId,
|
||||
keycloak_id: str,
|
||||
deleted: bool = False,
|
||||
editor_id: Optional[SerialId] = None,
|
||||
created: Optional[datetime] = None,
|
||||
updated: Optional[datetime] = None,
|
||||
):
|
||||
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
||||
self._keycloak_id = keycloak_id
|
||||
|
||||
@property
|
||||
def keycloak_id(self) -> str:
|
||||
return self._keycloak_id
|
||||
|
||||
@property
|
||||
def username(self):
|
||||
if self._keycloak_id == str(uuid.UUID(int=0)):
|
||||
return "ANONYMOUS"
|
||||
|
||||
try:
|
||||
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
return keycloak_admin.get_user(self._keycloak_id).get("username")
|
||||
except KeycloakGetError as e:
|
||||
return "UNKNOWN"
|
||||
except Exception as e:
|
||||
_logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||
return "UNKNOWN"
|
||||
|
||||
@property
|
||||
def email(self):
|
||||
if self._keycloak_id == str(uuid.UUID(int=0)):
|
||||
return "ANONYMOUS"
|
||||
|
||||
try:
|
||||
keycloak_admin: KeycloakAdmin = ServiceProviderABC.get_global_service(KeycloakAdmin)
|
||||
return keycloak_admin.get_user(self._keycloak_id).get("email")
|
||||
except KeycloakGetError as e:
|
||||
return "UNKNOWN"
|
||||
except Exception as e:
|
||||
_logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||
return "UNKNOWN"
|
||||
|
||||
@async_property
|
||||
async def roles(self):
|
||||
from cpl.auth.schema._permission.role_user_dao import RoleUserDao
|
||||
|
||||
role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
|
||||
return [await x.role for x in await role_user_dao.get_by_user_id(self.id)]
|
||||
|
||||
@async_property
|
||||
async def permissions(self):
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
|
||||
return await auth_user_dao.get_permissions(self.id)
|
||||
|
||||
async def has_permission(self, permission: Permissions) -> bool:
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
|
||||
return await auth_user_dao.has_permission(self.id, permission)
|
||||
|
||||
async def anonymize(self):
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
|
||||
|
||||
self._keycloak_id = str(uuid.UUID(int=0))
|
||||
await auth_user_dao.update(self)
|
||||
@@ -1,23 +1,22 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
from cpl.auth.schema._permission.permission import Permission
|
||||
from cpl.auth.schema._administration.user import User
|
||||
from cpl.auth.schema._administration.auth_user import AuthUser
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
from cpl.database.external_data_temp_table_builder import ExternalDataTempTableBuilder
|
||||
from cpl.dependency.context import get_provider
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class UserDao(DbModelDaoABC[User]):
|
||||
class AuthUserDao(DbModelDaoABC[AuthUser]):
|
||||
|
||||
def __init__(self, permission_dao: PermissionDao):
|
||||
DbModelDaoABC.__init__(self, User, TableManager.get("users"))
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, __name__, AuthUser, TableManager.get("auth_users"))
|
||||
|
||||
self._permissions = permission_dao
|
||||
|
||||
self.attribute(User.keycloak_id, str)
|
||||
self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId")
|
||||
|
||||
async def get_users():
|
||||
return [(x.id, x.username, x.email) for x in await self.get_all()]
|
||||
@@ -31,16 +30,16 @@ class UserDao(DbModelDaoABC[User]):
|
||||
.with_value_getter(get_users)
|
||||
)
|
||||
|
||||
async def get_by_keycloak_id(self, keycloak_id: str) -> User:
|
||||
return await self.get_single_by({User.keycloak_id: keycloak_id})
|
||||
async def get_by_keycloak_id(self, keycloak_id: str) -> AuthUser:
|
||||
return await self.get_single_by({AuthUser.keycloak_id: keycloak_id})
|
||||
|
||||
async def find_by_keycloak_id(self, keycloak_id: str) -> Optional[User]:
|
||||
return await self.find_single_by({User.keycloak_id: keycloak_id})
|
||||
async def find_by_keycloak_id(self, keycloak_id: str) -> Optional[AuthUser]:
|
||||
return await self.find_single_by({AuthUser.keycloak_id: keycloak_id})
|
||||
|
||||
async def has_permission(self, user_id: int, permission: Union[Permissions, str]) -> bool:
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
|
||||
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
|
||||
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
|
||||
p = await permission_dao.get_by_name(permission if isinstance(permission, str) else permission.value)
|
||||
result = await self._db.select_map(
|
||||
f"""
|
||||
@@ -58,7 +57,7 @@ class UserDao(DbModelDaoABC[User]):
|
||||
|
||||
return result[0]["count"] > 0
|
||||
|
||||
async def get_permissions(self, user_id: int) -> list[Permission]:
|
||||
async def get_permissions(self, user_id: int) -> list[Permissions]:
|
||||
result = await self._db.select_map(
|
||||
f"""
|
||||
SELECT p.*
|
||||
@@ -70,4 +69,4 @@ class UserDao(DbModelDaoABC[User]):
|
||||
AND ru.deleted = FALSE;
|
||||
"""
|
||||
)
|
||||
return [self._permissions.to_object(x) for x in result]
|
||||
return [Permissions(p["name"]) for p in result]
|
||||
@@ -1,89 +0,0 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, Self
|
||||
|
||||
from async_property import async_property
|
||||
from keycloak import KeycloakGetError
|
||||
|
||||
from cpl.auth.keycloak import KeycloakAdmin
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
from cpl.database.logger import DBLogger
|
||||
from cpl.dependency import get_provider
|
||||
|
||||
|
||||
class User(DbModelABC[Self]):
|
||||
def __init__(
|
||||
self,
|
||||
id: SerialId,
|
||||
keycloak_id: str,
|
||||
deleted: bool = False,
|
||||
editor_id: SerialId | None = None,
|
||||
created: datetime | None = None,
|
||||
updated: datetime | None = None,
|
||||
):
|
||||
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
||||
self._keycloak_id = keycloak_id
|
||||
|
||||
@property
|
||||
def keycloak_id(self) -> str:
|
||||
return self._keycloak_id
|
||||
|
||||
@property
|
||||
def username(self):
|
||||
if self._keycloak_id == str(uuid.UUID(int=0)):
|
||||
return "ANONYMOUS"
|
||||
|
||||
try:
|
||||
keycloak = get_provider().get_service(KeycloakAdmin)
|
||||
return keycloak.get_user(self._keycloak_id).get("username")
|
||||
except KeycloakGetError as e:
|
||||
return "UNKNOWN"
|
||||
except Exception as e:
|
||||
logger = get_provider().get_service(DBLogger)
|
||||
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||
return "UNKNOWN"
|
||||
|
||||
@property
|
||||
def email(self):
|
||||
if self._keycloak_id == str(uuid.UUID(int=0)):
|
||||
return "ANONYMOUS"
|
||||
|
||||
try:
|
||||
keycloak = get_provider().get_service(KeycloakAdmin)
|
||||
return keycloak.get_user(self._keycloak_id).get("email")
|
||||
except KeycloakGetError as e:
|
||||
return "UNKNOWN"
|
||||
except Exception as e:
|
||||
logger = get_provider().get_service(DBLogger)
|
||||
logger.error(f"Failed to get user {self._keycloak_id} from Keycloak", e)
|
||||
return "UNKNOWN"
|
||||
|
||||
@async_property
|
||||
async def roles(self):
|
||||
from cpl.auth.schema._permission.role_user_dao import RoleUserDao
|
||||
|
||||
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao)
|
||||
return [await x.role for x in await role_user_dao.get_by_user_id(self.id)]
|
||||
|
||||
@async_property
|
||||
async def permissions(self):
|
||||
from cpl.auth.schema._administration.user_dao import UserDao
|
||||
|
||||
user_dao: UserDao = get_provider().get_service(UserDao)
|
||||
return await user_dao.get_permissions(self.id)
|
||||
|
||||
async def has_permission(self, permission: Permissions) -> bool:
|
||||
from cpl.auth.schema._administration.user_dao import UserDao
|
||||
|
||||
user_dao: UserDao = get_provider().get_service(UserDao)
|
||||
return await user_dao.has_permission(self.id, permission)
|
||||
|
||||
async def anonymize(self):
|
||||
from cpl.auth.schema._administration.user_dao import UserDao
|
||||
|
||||
user_dao: UserDao = get_provider().get_service(UserDao)
|
||||
|
||||
self._keycloak_id = str(uuid.UUID(int=0))
|
||||
await user_dao.update(self)
|
||||
@@ -5,7 +5,7 @@ from async_property import async_property
|
||||
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbJoinModelABC
|
||||
from cpl.dependency import ServiceProvider
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class ApiKeyPermission(DbJoinModelABC):
|
||||
@@ -15,9 +15,9 @@ class ApiKeyPermission(DbJoinModelABC):
|
||||
api_key_id: SerialId,
|
||||
permission_id: SerialId,
|
||||
deleted: bool = False,
|
||||
editor_id: SerialId | None = None,
|
||||
created: datetime | None = None,
|
||||
updated: datetime | None = None,
|
||||
editor_id: Optional[SerialId] = None,
|
||||
created: Optional[datetime] = None,
|
||||
updated: Optional[datetime] = None,
|
||||
):
|
||||
DbJoinModelABC.__init__(self, api_key_id, permission_id, id, deleted, editor_id, created, updated)
|
||||
self._api_key_id = api_key_id
|
||||
@@ -31,7 +31,7 @@ class ApiKeyPermission(DbJoinModelABC):
|
||||
async def api_key(self):
|
||||
from cpl.auth.schema._administration.api_key_dao import ApiKeyDao
|
||||
|
||||
api_key_dao: ApiKeyDao = get_provider().get_service(ApiKeyDao)
|
||||
api_key_dao: ApiKeyDao = ServiceProviderABC.get_global_service(ApiKeyDao)
|
||||
return await api_key_dao.get_by_id(self._api_key_id)
|
||||
|
||||
@property
|
||||
@@ -42,5 +42,5 @@ class ApiKeyPermission(DbJoinModelABC):
|
||||
async def permission(self):
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
|
||||
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
|
||||
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
|
||||
return await permission_dao.get_by_id(self._permission_id)
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from cpl.auth.schema._permission.api_key_permission import ApiKeyPermission
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class ApiKeyPermissionDao(DbModelDaoABC[ApiKeyPermission]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, ApiKeyPermission, TableManager.get("api_key_permissions"))
|
||||
DbModelDaoABC.__init__(self, __name__, ApiKeyPermission, TableManager.get("api_key_permissions"))
|
||||
|
||||
self.attribute(ApiKeyPermission.api_key_id, int)
|
||||
self.attribute(ApiKeyPermission.permission_id, int)
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, Self
|
||||
from typing import Optional
|
||||
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
|
||||
|
||||
class Permission(DbModelABC[Self]):
|
||||
class Permission(DbModelABC):
|
||||
def __init__(
|
||||
self,
|
||||
id: SerialId,
|
||||
name: str,
|
||||
description: str,
|
||||
deleted: bool = False,
|
||||
editor_id: SerialId | None = None,
|
||||
created: datetime | None = None,
|
||||
updated: datetime | None = None,
|
||||
editor_id: Optional[SerialId] = None,
|
||||
created: Optional[datetime] = None,
|
||||
updated: Optional[datetime] = None,
|
||||
):
|
||||
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
||||
self._name = name
|
||||
|
||||
@@ -3,12 +3,15 @@ from typing import Optional
|
||||
from cpl.auth.schema._permission.permission import Permission
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class PermissionDao(DbModelDaoABC[Permission]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, Permission, TableManager.get("permissions"))
|
||||
DbModelDaoABC.__init__(self, __name__, Permission, TableManager.get("permissions"))
|
||||
|
||||
self.attribute(Permission.name, str)
|
||||
self.attribute(Permission.description, Optional[str])
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, Self
|
||||
from typing import Optional
|
||||
|
||||
from async_property import async_property
|
||||
|
||||
from cpl.auth.permission.permissions import Permissions
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbModelABC
|
||||
from cpl.dependency import ServiceProvider, get_provider
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class Role(DbModelABC[Self]):
|
||||
class Role(DbModelABC):
|
||||
def __init__(
|
||||
self,
|
||||
id: SerialId,
|
||||
name: str,
|
||||
description: str,
|
||||
deleted: bool = False,
|
||||
editor_id: SerialId | None = None,
|
||||
created: datetime | None = None,
|
||||
updated: datetime | None = None,
|
||||
editor_id: Optional[SerialId] = None,
|
||||
created: Optional[datetime] = None,
|
||||
updated: Optional[datetime] = None,
|
||||
):
|
||||
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
||||
self._name = name
|
||||
@@ -44,22 +44,22 @@ class Role(DbModelABC[Self]):
|
||||
async def permissions(self):
|
||||
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
|
||||
|
||||
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao)
|
||||
role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
|
||||
return [await x.permission for x in await role_permission_dao.get_by_role_id(self.id)]
|
||||
|
||||
@async_property
|
||||
async def users(self):
|
||||
from cpl.auth.schema._permission.role_user_dao import RoleUserDao
|
||||
|
||||
role_user_dao: RoleUserDao = get_provider().get_service(RoleUserDao)
|
||||
role_user_dao: RoleUserDao = ServiceProviderABC.get_global_service(RoleUserDao)
|
||||
return [await x.user for x in await role_user_dao.get_by_role_id(self.id)]
|
||||
|
||||
async def has_permission(self, permission: Permissions) -> bool:
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
from cpl.auth.schema._permission.role_permission_dao import RolePermissionDao
|
||||
|
||||
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
|
||||
role_permission_dao: RolePermissionDao = get_provider().get_service(RolePermissionDao)
|
||||
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
|
||||
role_permission_dao: RolePermissionDao = ServiceProviderABC.get_global_service(RolePermissionDao)
|
||||
|
||||
p = await permission_dao.get_by_name(permission.value)
|
||||
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from cpl.auth.schema._permission.role import Role
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class RoleDao(DbModelDaoABC[Role]):
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, Role, TableManager.get("roles"))
|
||||
DbModelDaoABC.__init__(self, __name__, Role, TableManager.get("roles"))
|
||||
self.attribute(Role.name, str)
|
||||
self.attribute(Role.description, str)
|
||||
|
||||
|
||||
@@ -1,44 +1,46 @@
|
||||
from datetime import datetime
|
||||
from typing import Self
|
||||
from typing import Optional
|
||||
|
||||
from async_property import async_property
|
||||
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbJoinModelABC
|
||||
from cpl.dependency import get_provider
|
||||
from cpl.database.abc import DbModelABC
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class RolePermission(DbJoinModelABC[Self]):
|
||||
class RolePermission(DbModelABC):
|
||||
def __init__(
|
||||
self,
|
||||
id: SerialId,
|
||||
role_id: SerialId,
|
||||
permission_id: SerialId,
|
||||
deleted: bool = False,
|
||||
editor_id: SerialId | None = None,
|
||||
created: datetime | None = None,
|
||||
updated: datetime | None = None,
|
||||
editor_id: Optional[SerialId] = None,
|
||||
created: Optional[datetime] = None,
|
||||
updated: Optional[datetime] = None,
|
||||
):
|
||||
DbJoinModelABC.__init__(self, id, role_id, permission_id, deleted, editor_id, created, updated)
|
||||
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
|
||||
self._role_id = role_id
|
||||
self._permission_id = permission_id
|
||||
|
||||
@property
|
||||
def role_id(self) -> int:
|
||||
return self._source_id
|
||||
return self._role_id
|
||||
|
||||
@async_property
|
||||
async def role(self):
|
||||
from cpl.auth.schema._permission.role_dao import RoleDao
|
||||
|
||||
role_dao: RoleDao = get_provider().get_service(RoleDao)
|
||||
return await role_dao.get_by_id(self._source_id)
|
||||
role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
|
||||
return await role_dao.get_by_id(self._role_id)
|
||||
|
||||
@property
|
||||
def permission_id(self) -> int:
|
||||
return self._foreign_id
|
||||
return self._permission_id
|
||||
|
||||
@async_property
|
||||
async def permission(self):
|
||||
from cpl.auth.schema._permission.permission_dao import PermissionDao
|
||||
|
||||
permission_dao: PermissionDao = get_provider().get_service(PermissionDao)
|
||||
return await permission_dao.get_by_id(self._foreign_id)
|
||||
permission_dao: PermissionDao = ServiceProviderABC.get_global_service(PermissionDao)
|
||||
return await permission_dao.get_by_id(self._permission_id)
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from cpl.auth.schema._permission.role_permission import RolePermission
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class RolePermissionDao(DbModelDaoABC[RolePermission]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, RolePermission, TableManager.get("role_permissions"))
|
||||
DbModelDaoABC.__init__(self, __name__, RolePermission, TableManager.get("role_permissions"))
|
||||
|
||||
self.attribute(RolePermission.role_id, int)
|
||||
self.attribute(RolePermission.permission_id, int)
|
||||
|
||||
@@ -5,7 +5,7 @@ from async_property import async_property
|
||||
|
||||
from cpl.core.typing import SerialId
|
||||
from cpl.database.abc import DbJoinModelABC
|
||||
from cpl.dependency import ServiceProvider, get_provider
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
|
||||
class RoleUser(DbJoinModelABC):
|
||||
@@ -15,9 +15,9 @@ class RoleUser(DbJoinModelABC):
|
||||
user_id: SerialId,
|
||||
role_id: SerialId,
|
||||
deleted: bool = False,
|
||||
editor_id: SerialId | None = None,
|
||||
created: datetime | None = None,
|
||||
updated: datetime | None = None,
|
||||
editor_id: Optional[SerialId] = None,
|
||||
created: Optional[datetime] = None,
|
||||
updated: Optional[datetime] = None,
|
||||
):
|
||||
DbJoinModelABC.__init__(self, id, user_id, role_id, deleted, editor_id, created, updated)
|
||||
self._user_id = user_id
|
||||
@@ -29,10 +29,10 @@ class RoleUser(DbJoinModelABC):
|
||||
|
||||
@async_property
|
||||
async def user(self):
|
||||
from cpl.auth.schema._administration.user_dao import UserDao
|
||||
from cpl.auth.schema._administration.auth_user_dao import AuthUserDao
|
||||
|
||||
user_dao: UserDao = get_provider().get_service(UserDao)
|
||||
return await user_dao.get_by_id(self._user_id)
|
||||
auth_user_dao: AuthUserDao = ServiceProviderABC.get_global_service(AuthUserDao)
|
||||
return await auth_user_dao.get_by_id(self._user_id)
|
||||
|
||||
@property
|
||||
def role_id(self) -> int:
|
||||
@@ -42,5 +42,5 @@ class RoleUser(DbJoinModelABC):
|
||||
async def role(self):
|
||||
from cpl.auth.schema._permission.role_dao import RoleDao
|
||||
|
||||
role_dao: RoleDao = get_provider().get_service(RoleDao)
|
||||
role_dao: RoleDao = ServiceProviderABC.get_global_service(RoleDao)
|
||||
return await role_dao.get_by_id(self._role_id)
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from cpl.auth.schema._permission.role_user import RoleUser
|
||||
from cpl.database import TableManager
|
||||
from cpl.database.abc import DbModelDaoABC
|
||||
from cpl.database.db_logger import DBLogger
|
||||
|
||||
_logger = DBLogger(__name__)
|
||||
|
||||
|
||||
class RoleUserDao(DbModelDaoABC[RoleUser]):
|
||||
|
||||
def __init__(self):
|
||||
DbModelDaoABC.__init__(self, RoleUser, TableManager.get("role_users"))
|
||||
DbModelDaoABC.__init__(self, __name__, RoleUser, TableManager.get("role_users"))
|
||||
|
||||
self.attribute(RoleUser.role_id, int)
|
||||
self.attribute(RoleUser.user_id, int)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
CREATE TABLE IF NOT EXISTS administration_users
|
||||
CREATE TABLE IF NOT EXISTS administration_auth_users
|
||||
(
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
keycloakId CHAR(36) NOT NULL,
|
||||
@@ -9,10 +9,10 @@ CREATE TABLE IF NOT EXISTS administration_users
|
||||
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
|
||||
CONSTRAINT UC_KeycloakId UNIQUE (keycloakId),
|
||||
CONSTRAINT FK_EditorId FOREIGN KEY (editorId) REFERENCES administration_users (id)
|
||||
CONSTRAINT FK_EditorId FOREIGN KEY (editorId) REFERENCES administration_auth_users (id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS administration_users_history
|
||||
CREATE TABLE IF NOT EXISTS administration_auth_users_history
|
||||
(
|
||||
id INT NOT NULL,
|
||||
keycloakId CHAR(36) NOT NULL,
|
||||
@@ -23,22 +23,22 @@ CREATE TABLE IF NOT EXISTS administration_users_history
|
||||
updated TIMESTAMP NOT NULL
|
||||
);
|
||||
|
||||
CREATE TRIGGER TR_administration_usersUpdate
|
||||
CREATE TRIGGER TR_administration_auth_usersUpdate
|
||||
AFTER UPDATE
|
||||
ON administration_users
|
||||
ON administration_auth_users
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
INSERT INTO administration_users_history
|
||||
INSERT INTO administration_auth_users_history
|
||||
(id, keycloakId, deleted, editorId, created, updated)
|
||||
VALUES (OLD.id, OLD.keycloakId, OLD.deleted, OLD.editorId, OLD.created, NOW());
|
||||
END;
|
||||
|
||||
CREATE TRIGGER TR_administration_usersDelete
|
||||
CREATE TRIGGER TR_administration_auth_usersDelete
|
||||
AFTER DELETE
|
||||
ON administration_users
|
||||
ON administration_auth_users
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
INSERT INTO administration_users_history
|
||||
INSERT INTO administration_auth_users_history
|
||||
(id, keycloakId, deleted, editorId, created, updated)
|
||||
VALUES (OLD.id, OLD.keycloakId, 1, OLD.editorId, OLD.created, NOW());
|
||||
END;
|
||||
@@ -10,7 +10,7 @@ CREATE TABLE IF NOT EXISTS administration_api_keys
|
||||
|
||||
CONSTRAINT UC_Identifier_Key UNIQUE (identifier, keyString),
|
||||
CONSTRAINT UC_Key UNIQUE (keyString),
|
||||
CONSTRAINT FK_ApiKeys_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id)
|
||||
CONSTRAINT FK_ApiKeys_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS administration_api_keys_history
|
||||
|
||||
@@ -8,7 +8,7 @@ CREATE TABLE IF NOT EXISTS permission_permissions
|
||||
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
CONSTRAINT UQ_PermissionName UNIQUE (name),
|
||||
CONSTRAINT FK_Permissions_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id)
|
||||
CONSTRAINT FK_Permissions_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS permission_permissions_history
|
||||
@@ -52,7 +52,7 @@ CREATE TABLE IF NOT EXISTS permission_roles
|
||||
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
CONSTRAINT UQ_RoleName UNIQUE (name),
|
||||
CONSTRAINT FK_Roles_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id)
|
||||
CONSTRAINT FK_Roles_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS permission_roles_history
|
||||
@@ -89,22 +89,22 @@ END;
|
||||
CREATE TABLE IF NOT EXISTS permission_role_permissions
|
||||
(
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
roleId INT NOT NULL,
|
||||
RoleId INT NOT NULL,
|
||||
permissionId INT NOT NULL,
|
||||
deleted BOOL NOT NULL DEFAULT FALSE,
|
||||
editorId INT NULL,
|
||||
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
CONSTRAINT UQ_RolePermission UNIQUE (roleId, permissionId),
|
||||
CONSTRAINT FK_RolePermissions_Role FOREIGN KEY (roleId) REFERENCES permission_roles (id) ON DELETE CASCADE,
|
||||
CONSTRAINT UQ_RolePermission UNIQUE (RoleId, permissionId),
|
||||
CONSTRAINT FK_RolePermissions_Role FOREIGN KEY (RoleId) REFERENCES permission_roles (id) ON DELETE CASCADE,
|
||||
CONSTRAINT FK_RolePermissions_Permission FOREIGN KEY (permissionId) REFERENCES permission_permissions (id) ON DELETE CASCADE,
|
||||
CONSTRAINT FK_RolePermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id)
|
||||
CONSTRAINT FK_RolePermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS permission_role_permissions_history
|
||||
(
|
||||
id INT NOT NULL,
|
||||
roleId INT NOT NULL,
|
||||
RoleId INT NOT NULL,
|
||||
permissionId INT NOT NULL,
|
||||
deleted BOOL NOT NULL,
|
||||
editorId INT NULL,
|
||||
@@ -118,8 +118,8 @@ CREATE TRIGGER TR_RolePermissionsUpdate
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
INSERT INTO permission_role_permissions_history
|
||||
(id, roleId, permissionId, deleted, editorId, created, updated)
|
||||
VALUES (OLD.id, OLD.roleId, OLD.permissionId, OLD.deleted, OLD.editorId, OLD.created, NOW());
|
||||
(id, RoleId, permissionId, deleted, editorId, created, updated)
|
||||
VALUES (OLD.id, OLD.RoleId, OLD.permissionId, OLD.deleted, OLD.editorId, OLD.created, NOW());
|
||||
END;
|
||||
|
||||
CREATE TRIGGER TR_RolePermissionsDelete
|
||||
@@ -128,52 +128,52 @@ CREATE TRIGGER TR_RolePermissionsDelete
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
INSERT INTO permission_role_permissions_history
|
||||
(id, roleId, permissionId, deleted, editorId, created, updated)
|
||||
VALUES (OLD.id, OLD.roleId, OLD.permissionId, 1, OLD.editorId, OLD.created, NOW());
|
||||
(id, RoleId, permissionId, deleted, editorId, created, updated)
|
||||
VALUES (OLD.id, OLD.RoleId, OLD.permissionId, 1, OLD.editorId, OLD.created, NOW());
|
||||
END;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS permission_role_users
|
||||
CREATE TABLE IF NOT EXISTS permission_role_auth_users
|
||||
(
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
roleId INT NOT NULL,
|
||||
userId INT NOT NULL,
|
||||
RoleId INT NOT NULL,
|
||||
UserId INT NOT NULL,
|
||||
deleted BOOL NOT NULL DEFAULT FALSE,
|
||||
editorId INT NULL,
|
||||
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
CONSTRAINT UQ_RoleUser UNIQUE (roleId, userId),
|
||||
CONSTRAINT FK_Roleusers_Role FOREIGN KEY (roleId) REFERENCES permission_roles (id) ON DELETE CASCADE,
|
||||
CONSTRAINT FK_Roleusers_User FOREIGN KEY (userId) REFERENCES administration_users (id) ON DELETE CASCADE,
|
||||
CONSTRAINT FK_Roleusers_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id)
|
||||
CONSTRAINT UQ_RoleUser UNIQUE (RoleId, UserId),
|
||||
CONSTRAINT FK_Roleauth_users_Role FOREIGN KEY (RoleId) REFERENCES permission_roles (id) ON DELETE CASCADE,
|
||||
CONSTRAINT FK_Roleauth_users_User FOREIGN KEY (UserId) REFERENCES administration_auth_users (id) ON DELETE CASCADE,
|
||||
CONSTRAINT FK_Roleauth_users_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS permission_role_users_history
|
||||
CREATE TABLE IF NOT EXISTS permission_role_auth_users_history
|
||||
(
|
||||
id INT NOT NULL,
|
||||
roleId INT NOT NULL,
|
||||
userId INT NOT NULL,
|
||||
RoleId INT NOT NULL,
|
||||
UserId INT NOT NULL,
|
||||
deleted BOOL NOT NULL,
|
||||
editorId INT NULL,
|
||||
created TIMESTAMP NOT NULL,
|
||||
updated TIMESTAMP NOT NULL
|
||||
);
|
||||
|
||||
CREATE TRIGGER TR_RoleusersUpdate
|
||||
CREATE TRIGGER TR_Roleauth_usersUpdate
|
||||
AFTER UPDATE
|
||||
ON permission_role_users
|
||||
ON permission_role_auth_users
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
INSERT INTO permission_role_users_history
|
||||
(id, roleId, userId, deleted, editorId, created, updated)
|
||||
VALUES (OLD.id, OLD.roleId, OLD.userId, OLD.deleted, OLD.editorId, OLD.created, NOW());
|
||||
INSERT INTO permission_role_auth_users_history
|
||||
(id, RoleId, UserId, deleted, editorId, created, updated)
|
||||
VALUES (OLD.id, OLD.RoleId, OLD.UserId, OLD.deleted, OLD.editorId, OLD.created, NOW());
|
||||
END;
|
||||
|
||||
CREATE TRIGGER TR_RoleusersDelete
|
||||
CREATE TRIGGER TR_Roleauth_usersDelete
|
||||
AFTER DELETE
|
||||
ON permission_role_users
|
||||
ON permission_role_auth_users
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
INSERT INTO permission_role_users_history
|
||||
(id, roleId, userId, deleted, editorId, created, updated)
|
||||
VALUES (OLD.id, OLD.roleId, OLD.userId, 1, OLD.editorId, OLD.created, NOW());
|
||||
INSERT INTO permission_role_auth_users_history
|
||||
(id, RoleId, UserId, deleted, editorId, created, updated)
|
||||
VALUES (OLD.id, OLD.RoleId, OLD.UserId, 1, OLD.editorId, OLD.created, NOW());
|
||||
END;
|
||||
|
||||
@@ -10,7 +10,7 @@ CREATE TABLE IF NOT EXISTS permission_api_key_permissions
|
||||
CONSTRAINT UQ_ApiKeyPermission UNIQUE (apiKeyId, permissionId),
|
||||
CONSTRAINT FK_ApiKeyPermissions_ApiKey FOREIGN KEY (apiKeyId) REFERENCES administration_api_keys (id) ON DELETE CASCADE,
|
||||
CONSTRAINT FK_ApiKeyPermissions_Permission FOREIGN KEY (permissionId) REFERENCES permission_permissions (id) ON DELETE CASCADE,
|
||||
CONSTRAINT FK_ApiKeyPermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_users (id)
|
||||
CONSTRAINT FK_ApiKeyPermissions_Editor FOREIGN KEY (editorId) REFERENCES administration_auth_users (id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS permission_api_key_permissions_history
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
CREATE SCHEMA IF NOT EXISTS administration;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS administration.users
|
||||
CREATE TABLE IF NOT EXISTS administration.auth_users
|
||||
(
|
||||
id SERIAL PRIMARY KEY,
|
||||
keycloakId UUID NOT NULL,
|
||||
-- for history
|
||||
deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
editorId INT NULL REFERENCES administration.users (id),
|
||||
editorId INT NULL REFERENCES administration.auth_users (id),
|
||||
created timestamptz NOT NULL DEFAULT NOW(),
|
||||
updated timestamptz NOT NULL DEFAULT NOW(),
|
||||
|
||||
CONSTRAINT UC_KeycloakId UNIQUE (keycloakId)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS administration.users_history
|
||||
CREATE TABLE IF NOT EXISTS administration.auth_users_history
|
||||
(
|
||||
LIKE administration.users
|
||||
LIKE administration.auth_users
|
||||
);
|
||||
|
||||
CREATE TRIGGER users_history_trigger
|
||||
BEFORE INSERT OR UPDATE OR DELETE
|
||||
ON administration.users
|
||||
ON administration.auth_users
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION public.history_trigger_function();
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ CREATE TABLE IF NOT EXISTS administration.api_keys
|
||||
keyString VARCHAR(255) NOT NULL,
|
||||
-- for history
|
||||
deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
editorId INT NULL REFERENCES administration.users (id),
|
||||
editorId INT NULL REFERENCES administration.auth_users (id),
|
||||
created timestamptz NOT NULL DEFAULT NOW(),
|
||||
updated timestamptz NOT NULL DEFAULT NOW(),
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ CREATE TABLE permission.permissions
|
||||
|
||||
-- for history
|
||||
deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
editorId INT NULL REFERENCES administration.users (id),
|
||||
editorId INT NULL REFERENCES administration.auth_users (id),
|
||||
created timestamptz NOT NULL DEFAULT NOW(),
|
||||
updated timestamptz NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT UQ_PermissionName UNIQUE (name)
|
||||
@@ -35,7 +35,7 @@ CREATE TABLE permission.roles
|
||||
|
||||
-- for history
|
||||
deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
editorId INT NULL REFERENCES administration.users (id),
|
||||
editorId INT NULL REFERENCES administration.auth_users (id),
|
||||
created timestamptz NOT NULL DEFAULT NOW(),
|
||||
updated timestamptz NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT UQ_RoleName UNIQUE (name)
|
||||
@@ -61,7 +61,7 @@ CREATE TABLE permission.role_permissions
|
||||
|
||||
-- for history
|
||||
deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
editorId INT NULL REFERENCES administration.users (id),
|
||||
editorId INT NULL REFERENCES administration.auth_users (id),
|
||||
created timestamptz NOT NULL DEFAULT NOW(),
|
||||
updated timestamptz NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT UQ_RolePermission UNIQUE (RoleId, permissionId)
|
||||
@@ -83,11 +83,11 @@ CREATE TABLE permission.role_users
|
||||
(
|
||||
id SERIAL PRIMARY KEY,
|
||||
RoleId INT NOT NULL REFERENCES permission.roles (id) ON DELETE CASCADE,
|
||||
UserId INT NOT NULL REFERENCES administration.users (id) ON DELETE CASCADE,
|
||||
UserId INT NOT NULL REFERENCES administration.auth_users (id) ON DELETE CASCADE,
|
||||
|
||||
-- for history
|
||||
deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
editorId INT NULL REFERENCES administration.users (id),
|
||||
editorId INT NULL REFERENCES administration.auth_users (id),
|
||||
created timestamptz NOT NULL DEFAULT NOW(),
|
||||
updated timestamptz NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT UQ_RoleUser UNIQUE (RoleId, UserId)
|
||||
|
||||
@@ -6,7 +6,7 @@ CREATE TABLE permission.api_key_permissions
|
||||
|
||||
-- for history
|
||||
deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
editorId INT NULL REFERENCES administration.users (id),
|
||||
editorId INT NULL REFERENCES administration.auth_users (id),
|
||||
created timestamptz NOT NULL DEFAULT NOW(),
|
||||
updated timestamptz NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT UQ_ApiKeyPermission UNIQUE (apiKeyId, permissionId)
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
|
||||
from cpl.auth.schema._administration.user import User
|
||||
from cpl.dependency import get_provider
|
||||
from cpl.auth.auth_logger import AuthLogger
|
||||
from cpl.auth.schema._administration.auth_user import AuthUser
|
||||
|
||||
_user_context: ContextVar[Optional[User]] = ContextVar("user", default=None)
|
||||
_user_context: ContextVar[Optional[AuthUser]] = ContextVar("user", default=None)
|
||||
|
||||
_logger = AuthLogger(__name__)
|
||||
|
||||
|
||||
def set_user(user: Optional[User]):
|
||||
from cpl.core.log.logger_abc import LoggerABC
|
||||
|
||||
logger = get_provider().get_service(LoggerABC)
|
||||
logger.trace("Setting user context", user.id)
|
||||
_user_context.set(user)
|
||||
def set_user(user_id: Optional[AuthUser]):
|
||||
_logger.trace("Setting user context", user_id)
|
||||
_user_context.set(user_id)
|
||||
|
||||
|
||||
def get_user() -> Optional[User]:
|
||||
def get_user() -> Optional[AuthUser]:
|
||||
return _user_context.get()
|
||||
|
||||
@@ -3,25 +3,13 @@ import traceback
|
||||
from cpl.core.console import Console
|
||||
|
||||
|
||||
def dependency_error(src: str, package_name: str, e: ImportError = None) -> None:
|
||||
Console.error(f"'{package_name}' is required to use feature: {src}. Please install it and try again.")
|
||||
def dependency_error(package_name: str, e: ImportError) -> None:
|
||||
Console.error(f"'{package_name}' is required to use this feature. Please install it and try again.")
|
||||
tb = traceback.format_exc()
|
||||
if not tb.startswith("NoneType: None"):
|
||||
Console.error("->", tb)
|
||||
Console.write_line("->", tb)
|
||||
|
||||
elif e is not None:
|
||||
Console.error(f"-> {str(e)}")
|
||||
|
||||
exit(1)
|
||||
|
||||
|
||||
def module_dependency_error(src: str, module: str, e: ImportError = None) -> None:
|
||||
Console.error(f"'{module}' is required by '{src}'. Please initialize it with `add_module({module})`.")
|
||||
tb = traceback.format_exc()
|
||||
if not tb.startswith("NoneType: None"):
|
||||
Console.error("->", tb)
|
||||
|
||||
elif e is not None:
|
||||
Console.error(f"-> {str(e)}")
|
||||
Console.write_line("->", str(e))
|
||||
|
||||
exit(1)
|
||||
|
||||
@@ -2,4 +2,3 @@ from .logger import Logger
|
||||
from .logger_abc import LoggerABC
|
||||
from .log_level import LogLevel
|
||||
from .log_settings import LogSettings
|
||||
from .structured_logger import StructuredLogger
|
||||
|
||||
@@ -93,13 +93,14 @@ class Logger(LoggerABC):
|
||||
def _log(self, level: LogLevel, *messages: Messages):
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
formatted_message = self._format_message(level.value, timestamp, *messages)
|
||||
|
||||
self._write_log_to_file(level, self._file_format_message(level.value, timestamp, *messages))
|
||||
self._write_to_console(level, self._console_format_message(level.value, timestamp, *messages))
|
||||
self._write_log_to_file(level, formatted_message)
|
||||
self._write_to_console(level, formatted_message)
|
||||
except Exception as e:
|
||||
print(f"Error while logging: {e} -> {traceback.format_exc()}")
|
||||
|
||||
def _file_format_message(self, level: str, timestamp, *messages: Messages) -> str:
|
||||
def _format_message(self, level: str, timestamp, *messages: Messages) -> str:
|
||||
if isinstance(messages, tuple):
|
||||
messages = list(messages)
|
||||
|
||||
@@ -118,24 +119,6 @@ class Logger(LoggerABC):
|
||||
|
||||
return message
|
||||
|
||||
def _console_format_message(self, level: str, timestamp, *messages: Messages) -> str:
|
||||
if isinstance(messages, tuple):
|
||||
messages = list(messages)
|
||||
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
messages = [str(message) for message in messages if message is not None]
|
||||
|
||||
message = f"[{level.upper():^3}]"
|
||||
message += f" [{self._file_prefix}]"
|
||||
if self._source is not None:
|
||||
message += f" - [{self._source}]"
|
||||
|
||||
message += f": {' '.join(messages)}"
|
||||
|
||||
return message
|
||||
|
||||
def header(self, string: str):
|
||||
self._log(LogLevel.info, string)
|
||||
|
||||
|
||||
@@ -11,10 +11,7 @@ class LoggerABC(ABC):
|
||||
def set_level(self, level: LogLevel): ...
|
||||
|
||||
@abstractmethod
|
||||
def _file_format_message(self, level: str, timestamp, *messages: Messages) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def _console_format_message(self, level: str, timestamp, *messages: Messages) -> str: ...
|
||||
def _format_message(self, level: str, timestamp, *messages: Messages) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def header(self, string: str):
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
from cpl.core.log.logger import Logger
|
||||
from cpl.core.typing import Source, Messages
|
||||
from cpl.dependency.context import get_provider
|
||||
|
||||
|
||||
class StructuredLogger(Logger):
|
||||
|
||||
def __init__(self, source: Source, file_prefix: str = None):
|
||||
Logger.__init__(self, source, file_prefix)
|
||||
|
||||
@property
|
||||
def log_file(self):
|
||||
return f"logs/{self._file_prefix}_{datetime.now().strftime('%Y-%m-%d')}.jsonl"
|
||||
|
||||
def _file_format_message(self, level: str, timestamp: str, *messages: Messages) -> str:
|
||||
structured_message = {
|
||||
"timestamp": timestamp,
|
||||
"level": level.upper(),
|
||||
"source": self._source,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
self._enrich_message_with_request(structured_message)
|
||||
self._enrich_message_with_user(structured_message)
|
||||
|
||||
return json.dumps(structured_message, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def _scope_to_json(request: Request, include_headers: bool = False) -> dict:
|
||||
scope = dict(request.scope)
|
||||
|
||||
def convert(value):
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8")
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [convert(v) for v in value]
|
||||
if isinstance(value, dict):
|
||||
return {str(k): convert(v) for k, v in value.items()}
|
||||
if not isinstance(value, (str, int, float, bool, type(None))):
|
||||
return str(value)
|
||||
return value
|
||||
|
||||
serializable_scope = {str(k): convert(v) for k, v in scope.items()}
|
||||
|
||||
if not include_headers and "headers" in serializable_scope:
|
||||
serializable_scope["headers"] = "<omitted>"
|
||||
|
||||
return serializable_scope
|
||||
|
||||
def _enrich_message_with_request(self, message: dict):
|
||||
if importlib.util.find_spec("cpl.api") is None:
|
||||
return
|
||||
|
||||
from cpl.api.middleware.request import get_request
|
||||
from starlette.requests import Request
|
||||
|
||||
request = get_request()
|
||||
|
||||
if request is None:
|
||||
return
|
||||
|
||||
message["request"] = {
|
||||
"url": str(request.url),
|
||||
"method": request.method if request.scope == "http" else "websocket",
|
||||
"scope": self._scope_to_json(request),
|
||||
}
|
||||
if isinstance(request, Request) and request.scope == "http":
|
||||
request: Request = request # fix typing for IDEs
|
||||
|
||||
message["request"]["data"] = asyncio.create_task(request.body())
|
||||
|
||||
@staticmethod
|
||||
def _enrich_message_with_user(message: dict):
|
||||
if importlib.util.find_spec("cpl-auth") is None:
|
||||
return
|
||||
|
||||
from cpl.core.ctx import get_user
|
||||
|
||||
user = get_user()
|
||||
if user is None:
|
||||
return
|
||||
|
||||
from cpl.auth.keycloak.keycloak_admin import KeycloakAdmin
|
||||
|
||||
keycloak = get_provider().get_service(KeycloakAdmin)
|
||||
kc_user = keycloak.get_user(user.keycloak_id)
|
||||
message["user"] = {
|
||||
"id": str(user.id),
|
||||
"username": kc_user.get("username"),
|
||||
"email": kc_user.get("email"),
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
import inspect
|
||||
from typing import Type
|
||||
|
||||
from cpl.core.log import LoggerABC, LogLevel, StructuredLogger
|
||||
from cpl.core.typing import Messages
|
||||
from cpl.dependency.inject import inject
|
||||
from cpl.dependency.service_provider import ServiceProvider
|
||||
|
||||
|
||||
class WrappedLogger(LoggerABC):
|
||||
|
||||
def __init__(self, file_prefix: str):
|
||||
LoggerABC.__init__(self)
|
||||
assert file_prefix is not None and file_prefix != "", "file_prefix must be a non-empty string"
|
||||
|
||||
self._source = None
|
||||
self._file_prefix = file_prefix
|
||||
|
||||
self._set_logger()
|
||||
|
||||
@inject
|
||||
def _set_logger(self, services: ServiceProvider):
|
||||
from cpl.core.log import Logger
|
||||
|
||||
t_logger: Type[Logger] = services.get_service_type(LoggerABC)
|
||||
if t_logger is None:
|
||||
raise Exception("No LoggerABC service registered in ServiceProvider")
|
||||
|
||||
self._logger = t_logger(self._source, self._file_prefix)
|
||||
|
||||
def set_level(self, level: LogLevel):
|
||||
self._logger.set_level(level)
|
||||
|
||||
def _file_format_message(self, level: str, timestamp, *messages: Messages) -> str:
|
||||
return self._logger._file_format_message(level, timestamp, *messages)
|
||||
|
||||
def _console_format_message(self, level: str, timestamp, *messages: Messages) -> str:
|
||||
return self._logger._console_format_message(level, timestamp, *messages)
|
||||
|
||||
@staticmethod
|
||||
def _get_source() -> str | None:
|
||||
stack = inspect.stack()
|
||||
if len(stack) <= 1:
|
||||
return None
|
||||
|
||||
from cpl.dependency import ServiceCollection
|
||||
|
||||
ignore_classes = [
|
||||
ServiceProvider,
|
||||
ServiceProvider.__subclasses__(),
|
||||
ServiceCollection,
|
||||
WrappedLogger,
|
||||
WrappedLogger.__subclasses__(),
|
||||
StructuredLogger,
|
||||
]
|
||||
|
||||
ignore_modules = [x.__module__ for x in ignore_classes if isinstance(x, type)]
|
||||
|
||||
for i, frame_info in enumerate(stack[1:]):
|
||||
module = inspect.getmodule(frame_info.frame)
|
||||
if module is None:
|
||||
continue
|
||||
|
||||
if module.__name__ in ignore_classes or module in ignore_classes:
|
||||
continue
|
||||
|
||||
if module in ignore_modules or module.__name__ in ignore_modules:
|
||||
continue
|
||||
|
||||
if module.__name__ != __name__:
|
||||
return module.__name__
|
||||
|
||||
return None
|
||||
|
||||
def _set_source(self):
|
||||
self._source = self._get_source()
|
||||
self._set_logger()
|
||||
|
||||
def header(self, string: str):
|
||||
self._set_source()
|
||||
self._logger.header(string)
|
||||
|
||||
def trace(self, *messages: Messages):
|
||||
self._set_source()
|
||||
self._logger.trace(*messages)
|
||||
|
||||
def debug(self, *messages: Messages):
|
||||
self._set_source()
|
||||
self._logger.debug(*messages)
|
||||
|
||||
def info(self, *messages: Messages):
|
||||
self._set_source()
|
||||
self._logger.info(*messages)
|
||||
|
||||
def warning(self, *messages: Messages):
|
||||
self._set_source()
|
||||
self._logger.warning(*messages)
|
||||
|
||||
def error(self, messages: str, e: Exception = None):
|
||||
self._set_source()
|
||||
self._logger.error(messages, e)
|
||||
|
||||
def fatal(self, messages: str, e: Exception = None):
|
||||
self._set_source()
|
||||
self._logger.fatal(messages, e)
|
||||
@@ -1,2 +1,2 @@
|
||||
from .time_format_settings import TimeFormatSettings
|
||||
from .cron import Cron
|
||||
from .time_format_settings_names_enum import TimeFormatSettingsNamesEnum
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
import croniter
|
||||
|
||||
|
||||
class Cron:
|
||||
def __init__(self, cron_expression: str, start_time: datetime = None):
|
||||
self._cron_expression = cron_expression
|
||||
self._start_time = start_time or datetime.now()
|
||||
self._iter = croniter.croniter(cron_expression, self._start_time)
|
||||
|
||||
def next(self) -> datetime:
|
||||
return self._iter.get_next(datetime)
|
||||
@@ -13,7 +13,7 @@ class TimeFormatSettings(ConfigurationModelABC):
|
||||
date_time_format: str = None,
|
||||
date_time_log_format: str = None,
|
||||
):
|
||||
ConfigurationModelABC.__init__(self, readonly=False)
|
||||
ConfigurationModelABC.__init__(self)
|
||||
self._date_format: Optional[str] = date_format
|
||||
self._time_format: Optional[str] = time_format
|
||||
self._date_time_format: Optional[str] = date_time_format
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TimeFormatSettingsNamesEnum(Enum):
|
||||
date_format = "DateFormat"
|
||||
time_format = "TimeFormat"
|
||||
date_time_format = "DateTimeFormat"
|
||||
date_time_log_format = "DateTimeLogFormat"
|
||||
@@ -14,4 +14,3 @@ UuidId = str | UUID
|
||||
SerialId = int
|
||||
|
||||
Id = UuidId | SerialId
|
||||
TNumber = int | float | complex
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
import time
|
||||
import tracemalloc
|
||||
from typing import List, Callable
|
||||
|
||||
from cpl.core.console import Console
|
||||
|
||||
|
||||
class Benchmark:
|
||||
|
||||
@staticmethod
|
||||
def all(label: str, func: Callable, iterations: int = 5):
|
||||
times: List[float] = []
|
||||
mems: List[float] = []
|
||||
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
func()
|
||||
end = time.perf_counter()
|
||||
times.append(end - start)
|
||||
|
||||
for _ in range(iterations):
|
||||
tracemalloc.start()
|
||||
func()
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
mems.append(peak)
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
avg_mem = sum(mems) / len(mems) / (1024 * 1024)
|
||||
Console.write_line(f"{label:20s} -> min {min(times):.6f}s avg {avg_time:.6f}s mem {avg_mem:.8f} MB")
|
||||
|
||||
@staticmethod
|
||||
def time(label: str, func: Callable, iterations: int = 5):
|
||||
times: List[float] = []
|
||||
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
func()
|
||||
end = time.perf_counter()
|
||||
times.append(end - start)
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
Console.write_line(f"{label:20s} -> min {min(times):.6f}s avg {avg_time:.6f}s")
|
||||
|
||||
@staticmethod
|
||||
def memory(label: str, func: Callable, iterations: int = 5):
|
||||
mems: List[float] = []
|
||||
|
||||
for _ in range(iterations):
|
||||
tracemalloc.start()
|
||||
func()
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
mems.append(peak)
|
||||
|
||||
avg_mem = sum(mems) / len(mems) / (1024 * 1024)
|
||||
Console.write_line(f"{label:20s} -> mem {avg_mem:.2f} MB")
|
||||
@@ -1,100 +0,0 @@
|
||||
import threading
|
||||
import time
|
||||
from typing import Generic
|
||||
|
||||
from cpl.core.typing import T
|
||||
|
||||
|
||||
class Cache(Generic[T]):
|
||||
def __init__(self, default_ttl: int = None, cleanup_interval: int = 60, t: type = None):
|
||||
self._store = {}
|
||||
self._default_ttl = default_ttl
|
||||
self._lock = threading.Lock()
|
||||
self._cleanup_interval = cleanup_interval
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
self._type = t
|
||||
|
||||
# Start background cleanup thread
|
||||
self._thread = threading.Thread(target=self._auto_cleanup, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def set(self, key: str, value: T, ttl: int = None) -> None:
|
||||
"""Store a value in the cache with optional TTL override."""
|
||||
expire_at = None
|
||||
ttl = ttl if ttl is not None else self._default_ttl
|
||||
if ttl is not None:
|
||||
expire_at = time.time() + ttl
|
||||
|
||||
with self._lock:
|
||||
self._store[key] = (value, expire_at)
|
||||
|
||||
def get(self, key: str) -> T | None:
|
||||
"""Retrieve a value from the cache if not expired."""
|
||||
with self._lock:
|
||||
item = self._store.get(key)
|
||||
if not item:
|
||||
return None
|
||||
value, expire_at = item
|
||||
if expire_at and expire_at < time.time():
|
||||
# Expired -> remove and return None
|
||||
del self._store[key]
|
||||
return None
|
||||
return value
|
||||
|
||||
def get_all(self) -> list[T]:
|
||||
"""Retrieve all non-expired values from the cache."""
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
valid_items = []
|
||||
expired_keys = []
|
||||
for k, (v, exp) in self._store.items():
|
||||
if exp and exp < now:
|
||||
expired_keys.append(k)
|
||||
else:
|
||||
valid_items.append(v)
|
||||
for k in expired_keys:
|
||||
del self._store[k]
|
||||
return valid_items
|
||||
|
||||
def has(self, key: str) -> bool:
|
||||
"""Check if a key exists and is not expired."""
|
||||
with self._lock:
|
||||
item = self._store.get(key)
|
||||
if not item:
|
||||
return False
|
||||
_, expire_at = item
|
||||
if expire_at and expire_at < time.time():
|
||||
# Expired -> remove and return False
|
||||
del self._store[key]
|
||||
return False
|
||||
return True
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
"""Remove an item from the cache."""
|
||||
with self._lock:
|
||||
self._store.pop(key, None)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the entire cache."""
|
||||
with self._lock:
|
||||
self._store.clear()
|
||||
|
||||
def _auto_cleanup(self):
|
||||
"""Background thread to clean expired items."""
|
||||
while not self._stop_event.is_set():
|
||||
self.cleanup()
|
||||
self._stop_event.wait(self._cleanup_interval)
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Remove expired items immediately."""
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
expired_keys = [k for k, (_, exp) in self._store.items() if exp and exp < now]
|
||||
for k in expired_keys:
|
||||
del self._store[k]
|
||||
|
||||
def stop(self):
|
||||
"""Stop the background cleanup thread."""
|
||||
self._stop_event.set()
|
||||
self._thread.join()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user