Changed to asgi
Some checks failed
Test before pr merge / test-before-merge (pull_request) Has been cancelled

This commit is contained in:
Sven Heidemann 2025-03-08 08:18:32 +01:00
parent fafa588880
commit ee7713a25e
30 changed files with 755 additions and 132 deletions

View File

@ -1,10 +1,11 @@
ariadne==0.23.0 ariadne==0.23.0
eventlet==0.37.0 broadcaster==0.3.1
graphql-core==3.2.5 graphql-core==3.2.5
Flask[async]==3.1.0
Flask-Cors==5.0.0
async-property==0.2.2 async-property==0.2.2
python-keycloak==4.7.3
psycopg[binary]==3.2.3 psycopg[binary]==3.2.3
psycopg-pool==3.2.4 psycopg-pool==3.2.4
Werkzeug==3.1.3 uvicorn==0.34.0
starlette==0.46.0
requests==2.32.3
python-keycloak==5.3.1
python-multipart==0.0.20

View File

@ -1,78 +1,45 @@
import importlib import importlib
import os import os
import time from typing import Optional
from uuid import uuid4
from flask import Flask, request, g from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from api.route import Route from core.environment import Environment
from core.logger import APILogger from core.logger import APILogger
app = Flask(__name__)
logger = APILogger(__name__) logger = APILogger(__name__)
def filter_relevant_headers(headers: dict) -> dict: class API:
relevant_keys = { app: Optional[Starlette] = None
"Content-Type",
"Host",
"Connection",
"User-Agent",
"Origin",
"Referer",
"Accept",
}
return {key: value for key, value in headers.items() if key in relevant_keys}
@classmethod
def create(cls, app: Starlette):
cls.app = app
@app.before_request @staticmethod
async def log_request(): async def handle_exception(request: Request, exc: Exception):
g.request_id = uuid4() logger.error(f"Request {request.state.request_id}", exc)
g.start_time = time.time() return JSONResponse({"error": str(exc)}, status_code=500)
logger.debug(
f"Request {g.request_id}: {request.method}@{request.path} from {request.remote_addr}"
)
user = await Route.get_user()
request_info = { @staticmethod
"headers": filter_relevant_headers(dict(request.headers)), def get_allowed_origins():
"args": request.args.to_dict(), client_urls = Environment.get("CLIENT_URLS", str)
"form-data": request.form.to_dict(), if client_urls is None or client_urls == "":
"payload": request.get_json(silent=True), allowed_origins = ["*"]
"user": f"{user.id}-{user.keycloak_id}" if user else None, logger.warning("No allowed origins specified, allowing all origins")
"files": ( else:
{key: file.filename for key, file in request.files.items()} allowed_origins = client_urls.split(",")
if request.files
else None
),
}
logger.trace(f"Request {g.request_id}: {request_info}") return allowed_origins
@staticmethod
@app.after_request def import_routes():
def log_after_request(response): # used to import all routes
# calc the time it took to process the request routes_dir = os.path.join(os.path.dirname(__file__), "routes")
duration = (time.time() - g.start_time) * 1000 for filename in os.listdir(routes_dir):
logger.info( if filename.endswith(".py") and filename != "__init__.py":
f"Request finished {g.request_id}: {response.status_code}-{request.method}@{request.path} from {request.remote_addr} in {duration:.2f}ms" module_name = f"api.routes.{filename[:-3]}"
) importlib.import_module(module_name)
return response
@app.errorhandler(Exception)
def handle_exception(e):
logger.error(f"Request {g.request_id}", e)
return {"error": str(e)}, 500
# used to import all routes
routes_dir = os.path.join(os.path.dirname(__file__), "routes")
for filename in os.listdir(routes_dir):
if filename.endswith(".py") and filename != "__init__.py":
module_name = f"api.routes.{filename[:-3]}"
importlib.import_module(module_name)
# Explicitly register the routes
for route, (view_func, options) in Route.registered_routes.items():
app.add_url_rule(route, view_func=view_func, **options)

5
api/src/api/broadcast.py Normal file
View File

@ -0,0 +1,5 @@
from typing import Optional
from broadcaster import Broadcast
broadcast: Optional[Broadcast] = Broadcast("memory://")

View File

@ -1,9 +1,9 @@
from flask import jsonify from starlette.responses import JSONResponse
def unauthorized(): def unauthorized():
return jsonify({"error": "Unauthorized"}), 401 return JSONResponse({"error": "Unauthorized"}, 401)
def forbidden(): def forbidden():
return jsonify({"error": "Unauthorized"}), 401 return JSONResponse({"error": "Unauthorized"}, 401)

View File

View File

@ -0,0 +1,73 @@
import time
from uuid import uuid4
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from api.route import Route
from core.logger import APILogger
logger = APILogger("api.api")
class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
await self._log_request(request)
response = await call_next(request)
await self._log_after_request(request, response)
return response
@staticmethod
def _filter_relevant_headers(headers: dict) -> dict:
relevant_keys = {
"content-type",
"host",
"connection",
"user-agent",
"origin",
"referer",
"accept",
}
return {key: value for key, value in headers.items() if key in relevant_keys}
@classmethod
async def _log_request(cls, request: Request):
request.state.request_id = uuid4()
request.state.start_time = time.time()
logger.debug(
f"Request {request.state.request_id}: {request.method}@{request.url.path} from {request.client.host}"
)
user = await Route.get_user()
request_info = {
"headers": cls._filter_relevant_headers(dict(request.headers)),
"args": dict(request.query_params),
"form-data": (
await request.form()
if request.headers.get("content-type")
== "application/x-www-form-urlencoded"
else None
),
"payload": (
await request.json()
if request.headers.get("content-length") == "0"
else None
),
"user": f"{user.id}-{user.keycloak_id}" if user else None,
"files": (
{key: file.filename for key, file in (await request.form()).items()}
if await request.form()
else None
),
}
logger.trace(f"Request {request.state.request_id}: {request_info}")
@staticmethod
async def _log_after_request(request: Request, response: Response):
duration = (time.time() - request.state.start_time) * 1000
logger.info(
f"Request finished {request.state.request_id}: {response.status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
)

View File

@ -0,0 +1,23 @@
from contextvars import ContextVar
from typing import Optional, Union
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
_request_context: ContextVar[Union[Request, None]] = ContextVar("request", default=None)
class RequestMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
_request_context.set(request)
from core.logger import APILogger
logger = APILogger(__name__)
logger.trace("Set new current request")
response = await call_next(request)
return response
def get_request() -> Optional[Request]:
return _request_context.get()

View File

@ -1,7 +1,8 @@
from uuid import uuid4 from uuid import uuid4
from flask import send_file from starlette.requests import Request
from werkzeug.exceptions import NotFound from starlette.responses import FileResponse
from starlette.exceptions import HTTPException
from api.route import Route from api.route import Route
from core.logger import APILogger from core.logger import APILogger
@ -9,19 +10,23 @@ from core.logger import APILogger
logger = APILogger(__name__) logger = APILogger(__name__)
@Route.get(f"/api/files/<path:file_path>") @Route.get("/api/files/{file_path:path}")
def get_file(file_path: str): async def get_file(request: Request):
file_path = request.path_params["file_path"]
name = file_path name = file_path
if "/" in file_path: if "/" in file_path:
name = file_path.split("/")[-1] name = file_path.split("/")[-1]
try: try:
return send_file( return FileResponse(
f"../files/{file_path}", download_name=name, as_attachment=True path=f"files/{file_path}",
filename=name,
media_type="application/octet-stream",
) )
except NotFound: except HTTPException as e:
return {"error": "File not found"}, 404 if e.status_code == 404:
except Exception as e: return {"error": "File not found"}, 404
error_id = uuid4() else:
logger.error(f"Error {error_id} getting file {file_path}", e) error_id = uuid4()
return {"error": f"File error. ErrorId: {error_id}"}, 500 logger.error(f"Error {error_id} getting file {file_path}", e)
return {"error": f"File error. ErrorId: {error_id}"}, 500

View File

@ -1,5 +1,6 @@
from ariadne import graphql from ariadne import graphql
from flask import request, jsonify from starlette.requests import Request
from starlette.responses import JSONResponse
from api.route import Route from api.route import Route
from api_graphql.service.schema import schema from api_graphql.service.schema import schema
@ -10,11 +11,11 @@ logger = Logger(__name__)
@Route.post(f"{BasePath}") @Route.post(f"{BasePath}")
async def graphql_endpoint(): async def graphql_endpoint(request: Request):
data = request.get_json() data = await request.json()
# Note: Passing the request to the context is optional. # Note: Passing the request to the context is optional.
# In Flask, the current request is always accessible as flask.request # In Starlette, the current request is accessible as request
success, result = await graphql(schema, data, context_value=request) success, result = await graphql(schema, data, context_value=request)
status_code = 200 status_code = 200
@ -24,4 +25,4 @@ async def graphql_endpoint():
] ]
status_code = max(status_codes, default=200) status_code = max(status_codes, default=200)
return jsonify(result), status_code return JSONResponse(result, status_code=status_code)

View File

@ -1,4 +1,6 @@
from ariadne.explorer import ExplorerPlayground from ariadne.explorer import ExplorerPlayground
from starlette.requests import Request
from starlette.responses import HTMLResponse
from api.route import Route from api.route import Route
from core.environment import Environment from core.environment import Environment
@ -10,7 +12,7 @@ logger = Logger(__name__)
@Route.get(f"{BasePath}/playground") @Route.get(f"{BasePath}/playground")
@Route.authorize(skip_in_dev=True) @Route.authorize(skip_in_dev=True)
async def playground(): async def playground(r: Request):
if Environment.get_environment() != "development": if Environment.get_environment() != "development":
return "", 403 return "", 403
@ -19,7 +21,6 @@ async def playground():
if dev_user: if dev_user:
request_global_headers = {f"Authorization": f"DEV-User {dev_user}"} request_global_headers = {f"Authorization": f"DEV-User {dev_user}"}
return ( return HTMLResponse(
ExplorerPlayground(request_global_headers=request_global_headers).html(None), ExplorerPlayground(request_global_headers=request_global_headers).html(None)
200,
) )

View File

@ -1,7 +1,16 @@
from starlette.requests import Request
from starlette.responses import JSONResponse
from api.route import Route from api.route import Route
from core.configuration.feature_flags import FeatureFlags
from core.configuration.feature_flags_enum import FeatureFlagsEnum
from version import VERSION from version import VERSION
@Route.get(f"/api/version") @Route.get(f"/api/version")
def version(): async def version(r: Request):
return VERSION feature = await FeatureFlags.has_feature(FeatureFlagsEnum.version_endpoint)
if not feature:
return JSONResponse("DISABLED", status_code=403)
return JSONResponse(VERSION)

View File

View File

@ -0,0 +1,20 @@
from core.configuration.feature_flags_enum import FeatureFlagsEnum
from data.schemas.system.feature_flag_dao import featureFlagDao
class FeatureFlags:
_flags = {
FeatureFlagsEnum.version_endpoint.value: True, # 15.01.2025
}
@staticmethod
def get_default(key: FeatureFlagsEnum) -> bool:
return FeatureFlags._flags[key.value]
@staticmethod
async def has_feature(key: FeatureFlagsEnum) -> bool:
value = await featureFlagDao.find_by_key(key.value)
if value is None:
return False
return value.value

View File

@ -0,0 +1,6 @@
from enum import Enum
class FeatureFlagsEnum(Enum):
# modules
version_endpoint = "VersionEndpoint"

View File

@ -1,8 +1,13 @@
import asyncio
import os import os
import traceback import traceback
from datetime import datetime from datetime import datetime
from api.middleware.request import get_request
from core.environment import Environment
class Logger: class Logger:
_level = "info" _level = "info"
_levels = ["trace", "debug", "info", "warning", "error", "fatal"] _levels = ["trace", "debug", "info", "warning", "error", "fatal"]
@ -54,6 +59,30 @@ class Logger:
else: else:
raise ValueError(f"Invalid log level: {level}") raise ValueError(f"Invalid log level: {level}")
def _get_structured_message(self, level: str, timestamp: str, messages: str) -> str:
structured_message = {
"timestamp": timestamp,
"level": level.upper(),
"source": self.source,
"messages": messages,
}
request = get_request()
if request is not None:
structured_message["request"] = {
"url": str(request.url),
"method": request.method,
"data": asyncio.create_task(request.body()),
}
return str(structured_message)
def _write_log_to_file(self, content: str):
self._ensure_file_size()
with open(self.log_file, "a") as log_file:
log_file.write(content + "\n")
log_file.close()
def _log(self, level: str, *messages): def _log(self, level: str, *messages):
try: try:
if self._levels.index(level) < self._levels.index(self._level): if self._levels.index(level) < self._levels.index(self._level):
@ -63,17 +92,18 @@ class Logger:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
formatted_message = f"<{timestamp}> [{level.upper():^7}] [{self._file_prefix:^5}] - [{self.source}]: {' '.join(messages)}" formatted_message = f"<{timestamp}> [{level.upper():^7}] [{self._file_prefix:^5}] - [{self.source}]: {' '.join(messages)}"
self._ensure_file_size() if Environment.get("STRUCTURED_LOGGING", bool, False):
with open(self.log_file, "a") as log_file: self._write_log_to_file(
log_file.write(formatted_message + "\n") self._get_structured_message(level, timestamp, " ".join(messages))
log_file.close() )
else:
self._write_log_to_file(formatted_message)
color = self.COLORS.get(level, self.COLORS["reset"]) print(
reset_color = self.COLORS["reset"] f"{self.COLORS.get(level, self.COLORS["reset"])}{formatted_message}{self.COLORS["reset"]}"
)
print(f"{color}{formatted_message}{reset_color}")
except Exception as e: except Exception as e:
print(f"Error while logging: {e}") print(f"Error while logging: {e} -> {traceback.format_exc()}")
def trace(self, *messages): def trace(self, *messages):
self._log("trace", *messages) self._log("trace", *messages)

2
api/src/core/string.py Normal file
View File

@ -0,0 +1,2 @@
def first_to_lower(s: str) -> str:
return s[0].lower() + s[1:] if s else s

View File

@ -0,0 +1,48 @@
from datetime import datetime
from typing import Optional, Union
from async_property import async_property
from core.database.abc.db_model_abc import DbModelABC
from core.typing import SerialId
class UserSetting(DbModelABC):
def __init__(
self,
id: SerialId,
user_id: SerialId,
key: str,
value: str,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._user_id = user_id
self._key = key
self._value = value
@property
def user_id(self) -> SerialId:
return self._user_id
@async_property
async def user(self):
from data.schemas.administration.user_dao import userDao
return await userDao.get_by_id(self._user_id)
@property
def key(self) -> str:
return self._key
@property
def value(self) -> str:
return self._value
@value.setter
def value(self, value: Union[str, int, float, bool]):
self._value = str(value)

View File

@ -0,0 +1,24 @@
from core.database.abc.db_model_dao_abc import DbModelDaoABC
from core.logger import DBLogger
from data.schemas.administration.user import User
from data.schemas.public.user_setting import UserSetting
logger = DBLogger(__name__)
class UserSettingDao(DbModelDaoABC[UserSetting]):
def __init__(self):
DbModelDaoABC.__init__(self, __name__, UserSetting, "public.user_settings")
self.attribute(UserSetting.user_id, int)
self.attribute(UserSetting.key, str)
self.attribute(UserSetting.value, str)
async def find_by_key(self, user: User, key: str) -> UserSetting:
return await self.find_single_by(
[{UserSetting.user_id: user.id}, {UserSetting.key: key}]
)
userSettingsDao = UserSettingDao()

View File

@ -0,0 +1,34 @@
from datetime import datetime
from typing import Optional
from core.database.abc.db_model_abc import DbModelABC
from core.typing import SerialId
class FeatureFlag(DbModelABC):
def __init__(
self,
id: SerialId,
key: str,
value: bool,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._key = key
self._value = value
@property
def key(self) -> str:
return self._key
@property
def value(self) -> bool:
return self._value
@value.setter
def value(self, value: bool):
self._value = value

View File

@ -0,0 +1,20 @@
from core.database.abc.db_model_dao_abc import DbModelDaoABC
from core.logger import DBLogger
from data.schemas.system.feature_flag import FeatureFlag
logger = DBLogger(__name__)
class FeatureFlagDao(DbModelDaoABC[FeatureFlag]):
def __init__(self):
DbModelDaoABC.__init__(self, __name__, FeatureFlag, "system.feature_flags")
self.attribute(FeatureFlag.key, str)
self.attribute(FeatureFlag.value, bool)
async def find_by_key(self, key: str) -> FeatureFlag:
return await self.find_single_by({FeatureFlag.key: key})
featureFlagDao = FeatureFlagDao()

View File

@ -0,0 +1,34 @@
from datetime import datetime
from typing import Optional, Union
from core.database.abc.db_model_abc import DbModelABC
from core.typing import SerialId
class Setting(DbModelABC):
def __init__(
self,
id: SerialId,
key: str,
value: str,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._key = key
self._value = value
@property
def key(self) -> str:
return self._key
@property
def value(self) -> str:
return self._value
@value.setter
def value(self, value: Union[str, int, float, bool]):
self._value = str(value)

View File

@ -0,0 +1,20 @@
from core.database.abc.db_model_dao_abc import DbModelDaoABC
from core.logger import DBLogger
from data.schemas.system.setting import Setting
logger = DBLogger(__name__)
class SettingDao(DbModelDaoABC[Setting]):
def __init__(self):
DbModelDaoABC.__init__(self, __name__, Setting, "system.settings")
self.attribute(Setting.key, str)
self.attribute(Setting.value, str)
async def find_by_key(self, key: str) -> Setting:
return await self.find_single_by({Setting.key: key})
settingsDao = SettingDao()

View File

@ -0,0 +1,24 @@
CREATE SCHEMA IF NOT EXISTS system;
CREATE TABLE IF NOT EXISTS system.settings
(
Id SERIAL PRIMARY KEY,
Key TEXT NOT NULL,
Value TEXT NOT NULL,
-- for history
Deleted BOOLEAN NOT NULL DEFAULT FALSE,
EditorId INT NULL REFERENCES administration.users (Id),
CreatedUtc timestamptz NOT NULL DEFAULT NOW(),
UpdatedUtc timestamptz NOT NULL DEFAULT NOW()
);
CREATE TABLE system.settings_history
(
LIKE system.settings
);
CREATE TRIGGER ip_list_history_trigger
BEFORE INSERT OR UPDATE OR DELETE
ON system.settings
FOR EACH ROW
EXECUTE FUNCTION public.history_trigger_function();

View File

@ -0,0 +1,24 @@
CREATE SCHEMA IF NOT EXISTS system;
CREATE TABLE IF NOT EXISTS system.feature_flags
(
Id SERIAL PRIMARY KEY,
Key TEXT NOT NULL,
Value BOOLEAN NOT NULL,
-- for history
Deleted BOOLEAN NOT NULL DEFAULT FALSE,
EditorId INT NULL REFERENCES administration.users (Id),
CreatedUtc timestamptz NOT NULL DEFAULT NOW(),
UpdatedUtc timestamptz NOT NULL DEFAULT NOW()
);
CREATE TABLE system.feature_flags_history
(
LIKE system.feature_flags
);
CREATE TRIGGER ip_list_history_trigger
BEFORE INSERT OR UPDATE OR DELETE
ON system.feature_flags
FOR EACH ROW
EXECUTE FUNCTION public.history_trigger_function();

View File

@ -0,0 +1,25 @@
CREATE SCHEMA IF NOT EXISTS public;
CREATE TABLE IF NOT EXISTS public.user_settings
(
Id SERIAL PRIMARY KEY,
Key TEXT NOT NULL,
Value TEXT NOT NULL,
UserId INT NOT NULL REFERENCES public.user_settings (Id) ON DELETE CASCADE,
-- for history
Deleted BOOLEAN NOT NULL DEFAULT FALSE,
EditorId INT NULL REFERENCES administration.users (Id),
CreatedUtc timestamptz NOT NULL DEFAULT NOW(),
UpdatedUtc timestamptz NOT NULL DEFAULT NOW()
);
CREATE TABLE public.user_settings_history
(
LIKE public.user_settings
);
CREATE TRIGGER ip_list_history_trigger
BEFORE INSERT OR UPDATE OR DELETE
ON public.user_settings
FOR EACH ROW
EXECUTE FUNCTION public.history_trigger_function();

View File

@ -0,0 +1,40 @@
from core.configuration.feature_flags import FeatureFlags
from core.configuration.feature_flags_enum import FeatureFlagsEnum
from core.logger import DBLogger
from data.abc.data_seeder_abc import DataSeederABC
from data.schemas.system.feature_flag import FeatureFlag
from data.schemas.system.feature_flag_dao import featureFlagDao
logger = DBLogger(__name__)
class FeatureFlagsSeeder(DataSeederABC):
def __init__(self):
DataSeederABC.__init__(self)
async def seed(self):
logger.info("Seeding feature flags")
feature_flags = await featureFlagDao.get_all()
feature_flag_keys = [x.key for x in feature_flags]
possible_feature_flags = {
x.value: FeatureFlags.get_default(x) for x in FeatureFlagsEnum
}
to_create = [
FeatureFlag(0, x, possible_feature_flags[x])
for x in possible_feature_flags.keys()
if x not in feature_flag_keys
]
if len(to_create) > 0:
await featureFlagDao.create_many(to_create)
to_create_dicts = {x.key: x.value for x in to_create}
logger.debug(f"Created feature flags: {to_create_dicts}")
to_delete = [
x for x in feature_flags if x.key not in possible_feature_flags.keys()
]
if len(to_delete) > 0:
await featureFlagDao.delete_many(to_delete, hard_delete=True)
to_delete_dicts = {x.key: x.value for x in to_delete}
logger.debug(f"Deleted feature flags: {to_delete_dicts}")

View File

@ -0,0 +1,25 @@
from typing import Any
from core.logger import DBLogger
from data.abc.data_seeder_abc import DataSeederABC
from data.schemas.system.setting import Setting
from data.schemas.system.setting_dao import settingsDao
logger = DBLogger(__name__)
class SettingsSeeder(DataSeederABC):
def __init__(self):
DataSeederABC.__init__(self)
async def seed(self):
await self._seed_if_not_exists("default_language", "de")
await self._seed_if_not_exists("show_terms", True)
@staticmethod
async def _seed_if_not_exists(key: str, value: Any):
existing = await settingsDao.find_by_key(key)
if existing is not None:
return
await settingsDao.create(Setting(0, key, str(value)))

View File

@ -1,10 +1,9 @@
import asyncio import asyncio
import sys import sys
import eventlet import uvicorn
from eventlet import wsgi
from api.api import app from api.api import API
from core.environment import Environment from core.environment import Environment
from core.logger import Logger from core.logger import Logger
from startup import Startup from startup import Startup
@ -18,15 +17,13 @@ def main():
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
loop = asyncio.new_event_loop() Startup.configure()
loop.run_until_complete(Startup.configure()) uvicorn.run(
loop.close() API.app,
host="0.0.0.0",
port = Environment.get("PORT", int, 5000) port=Environment.get("PORT", int, 5000),
logger.info(f"Start API on port: {port}") log_config=None,
if Environment.get_environment() == "development": )
logger.info(f"Playground: http://localhost:{port}/ui/playground")
wsgi.server(eventlet.listen(("0.0.0.0", port)), app, log_output=False)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,17 +1,30 @@
from flask_cors import CORS from contextlib import asynccontextmanager
from api.api import app from ariadne.asgi import GraphQL
from ariadne.asgi.handlers import GraphQLTransportWSHandler
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.routing import WebSocketRoute
from api.api import API
from api.auth.keycloak_client import Keycloak from api.auth.keycloak_client import Keycloak
from api.broadcast import broadcast
from api.middleware.logging import LoggingMiddleware
from api.middleware.request import RequestMiddleware
from api.route import Route
from api_graphql.service.schema import schema
from core.database.database import Database from core.database.database import Database
from core.database.database_settings import DatabaseSettings from core.database.database_settings import DatabaseSettings
from core.database.db_context import DBContext from core.database.db_context import DBContext
from core.environment import Environment from core.environment import Environment
from core.logger import Logger from core.logger import Logger
from data.seeder.api_key_seeder import ApiKeySeeder from data.seeder.api_key_seeder import ApiKeySeeder
from data.seeder.feature_flags_seeder import FeatureFlagsSeeder
from data.seeder.file_hash_seeder import FileHashSeeder from data.seeder.file_hash_seeder import FileHashSeeder
from data.seeder.permission_seeder import PermissionSeeder from data.seeder.permission_seeder import PermissionSeeder
from data.seeder.role_seeder import RoleSeeder from data.seeder.role_seeder import RoleSeeder
from data.seeder.short_url_seeder import ShortUrlSeeder from data.seeder.settings_seeder import SettingsSeeder
from data.service.migration_service import MigrationService from data.service.migration_service import MigrationService
from service.file_service import FileService from service.file_service import FileService
@ -19,15 +32,43 @@ logger = Logger(__name__)
class Startup: class Startup:
@classmethod
def _get_db_settings(cls):
host = Environment.get("DB_HOST", str)
port = Environment.get("DB_PORT", int)
user = Environment.get("DB_USER", str)
password = Environment.get("DB_PASSWORD", str)
database = Environment.get("DB_DATABASE", str)
if None in [host, port, user, password, database]:
logger.fatal(
"DB settings are not set correctly",
EnvironmentError("DB settings are not set correctly"),
)
return DatabaseSettings(
host=host, port=port, user=user, password=password, database=database
)
@classmethod
async def _startup_db(cls):
logger.info("Init DB")
db = DBContext()
await db.connect(cls._get_db_settings())
Database.init(db)
migrations = MigrationService(db)
await migrations.migrate()
@staticmethod @staticmethod
async def _seed_data(): async def _seed_data():
seeders = [ seeders = [
SettingsSeeder,
FeatureFlagsSeeder,
PermissionSeeder, PermissionSeeder,
RoleSeeder, RoleSeeder,
ApiKeySeeder, ApiKeySeeder,
FileHashSeeder, FileHashSeeder,
ShortUrlSeeder,
] ]
for seeder in [x() for x in seeders]: for seeder in [x() for x in seeders]:
await seeder.seed() await seeder.seed()
@ -38,22 +79,67 @@ class Startup:
Keycloak.init() Keycloak.init()
@classmethod @classmethod
async def configure(cls): async def _startup_broadcast(cls):
Logger.set_level(Environment.get("LOG_LEVEL", str, "info")) logger.info("Init Broadcast")
Environment.set_environment(Environment.get("ENVIRONMENT", str, "production")) await broadcast.connect()
logger.info(f"Environment: {Environment.get_environment()}")
app.debug = Environment.get_environment() == "development" @classmethod
async def configure_api(cls):
await Database.startup_db() await cls._startup_db()
await FileService.clean_files() await FileService.clean_files()
await cls._seed_data() await cls._seed_data()
cls._startup_keycloak() cls._startup_keycloak()
await cls._startup_broadcast()
client_urls = Environment.get("CLIENT_URLS", str) @staticmethod
if client_urls is None: @asynccontextmanager
raise EnvironmentError("CLIENT_URLS not set") async def api_lifespan(app: Starlette):
await Startup.configure_api()
origins = client_urls.split(",") port = Environment.get("PORT", int, 5000)
CORS(app, support_credentials=True, resources={r"/api/*": {"origins": origins}}) logger.info(f"Start API server on port: {port}")
if Environment.get_environment() == "development":
logger.info(f"Playground: http://localhost:{port}/ui/playground")
app.debug = Environment.get_environment() == "development"
yield
logger.info("Shutdown API")
@classmethod
def init_api(cls):
logger.info("Init API")
API.import_routes()
API.create(
Starlette(
lifespan=cls.api_lifespan,
routes=[
*Route.registered_routes,
WebSocketRoute(
"/graphql",
endpoint=GraphQL(
schema, websocket_handler=GraphQLTransportWSHandler()
),
),
],
middleware=[
Middleware(RequestMiddleware),
Middleware(LoggingMiddleware),
Middleware(
CORSMiddleware,
allow_origins=API.get_allowed_origins(),
allow_methods=["*"],
allow_headers=["*"],
),
],
exception_handlers={Exception: API.handle_exception},
)
)
@classmethod
def configure(cls):
Logger.set_level(Environment.get("LOG_LEVEL", str, "info"))
Environment.set_environment(Environment.get("ENVIRONMENT", str, "production"))
logger.info(f"Environment: {Environment.get_environment()}")
cls.init_api()

79
maxlan.yaml Normal file
View File

@ -0,0 +1,79 @@
version: "3.9"
services:
open_redirect_dev_redirector:
image: git.sh-edraft.de/sh-edraft.de/open-redirect-redirector-dev:1.2.1
depends_on:
- open_redirect_dev_db
networks:
- open_redirect_dev
- traefik
environment:
- PORT=80
- ENVIRONMENT=development
- DOMAINS=maxlan.de
- DOMAIN_STRICT_MODE=false
- LOG_LEVEL=debug
- DB_HOST=open_redirect_dev_db
- DB_PORT=5432
- DB_USER=open-redirect
- DB_PASSWORD=V0R4bm9rNFlhYks2ODgyTmdDYnFXd09G
- DB_DATABASE=open-redirect
open_redirect_dev_api:
image: git.sh-edraft.de/sh-edraft.de/open-redirect-api-dev:1.2.1
depends_on:
- open_redirect_dev_db
networks:
- open_redirect_dev
- traefik
environment:
- PORT=80
- ENVIRONMENT=development
- CLIENT_URLS=https://dev.or.maxlan.de
- LOG_LEVEL=debug
- DB_HOST=open_redirect_dev_db
- DB_PORT=5432
- DB_USER=open-redirect
- DB_PASSWORD=WTNmamVXTWNNMXVFQ1NNd1RiNUZkdDJr
- DB_DATABASE=open-redirect
- KEYCLOAK_URL=https://keycloak.maxlan.de
- KEYCLOAK_CLIENT_ID=
- KEYCLOAK_REALM=
- KEYCLOAK_CLIENT_SECRET=
volumes:
- open_redirect_dev_files:/app/open_redirect/persistent
open_redirect_dev_web:
image: git.sh-edraft.de/sh-edraft.de/open-redirect-web-dev:1.2.1
depends_on:
- open_redirect_dev_api
networks:
- open_redirect_dev
- traefik
environment:
CONTAINER_NAME: "open_redirect_dev_api"
volumes:
- open_redirect_dev_web_config:/usr/share/nginx/html/assets/config
open_redirect_dev_db:
image: postgres:17
restart: always
environment:
- POSTGRES_USER=open-redirect
- POSTGRES_PASSWORD=Y3fjeWMcM1uECSMwTb5Fdt2k
- POSTGRES_DB=open-redirect
networks:
- open_redirect_dev
volumes:
- open_redirect_dev_db:/var/lib/postgresql/data
networks:
traefik:
external: true
open_redirect_dev:
volumes:
open_redirect_dev_files:
open_redirect_dev_web_config:
open_redirect_dev_db: