Redirector asgi rewrite
This commit is contained in:
parent
993654dabd
commit
09196e99b1
@ -1,9 +1,13 @@
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import eventlet
|
||||
from eventlet import wsgi
|
||||
from flask import Flask, request, Response, redirect, render_template
|
||||
import uvicorn
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import RedirectResponse
|
||||
from starlette.routing import Route, Mount
|
||||
from starlette.staticfiles import StaticFiles
|
||||
from starlette.templating import Jinja2Templates
|
||||
|
||||
from core.database.database import Database
|
||||
from core.environment import Environment
|
||||
@ -14,72 +18,62 @@ from data.schemas.public.short_url_visit import ShortUrlVisit
|
||||
from data.schemas.public.short_url_visit_dao import shortUrlVisitDao
|
||||
|
||||
logger = Logger(__name__)
|
||||
templates = Jinja2Templates(directory="templates")
|
||||
|
||||
|
||||
class Redirector(Flask):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
Flask.__init__(self, *args, **kwargs)
|
||||
async def index(request: Request):
|
||||
return templates.TemplateResponse("404.html", {"request": request}, status_code=404)
|
||||
|
||||
|
||||
app = Redirector(__name__)
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return render_template("404.html"), 404
|
||||
|
||||
|
||||
@app.route("/<path:path>")
|
||||
async def _handle_request(path: str):
|
||||
short_url = await _find_short_url_by_url(path)
|
||||
async def handle_request(request: Request):
|
||||
path = request.path_params["path"]
|
||||
short_url = await shortUrlDao.find_single_by({ShortUrl.short_url: path})
|
||||
if short_url is None:
|
||||
return render_template("404.html"), 404
|
||||
return templates.TemplateResponse("404.html", {"request": request}, status_code=404)
|
||||
|
||||
domains = Environment.get("DOMAINS", list[str], [])
|
||||
domain = await short_url.domain
|
||||
logger.debug(
|
||||
f"Domain: {domain.name if domain is not None else None}, request.host: {request.host}"
|
||||
f"Domain: {domain.name if domain is not None else None}, request.host: {request.headers['host']}"
|
||||
)
|
||||
|
||||
host = request.host
|
||||
host = request.headers["host"]
|
||||
if ":" in host:
|
||||
host = host.split(":")[0]
|
||||
|
||||
domain_strict_mode = Environment.get("DOMAIN_STRICT_MODE", bool, False)
|
||||
if domain is not None and (
|
||||
domain.name not in domains
|
||||
or (domain_strict_mode and not host.endswith(domain.name))
|
||||
domain.name not in domains
|
||||
or (domain_strict_mode and not host.endswith(domain.name))
|
||||
):
|
||||
return render_template("404.html"), 404
|
||||
return templates.TemplateResponse("404.html", {"request": request}, status_code=404)
|
||||
|
||||
user_agent = request.headers.get("User-Agent", "").lower()
|
||||
|
||||
if "wheregoes" in user_agent or "someothertool" in user_agent:
|
||||
return await _handle_short_url(path, short_url)
|
||||
return await _handle_short_url(request, path, short_url)
|
||||
|
||||
if short_url.loading_screen:
|
||||
await _track_visit(short_url)
|
||||
await _track_visit(request, short_url)
|
||||
|
||||
return render_template(
|
||||
return templates.TemplateResponse(
|
||||
"redirect.html",
|
||||
key=short_url.short_url,
|
||||
target_url=_get_redirect_url(short_url.target_url),
|
||||
{"request": request, "key": short_url.short_url, "target_url": _get_redirect_url(short_url.target_url)},
|
||||
)
|
||||
|
||||
return await _handle_short_url(path, short_url)
|
||||
return await _handle_short_url(request, path, short_url)
|
||||
|
||||
|
||||
async def _handle_short_url(path: str, short_url: ShortUrl):
|
||||
async def _handle_short_url(request: Request, path: str, short_url: ShortUrl):
|
||||
if path.startswith("api/"):
|
||||
path = path.replace("api/", "")
|
||||
|
||||
await _track_visit(short_url)
|
||||
await _track_visit(request, short_url)
|
||||
|
||||
return _do_redirect(short_url.target_url)
|
||||
return RedirectResponse(_get_redirect_url(short_url.target_url))
|
||||
|
||||
|
||||
async def _track_visit(short_url: ShortUrl):
|
||||
async def _track_visit(request: Request, short_url: ShortUrl):
|
||||
try:
|
||||
await shortUrlVisitDao.create(
|
||||
ShortUrlVisit(0, short_url.id, request.headers.get("User-Agent"))
|
||||
@ -88,58 +82,43 @@ async def _track_visit(short_url: ShortUrl):
|
||||
logger.error(f"Failed to update short url {short_url.short_url} with error", e)
|
||||
|
||||
|
||||
async def _find_short_url_by_url(url: str) -> ShortUrl:
|
||||
return await shortUrlDao.find_single_by({ShortUrl.short_url: url})
|
||||
|
||||
|
||||
def _get_redirect_url(url: str) -> str:
|
||||
# todo: multiple protocols like ts3://
|
||||
if not url.startswith("http://") and not url.startswith("https://"):
|
||||
url = f"http://{url}"
|
||||
|
||||
return url
|
||||
|
||||
|
||||
def _do_redirect(url: str) -> Response:
|
||||
return redirect(_get_redirect_url(url))
|
||||
|
||||
|
||||
async def configure():
|
||||
Logger.set_level(Environment.get("LOG_LEVEL", str, "info"))
|
||||
Environment.set_environment(Environment.get("ENVIRONMENT", str, "production"))
|
||||
logger.info(f"Environment: {Environment.get_environment()}")
|
||||
|
||||
app.debug = Environment.get_environment() == "development"
|
||||
|
||||
await Database.startup_db()
|
||||
|
||||
|
||||
routes = [
|
||||
Route("/", endpoint=index),
|
||||
Mount('/static', StaticFiles(directory='static'), name='static'),
|
||||
Route("/{path:path}", endpoint=handle_request),
|
||||
]
|
||||
|
||||
app = Starlette(routes=routes, on_startup=[configure])
|
||||
|
||||
|
||||
def main():
|
||||
if sys.platform == "win32":
|
||||
from asyncio import WindowsSelectorEventLoopPolicy
|
||||
|
||||
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(configure())
|
||||
loop.close()
|
||||
|
||||
port = Environment.get("PORT", int, 5001)
|
||||
logger.info(f"Start API on port: {port}")
|
||||
if Environment.get_environment() == "development":
|
||||
logger.info(f"Playground: http://localhost:{port}/")
|
||||
wsgi.server(eventlet.listen(("0.0.0.0", port)), app, log_output=False)
|
||||
uvicorn.run(
|
||||
app,
|
||||
host="0.0.0.0",
|
||||
port=Environment.get("PORT", int, 5001),
|
||||
log_config=None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
# ((
|
||||
# ( )
|
||||
# ; / ,
|
||||
# / \/
|
||||
# / |
|
||||
# / ~/
|
||||
# / ) ) ~ edraft
|
||||
# ___// | /
|
||||
# --' \_~-,
|
||||
main()
|
Loading…
Reference in New Issue
Block a user