WIP: dev into master #184

Draft
edraft wants to merge 121 commits from dev into master
20 changed files with 202 additions and 126 deletions
Showing only changes of commit 6de4f3c03a - Show all commits

View File

View File

@@ -0,0 +1,15 @@
from abc import ABC, abstractmethod
from starlette.types import Scope, Receive, Send
class ASGIMiddleware(ABC):
@abstractmethod
def __init__(self, app):
self._app = app
def _call_next(self, scope: Scope, receive: Receive, send: Send):
return self._app(scope, receive, send)
@abstractmethod
async def __call__(self, scope: Scope, receive: Receive, send: Send): ...

View File

@@ -1,11 +1,17 @@
from http.client import HTTPException
from starlette.responses import JSONResponse
from starlette.types import Scope, Receive, Send
class APIError(HTTPException):
status_code = 500
@classmethod
async def asgi_response(cls, scope: Scope, receive: Receive, send: Send):
r = JSONResponse({"error": cls.__name__}, status_code=cls.status_code)
return await r(scope, receive, send)
@classmethod
def response(cls):
return JSONResponse({"error": cls.__name__}, status_code=cls.status_code)

View File

@@ -1,21 +1,28 @@
from keycloak import KeycloakAuthenticationError
from starlette.types import Scope, Receive, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.api_logger import APILogger
from cpl.api.error import Unauthorized
from cpl.api.middleware.request import get_request
from cpl.api.router import Router
from cpl.auth.keycloak import KeycloakClient
from cpl.auth.schema import AuthUserDao, AuthUser
from cpl.dependency import ServiceProviderABC
_logger = APILogger(__name__)
class AuthenticationMiddleware:
class AuthenticationMiddleware(ASGIMiddleware):
def __init__(self, app):
self._app = app
@ServiceProviderABC.inject
def __init__(self, app, keycloak: KeycloakClient, user_dao: AuthUserDao):
ASGIMiddleware.__init__(self, app)
async def __call__(self, scope, receive, send):
self._keycloak = keycloak
self._user_dao = user_dao
async def __call__(self, scope: Scope, receive: Receive, send: Send):
request = get_request()
url = request.url.path
@@ -25,27 +32,41 @@ class AuthenticationMiddleware:
if not request.headers.get("Authorization"):
_logger.debug(f"Unauthorized access to {url}, missing Authorization header")
return Unauthorized(f"Missing header Authorization").response()
return await Unauthorized(f"Missing header Authorization").asgi_response(scope, receive, send)
auth_header = request.headers.get("Authorization", None)
if not auth_header or not auth_header.startswith("Bearer "):
return Unauthorized("Invalid Authorization header").response()
return await Unauthorized("Invalid Authorization header").asgi_response(scope, receive, send)
token = auth_header.split("Bearer ")[1]
if not await self._verify_login(token):
_logger.debug(f"Unauthorized access to {url}, invalid token")
return Unauthorized("Invalid token").response()
return await Unauthorized("Invalid token").asgi_response(scope, receive, send)
# check user exists in db, if not create
# unauthorized if user is deleted
return await self._app(scope, receive, send)
keycloak_id = self._keycloak.get_user_id(token)
if keycloak_id is None:
return await Unauthorized("Failed to get user id from token").asgi_response(scope, receive, send)
@classmethod
async def _verify_login(cls, token: str) -> bool:
keycloak = ServiceProviderABC.get_global_service(KeycloakClient)
user = await self._get_or_crate_user(keycloak_id)
if user.deleted:
_logger.debug(f"Unauthorized access to {url}, user is deleted")
return await Unauthorized("User is deleted").asgi_response(scope, receive, send)
return await self._call_next(scope, receive, send)
async def _get_or_crate_user(self, keycloak_id: str) -> AuthUser:
existing = await self._user_dao.find_by_keycloak_id(keycloak_id)
if existing is not None:
return existing
user = AuthUser(0, keycloak_id)
uid = await self._user_dao.create(user)
return await self._user_dao.get_by_id(uid)
async def _verify_login(self, token: str) -> bool:
try:
token_info = keycloak.introspect(token)
token_info = self._keycloak.introspect(token)
return token_info.get("active", False)
except KeycloakAuthenticationError as e:
_logger.debug(f"Keycloak authentication error: {e}")

View File

@@ -1,20 +1,22 @@
import time
from starlette.requests import Request
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import Receive, Scope, Send
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.api_logger import APILogger
from cpl.api.middleware.request import get_request
_logger = APILogger(__name__)
class LoggingMiddleware:
def __init__(self, app: ASGIApp):
self.app = app
class LoggingMiddleware(ASGIMiddleware):
def __init__(self, app):
ASGIMiddleware.__init__(self, app)
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
await self.app(scope, receive, send)
await self._call_next(scope, receive, send)
return
request = get_request()
@@ -32,7 +34,7 @@ class LoggingMiddleware:
response_body += message.get("body", b"")
await send(message)
await self.app(scope, receive, send_wrapper)
await self._call_next(scope, receive, send_wrapper)
duration = (time.time() - start_time) * 1000
await self._log_after_request(request, status_code, duration)

View File

@@ -4,8 +4,10 @@ from typing import Optional, Union
from uuid import uuid4
from starlette.requests import Request
from starlette.types import Scope, Receive, Send
from starlette.websockets import WebSocket
from cpl.api.abc.asgi_middleware_abc import ASGIMiddleware
from cpl.api.api_logger import APILogger
from cpl.api.typing import TRequest
@@ -14,12 +16,13 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
_logger = APILogger(__name__)
class RequestMiddleware:
class RequestMiddleware(ASGIMiddleware):
def __init__(self, app):
self._app = app
ASGIMiddleware.__init__(self, app)
self._ctx_token = None
async def __call__(self, scope, receive, send):
async def __call__(self, scope: Scope, receive: Receive, send: Send):
request = Request(scope, receive, send)
await self.set_request_data(request)

View File

@@ -66,6 +66,10 @@ class WebApp(ApplicationABC):
_logger.debug(f"Allowed origins: {origins}")
return origins.split(",")
def with_database(self):
self.with_migrations()
self.with_seeders()
def with_app(self, app: Starlette):
assert app is not None, "app must not be None"
assert isinstance(app, Starlette), "app must be an instance of Starlette"
@@ -132,7 +136,7 @@ class WebApp(ApplicationABC):
def with_authorization(self):
pass
def main(self):
async def main(self):
_logger.debug(f"Preparing API")
if self._app is None:
routes = [
@@ -162,10 +166,22 @@ class WebApp(ApplicationABC):
app = self._app
_logger.info(f"Start API on {self._api_settings.host}:{self._api_settings.port}")
uvicorn.run(
# uvicorn.run(
# app,
# host=self._api_settings.host,
# port=self._api_settings.port,
# log_config=None,
# loop="asyncio"
# )
config = uvicorn.Config(
app,
host=self._api_settings.host,
port=self._api_settings.port,
log_config=None,
loop="asyncio"
)
server = uvicorn.Server(config)
await server.serve()
_logger.info("Shutdown API")

View File

@@ -88,7 +88,7 @@ class ApplicationABC(ABC):
try:
Host.run(self.main)
except KeyboardInterrupt:
Console.close()
pass
@abstractmethod
def main(self): ...

View File

@@ -1,3 +1,5 @@
from typing import Optional
from keycloak import KeycloakOpenID
from cpl.auth.auth_logger import AuthLogger
@@ -17,3 +19,7 @@ class KeycloakClient(KeycloakOpenID):
client_secret_key=settings.client_secret,
)
_logger.info("Initializing Keycloak client")
def get_user_id(self, token: str) -> Optional[str]:
info = self.introspect(token)
return info.get("sub", None)

View File

@@ -16,7 +16,7 @@ class AuthUserDao(DbModelDaoABC[AuthUser]):
def __init__(self):
DbModelDaoABC.__init__(self, __name__, AuthUser, TableManager.get("auth_users"))
self.attribute(AuthUser.keycloak_id, str, aliases=["keycloakId"])
self.attribute(AuthUser.keycloak_id, str, db_name="keycloakId")
async def get_users():
return [(x.id, x.username, x.email) for x in await self.get_all()]

View File

@@ -14,7 +14,7 @@ CREATE TABLE IF NOT EXISTS administration_auth_users
CREATE TABLE IF NOT EXISTS administration_auth_users_history
(
id INT AUTO_INCREMENT PRIMARY KEY,
id INT NOT NULL,
keycloakId CHAR(36) NOT NULL,
-- for history
deleted BOOL NOT NULL,

View File

@@ -15,7 +15,7 @@ CREATE TABLE IF NOT EXISTS administration_api_keys
CREATE TABLE IF NOT EXISTS administration_api_keys_history
(
id INT AUTO_INCREMENT PRIMARY KEY,
id INT NOT NULL,
identifier VARCHAR(255) NOT NULL,
keyString VARCHAR(255) NOT NULL,
deleted BOOL NOT NULL,

View File

@@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS permission_permissions
CREATE TABLE IF NOT EXISTS permission_permissions_history
(
id INT AUTO_INCREMENT PRIMARY KEY,
id INT NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT NULL,
deleted BOOL NOT NULL,
@@ -57,7 +57,7 @@ CREATE TABLE IF NOT EXISTS permission_roles
CREATE TABLE IF NOT EXISTS permission_roles_history
(
id INT AUTO_INCREMENT PRIMARY KEY,
id INT NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT NULL,
deleted BOOL NOT NULL,
@@ -103,7 +103,7 @@ CREATE TABLE IF NOT EXISTS permission_role_permissions
CREATE TABLE IF NOT EXISTS permission_role_permissions_history
(
id INT AUTO_INCREMENT PRIMARY KEY,
id INT NOT NULL,
RoleId INT NOT NULL,
permissionId INT NOT NULL,
deleted BOOL NOT NULL,
@@ -149,7 +149,7 @@ CREATE TABLE IF NOT EXISTS permission_role_auth_users
CREATE TABLE IF NOT EXISTS permission_role_auth_users_history
(
id INT AUTO_INCREMENT PRIMARY KEY,
id INT NOT NULL,
RoleId INT NOT NULL,
UserId INT NOT NULL,
deleted BOOL NOT NULL,

View File

@@ -15,7 +15,7 @@ CREATE TABLE IF NOT EXISTS permission_api_key_permissions
CREATE TABLE IF NOT EXISTS permission_api_key_permissions_history
(
id INT AUTO_INCREMENT PRIMARY KEY,
id INT NOT NULL,
apiKeyId INT NOT NULL,
permissionId INT NOT NULL,
deleted BOOL NOT NULL,

View File

@@ -1,3 +1,4 @@
import os
from typing import Type
from cpl.application.abc import ApplicationABC as _ApplicationABC
@@ -7,13 +8,19 @@ from . import postgres as _postgres
from .table_manager import TableManager
def _with_migrations(self: _ApplicationABC, *paths: list[str]) -> _ApplicationABC:
def _with_migrations(self: _ApplicationABC, *paths: str | list[str]) -> _ApplicationABC:
from cpl.application.host import Host
from cpl.database.service.migration_service import MigrationService
migration_service = self._services.get_service(MigrationService)
migration_service.with_directory("./scripts")
migration_service.with_directory(os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts"))
if isinstance(paths, str):
paths = [paths]
for path in paths:
migration_service.with_directory(path)
Host.run(migration_service.migrate)
return self

View File

@@ -156,12 +156,15 @@ class DataAccessObjectABC(ABC, Generic[T_DBM]):
:param dict result: Result from the database
:return:
"""
value_map: dict[str, T] = {}
value_map: dict[str, Any] = {}
db_names = self.__db_names.items()
for db_name, value in result.items():
# Find the attribute name corresponding to the db_name
attr_name = next((k for k, v in self.__db_names.items() if v == db_name), None)
if attr_name:
attr_name = next((k for k, v in db_names if v == db_name), None)
if not attr_name:
continue
value_map[attr_name] = self._get_value_from_sql(self.__attributes[attr_name], value)
return self._model_type(**value_map)

View File

@@ -1,7 +1,7 @@
from typing import Optional, Any
import sqlparse
import aiomysql
from mysql.connector.aio import MySQLConnectionPool
from cpl.core.environment import Environment
from cpl.database.db_logger import DBLogger
@@ -9,97 +9,83 @@ from cpl.database.model import DatabaseSettings
_logger = DBLogger(__name__)
class MySQLPool:
"""
Create a pool when connecting to MySQL, which will decrease the time spent in
requesting connection, creating connection, and closing connection.
"""
def __init__(self, database_settings: DatabaseSettings):
self._db_settings = database_settings
self.pool: Optional[aiomysql.Pool] = None
self._dbconfig = {
"host": database_settings.host,
"port": database_settings.port,
"user": database_settings.user,
"password": database_settings.password,
"database": database_settings.database,
"ssl_disabled": True,
}
self._pool: Optional[MySQLConnectionPool] = None
async def _get_pool(self):
if self.pool is None or self.pool._closed:
try:
self.pool = await aiomysql.create_pool(
host=self._db_settings.host,
port=self._db_settings.port,
user=self._db_settings.user,
password=self._db_settings.password,
db=self._db_settings.database,
minsize=1,
maxsize=Environment.get("DB_POOL_SIZE", int, 1),
if self._pool is None:
self._pool = MySQLConnectionPool(
pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig
)
await self._pool.initialize_pool()
con = await self._pool.get_connection()
try:
async with await con.cursor() as cursor:
await cursor.execute("SELECT 1")
await cursor.fetchall()
except Exception as e:
_logger.fatal("Failed to connect to the database", e)
raise
return self.pool
_logger.fatal(f"Error connecting to the database: {e}")
finally:
await con.close()
return self._pool
@staticmethod
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):
result = []
if multi:
queries = [str(stmt).strip() for stmt in sqlparse.parse(query) if str(stmt).strip()]
for q in queries:
if q.strip() == "":
continue
await cursor.execute(q, args)
if cursor.description is not None:
result = await cursor.fetchall()
else:
await cursor.execute(query, args)
if cursor.description is not None:
result = await cursor.fetchall()
return result
async def execute(self, query: str, args=None, multi=True) -> list[list]:
"""
Execute a SQL statement, it could be with args and without args. The usage is
similar to the execute() function in aiomysql.
:param query: SQL clause
:param args: args needed by the SQL clause
:param multi: if the query is a multi-statement
:return: return result
"""
pool = await self._get_pool()
async with pool.acquire() as con:
async with con.cursor() as cursor:
await self._exec_sql(cursor, query, args, multi)
con = await pool.get_connection()
try:
async with await con.cursor() as cursor:
result = await self._exec_sql(cursor, query, args, multi)
await con.commit()
if cursor.description is not None: # Query returns rows
res = await cursor.fetchall()
if res is None:
return []
return [list(row) for row in res]
else:
return []
return result
finally:
await con.close()
async def select(self, query: str, args=None, multi=True) -> list[str]:
"""
Execute a SQL statement, it could be with args and without args. The usage is
similar to the execute() function in aiomysql.
:param query: SQL clause
:param args: args needed by the SQL clause
:param multi: if the query is a multi-statement
:return: return result
"""
pool = await self._get_pool()
async with pool.acquire() as con:
async with con.cursor() as cursor:
await self._exec_sql(cursor, query, args, multi)
res = await cursor.fetchall()
con = await pool.get_connection()
try:
async with await con.cursor() as cursor:
res = await self._exec_sql(cursor, query, args, multi)
return list(res)
finally:
await con.close()
async def select_map(self, query: str, args=None, multi=True) -> list[dict]:
"""
Execute a SQL statement, it could be with args and without args. The usage is
similar to the execute() function in aiomysql.
:param query: SQL clause
:param args: args needed by the SQL clause
:param multi: if the query is a multi-statement
:return: return result
"""
pool = await self._get_pool()
async with pool.acquire() as con:
async with con.cursor(aiomysql.DictCursor) as cursor:
await self._exec_sql(cursor, query, args, multi)
res = await cursor.fetchall()
con = await pool.get_connection()
try:
async with await con.cursor(dictionary=True) as cursor:
res = await self._exec_sql(cursor, query, args, multi)
return list(res)
finally:
await con.close()

View File

@@ -25,10 +25,10 @@ class PostgresPool:
f"password={database_settings.password} "
f"dbname={database_settings.database}"
)
self.pool: Optional[AsyncConnectionPool] = None
self._pool: Optional[AsyncConnectionPool] = None
async def _get_pool(self):
if self._pool is None:
pool = AsyncConnectionPool(
conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1)
)
@@ -39,7 +39,9 @@ class PostgresPool:
except PoolTimeout as e:
await pool.close()
_logger.fatal(f"Failed to connect to the database", e)
return pool
self._pool = pool
return self._pool
@staticmethod
async def _exec_sql(cursor: Any, query: str, args=None, multi=True):

View File

@@ -114,14 +114,22 @@ class ServiceProviderABC(ABC):
if f is None:
return functools.partial(cls.inject)
if iscoroutinefunction(f):
@functools.wraps(f)
async def inner(*args, **kwargs):
async def async_inner(*args, **kwargs):
if cls._provider is None:
raise Exception(f"{cls.__name__} not build!")
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
if iscoroutinefunction(f):
return await f(*args, *injection, **kwargs)
return f(*args, *injection, **kwargs)
return async_inner
@functools.wraps(f)
def inner(*args, **kwargs):
if cls._provider is None:
raise Exception(f"{cls.__name__} not build!")
injection = [x for x in cls._provider._build_by_signature(signature(f)) if x is not None]
return f(*args, *injection, **kwargs)
return inner

View File

@@ -11,7 +11,6 @@ from service import PingService
def main():
builder = ApplicationBuilder[WebApp](WebApp)
Configuration.add_json_file(f"appsettings.json")
Configuration.add_json_file(f"appsettings.{Environment.get_environment()}.json")
Configuration.add_json_file(f"appsettings.{Environment.get_host_name()}.json", optional=True)
@@ -21,10 +20,12 @@ def main():
builder.services.add_module(api)
app = builder.build()
app.with_logging()
app.with_database()
app.with_authentication()
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET")
app.with_routes_directory("routes")
app.with_logging()
app.with_authentication()
app.run()