Redirector asgi rewrite
All checks were successful
Test before pr merge / test-translation-lint (pull_request) Successful in 38s
Test before pr merge / test-lint (pull_request) Successful in 41s
Test before pr merge / test-before-merge (pull_request) Successful in 1m40s

This commit is contained in:
Sven Heidemann 2025-03-08 10:08:03 +01:00
parent 993654dabd
commit 09196e99b1

View File

@ -1,9 +1,13 @@
import asyncio import asyncio
import sys import sys
import eventlet import uvicorn
from eventlet import wsgi from starlette.applications import Starlette
from flask import Flask, request, Response, redirect, render_template from starlette.requests import Request
from starlette.responses import RedirectResponse
from starlette.routing import Route, Mount
from starlette.staticfiles import StaticFiles
from starlette.templating import Jinja2Templates
from core.database.database import Database from core.database.database import Database
from core.environment import Environment from core.environment import Environment
@ -14,72 +18,62 @@ from data.schemas.public.short_url_visit import ShortUrlVisit
from data.schemas.public.short_url_visit_dao import shortUrlVisitDao from data.schemas.public.short_url_visit_dao import shortUrlVisitDao
logger = Logger(__name__) logger = Logger(__name__)
templates = Jinja2Templates(directory="templates")
class Redirector(Flask): async def index(request: Request):
return templates.TemplateResponse("404.html", {"request": request}, status_code=404)
def __init__(self, *args, **kwargs):
Flask.__init__(self, *args, **kwargs)
app = Redirector(__name__) async def handle_request(request: Request):
path = request.path_params["path"]
short_url = await shortUrlDao.find_single_by({ShortUrl.short_url: path})
@app.route("/")
def index():
return render_template("404.html"), 404
@app.route("/<path:path>")
async def _handle_request(path: str):
short_url = await _find_short_url_by_url(path)
if short_url is None: if short_url is None:
return render_template("404.html"), 404 return templates.TemplateResponse("404.html", {"request": request}, status_code=404)
domains = Environment.get("DOMAINS", list[str], []) domains = Environment.get("DOMAINS", list[str], [])
domain = await short_url.domain domain = await short_url.domain
logger.debug( logger.debug(
f"Domain: {domain.name if domain is not None else None}, request.host: {request.host}" f"Domain: {domain.name if domain is not None else None}, request.host: {request.headers['host']}"
) )
host = request.host host = request.headers["host"]
if ":" in host: if ":" in host:
host = host.split(":")[0] host = host.split(":")[0]
domain_strict_mode = Environment.get("DOMAIN_STRICT_MODE", bool, False) domain_strict_mode = Environment.get("DOMAIN_STRICT_MODE", bool, False)
if domain is not None and ( if domain is not None and (
domain.name not in domains domain.name not in domains
or (domain_strict_mode and not host.endswith(domain.name)) or (domain_strict_mode and not host.endswith(domain.name))
): ):
return render_template("404.html"), 404 return templates.TemplateResponse("404.html", {"request": request}, status_code=404)
user_agent = request.headers.get("User-Agent", "").lower() user_agent = request.headers.get("User-Agent", "").lower()
if "wheregoes" in user_agent or "someothertool" in user_agent: if "wheregoes" in user_agent or "someothertool" in user_agent:
return await _handle_short_url(path, short_url) return await _handle_short_url(request, path, short_url)
if short_url.loading_screen: if short_url.loading_screen:
await _track_visit(short_url) await _track_visit(request, short_url)
return render_template( return templates.TemplateResponse(
"redirect.html", "redirect.html",
key=short_url.short_url, {"request": request, "key": short_url.short_url, "target_url": _get_redirect_url(short_url.target_url)},
target_url=_get_redirect_url(short_url.target_url),
) )
return await _handle_short_url(path, short_url) return await _handle_short_url(request, path, short_url)
async def _handle_short_url(path: str, short_url: ShortUrl): async def _handle_short_url(request: Request, path: str, short_url: ShortUrl):
if path.startswith("api/"): if path.startswith("api/"):
path = path.replace("api/", "") path = path.replace("api/", "")
await _track_visit(short_url) await _track_visit(request, short_url)
return _do_redirect(short_url.target_url) return RedirectResponse(_get_redirect_url(short_url.target_url))
async def _track_visit(short_url: ShortUrl): async def _track_visit(request: Request, short_url: ShortUrl):
try: try:
await shortUrlVisitDao.create( await shortUrlVisitDao.create(
ShortUrlVisit(0, short_url.id, request.headers.get("User-Agent")) ShortUrlVisit(0, short_url.id, request.headers.get("User-Agent"))
@ -88,58 +82,43 @@ async def _track_visit(short_url: ShortUrl):
logger.error(f"Failed to update short url {short_url.short_url} with error", e) logger.error(f"Failed to update short url {short_url.short_url} with error", e)
async def _find_short_url_by_url(url: str) -> ShortUrl:
return await shortUrlDao.find_single_by({ShortUrl.short_url: url})
def _get_redirect_url(url: str) -> str: def _get_redirect_url(url: str) -> str:
# todo: multiple protocols like ts3://
if not url.startswith("http://") and not url.startswith("https://"): if not url.startswith("http://") and not url.startswith("https://"):
url = f"http://{url}" url = f"http://{url}"
return url return url
def _do_redirect(url: str) -> Response:
return redirect(_get_redirect_url(url))
async def configure(): async def configure():
Logger.set_level(Environment.get("LOG_LEVEL", str, "info")) Logger.set_level(Environment.get("LOG_LEVEL", str, "info"))
Environment.set_environment(Environment.get("ENVIRONMENT", str, "production")) Environment.set_environment(Environment.get("ENVIRONMENT", str, "production"))
logger.info(f"Environment: {Environment.get_environment()}") logger.info(f"Environment: {Environment.get_environment()}")
app.debug = Environment.get_environment() == "development"
await Database.startup_db() await Database.startup_db()
routes = [
Route("/", endpoint=index),
Mount('/static', StaticFiles(directory='static'), name='static'),
Route("/{path:path}", endpoint=handle_request),
]
app = Starlette(routes=routes, on_startup=[configure])
def main(): def main():
if sys.platform == "win32": if sys.platform == "win32":
from asyncio import WindowsSelectorEventLoopPolicy from asyncio import WindowsSelectorEventLoopPolicy
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
loop = asyncio.new_event_loop() uvicorn.run(
loop.run_until_complete(configure()) app,
loop.close() host="0.0.0.0",
port=Environment.get("PORT", int, 5001),
port = Environment.get("PORT", int, 5001) log_config=None,
logger.info(f"Start API on port: {port}") )
if Environment.get_environment() == "development":
logger.info(f"Playground: http://localhost:{port}/")
wsgi.server(eventlet.listen(("0.0.0.0", port)), app, log_output=False)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# ((
# ( )
# ; / ,
# / \/
# / |
# / ~/
# / ) ) ~ edraft
# ___// | /
# --' \_~-,