148 lines
4.8 KiB
Python
148 lines
4.8 KiB
Python
from contextlib import asynccontextmanager
|
|
|
|
from ariadne.asgi import GraphQL
|
|
from ariadne.asgi.handlers import GraphQLTransportWSHandler
|
|
from starlette.applications import Starlette
|
|
from starlette.middleware import Middleware
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
from starlette.routing import WebSocketRoute
|
|
|
|
from api.api import API
|
|
from api.auth.keycloak_client import Keycloak
|
|
from api.broadcast import broadcast
|
|
from api.middleware.logging import LoggingMiddleware
|
|
from api.middleware.request import RequestMiddleware
|
|
from api.middleware.websocket import AuthenticatedGraphQLTransportWSHandler
|
|
from api.route import Route
|
|
from api_graphql.service.schema import schema
|
|
from core.database.database import Database
|
|
from core.database.database_settings import DatabaseSettings
|
|
from core.database.db_context import DBContext
|
|
from core.environment import Environment
|
|
from core.logger import Logger
|
|
from data.seeder.api_key_seeder import ApiKeySeeder
|
|
from data.seeder.feature_flags_seeder import FeatureFlagsSeeder
|
|
from data.seeder.file_hash_seeder import FileHashSeeder
|
|
from data.seeder.permission_seeder import PermissionSeeder
|
|
from data.seeder.role_seeder import RoleSeeder
|
|
from data.seeder.settings_seeder import SettingsSeeder
|
|
from data.service.migration_service import MigrationService
|
|
from service.file_service import FileService
|
|
|
|
logger = Logger(__name__)
|
|
|
|
|
|
class Startup:
|
|
@classmethod
|
|
def _get_db_settings(cls):
|
|
host = Environment.get("DB_HOST", str)
|
|
port = Environment.get("DB_PORT", int)
|
|
user = Environment.get("DB_USER", str)
|
|
password = Environment.get("DB_PASSWORD", str)
|
|
database = Environment.get("DB_DATABASE", str)
|
|
|
|
if None in [host, port, user, password, database]:
|
|
logger.fatal(
|
|
"DB settings are not set correctly",
|
|
EnvironmentError("DB settings are not set correctly"),
|
|
)
|
|
|
|
return DatabaseSettings(
|
|
host=host, port=port, user=user, password=password, database=database
|
|
)
|
|
|
|
@classmethod
|
|
async def _startup_db(cls):
|
|
logger.info("Init DB")
|
|
db = DBContext()
|
|
|
|
await db.connect(cls._get_db_settings())
|
|
Database.init(db)
|
|
migrations = MigrationService(db)
|
|
await migrations.migrate()
|
|
|
|
@staticmethod
|
|
async def _seed_data():
|
|
seeders = [
|
|
SettingsSeeder,
|
|
FeatureFlagsSeeder,
|
|
PermissionSeeder,
|
|
RoleSeeder,
|
|
ApiKeySeeder,
|
|
FileHashSeeder,
|
|
]
|
|
for seeder in [x() for x in seeders]:
|
|
await seeder.seed()
|
|
|
|
@staticmethod
|
|
def _startup_keycloak():
|
|
logger.info("Init Keycloak")
|
|
Keycloak.init()
|
|
|
|
@classmethod
|
|
async def _startup_broadcast(cls):
|
|
logger.info("Init Broadcast")
|
|
await broadcast.connect()
|
|
|
|
@classmethod
|
|
async def configure_api(cls):
|
|
await cls._startup_db()
|
|
await FileService.clean_files()
|
|
|
|
await cls._seed_data()
|
|
cls._startup_keycloak()
|
|
await cls._startup_broadcast()
|
|
|
|
@staticmethod
|
|
@asynccontextmanager
|
|
async def api_lifespan(app: Starlette):
|
|
await Startup.configure_api()
|
|
|
|
port = Environment.get("PORT", int, 5000)
|
|
logger.info(f"Start API server on port: {port}")
|
|
if Environment.get_environment() == "development":
|
|
logger.info(f"Playground: http://localhost:{port}/ui/playground")
|
|
|
|
app.debug = Environment.get_environment() == "development"
|
|
yield
|
|
logger.info("Shutdown API")
|
|
|
|
@classmethod
|
|
def init_api(cls):
|
|
logger.info("Init API")
|
|
API.import_routes()
|
|
API.create(
|
|
Starlette(
|
|
lifespan=cls.api_lifespan,
|
|
routes=[
|
|
*Route.registered_routes,
|
|
WebSocketRoute(
|
|
"/graphql",
|
|
endpoint=GraphQL(
|
|
schema,
|
|
websocket_handler=AuthenticatedGraphQLTransportWSHandler(),
|
|
),
|
|
),
|
|
],
|
|
middleware=[
|
|
Middleware(RequestMiddleware),
|
|
Middleware(LoggingMiddleware),
|
|
Middleware(
|
|
CORSMiddleware,
|
|
allow_origins=API.get_allowed_origins(),
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
),
|
|
],
|
|
exception_handlers={Exception: API.handle_exception},
|
|
)
|
|
)
|
|
|
|
@classmethod
|
|
def configure(cls):
|
|
Logger.set_level(Environment.get("LOG_LEVEL", str, "info"))
|
|
Environment.set_environment(Environment.get("ENVIRONMENT", str, "production"))
|
|
logger.info(f"Environment: {Environment.get_environment()}")
|
|
|
|
cls.init_api()
|