Redirector asgi rewrite
This commit is contained in:
parent
993654dabd
commit
09196e99b1
@ -1,9 +1,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import eventlet
|
import uvicorn
|
||||||
from eventlet import wsgi
|
from starlette.applications import Starlette
|
||||||
from flask import Flask, request, Response, redirect, render_template
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import RedirectResponse
|
||||||
|
from starlette.routing import Route, Mount
|
||||||
|
from starlette.staticfiles import StaticFiles
|
||||||
|
from starlette.templating import Jinja2Templates
|
||||||
|
|
||||||
from core.database.database import Database
|
from core.database.database import Database
|
||||||
from core.environment import Environment
|
from core.environment import Environment
|
||||||
@ -14,72 +18,62 @@ from data.schemas.public.short_url_visit import ShortUrlVisit
|
|||||||
from data.schemas.public.short_url_visit_dao import shortUrlVisitDao
|
from data.schemas.public.short_url_visit_dao import shortUrlVisitDao
|
||||||
|
|
||||||
logger = Logger(__name__)
|
logger = Logger(__name__)
|
||||||
|
templates = Jinja2Templates(directory="templates")
|
||||||
|
|
||||||
|
|
||||||
class Redirector(Flask):
|
async def index(request: Request):
|
||||||
|
return templates.TemplateResponse("404.html", {"request": request}, status_code=404)
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
Flask.__init__(self, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
app = Redirector(__name__)
|
async def handle_request(request: Request):
|
||||||
|
path = request.path_params["path"]
|
||||||
|
short_url = await shortUrlDao.find_single_by({ShortUrl.short_url: path})
|
||||||
@app.route("/")
|
|
||||||
def index():
|
|
||||||
return render_template("404.html"), 404
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/<path:path>")
|
|
||||||
async def _handle_request(path: str):
|
|
||||||
short_url = await _find_short_url_by_url(path)
|
|
||||||
if short_url is None:
|
if short_url is None:
|
||||||
return render_template("404.html"), 404
|
return templates.TemplateResponse("404.html", {"request": request}, status_code=404)
|
||||||
|
|
||||||
domains = Environment.get("DOMAINS", list[str], [])
|
domains = Environment.get("DOMAINS", list[str], [])
|
||||||
domain = await short_url.domain
|
domain = await short_url.domain
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Domain: {domain.name if domain is not None else None}, request.host: {request.host}"
|
f"Domain: {domain.name if domain is not None else None}, request.host: {request.headers['host']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
host = request.host
|
host = request.headers["host"]
|
||||||
if ":" in host:
|
if ":" in host:
|
||||||
host = host.split(":")[0]
|
host = host.split(":")[0]
|
||||||
|
|
||||||
domain_strict_mode = Environment.get("DOMAIN_STRICT_MODE", bool, False)
|
domain_strict_mode = Environment.get("DOMAIN_STRICT_MODE", bool, False)
|
||||||
if domain is not None and (
|
if domain is not None and (
|
||||||
domain.name not in domains
|
domain.name not in domains
|
||||||
or (domain_strict_mode and not host.endswith(domain.name))
|
or (domain_strict_mode and not host.endswith(domain.name))
|
||||||
):
|
):
|
||||||
return render_template("404.html"), 404
|
return templates.TemplateResponse("404.html", {"request": request}, status_code=404)
|
||||||
|
|
||||||
user_agent = request.headers.get("User-Agent", "").lower()
|
user_agent = request.headers.get("User-Agent", "").lower()
|
||||||
|
|
||||||
if "wheregoes" in user_agent or "someothertool" in user_agent:
|
if "wheregoes" in user_agent or "someothertool" in user_agent:
|
||||||
return await _handle_short_url(path, short_url)
|
return await _handle_short_url(request, path, short_url)
|
||||||
|
|
||||||
if short_url.loading_screen:
|
if short_url.loading_screen:
|
||||||
await _track_visit(short_url)
|
await _track_visit(request, short_url)
|
||||||
|
|
||||||
return render_template(
|
return templates.TemplateResponse(
|
||||||
"redirect.html",
|
"redirect.html",
|
||||||
key=short_url.short_url,
|
{"request": request, "key": short_url.short_url, "target_url": _get_redirect_url(short_url.target_url)},
|
||||||
target_url=_get_redirect_url(short_url.target_url),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return await _handle_short_url(path, short_url)
|
return await _handle_short_url(request, path, short_url)
|
||||||
|
|
||||||
|
|
||||||
async def _handle_short_url(path: str, short_url: ShortUrl):
|
async def _handle_short_url(request: Request, path: str, short_url: ShortUrl):
|
||||||
if path.startswith("api/"):
|
if path.startswith("api/"):
|
||||||
path = path.replace("api/", "")
|
path = path.replace("api/", "")
|
||||||
|
|
||||||
await _track_visit(short_url)
|
await _track_visit(request, short_url)
|
||||||
|
|
||||||
return _do_redirect(short_url.target_url)
|
return RedirectResponse(_get_redirect_url(short_url.target_url))
|
||||||
|
|
||||||
|
|
||||||
async def _track_visit(short_url: ShortUrl):
|
async def _track_visit(request: Request, short_url: ShortUrl):
|
||||||
try:
|
try:
|
||||||
await shortUrlVisitDao.create(
|
await shortUrlVisitDao.create(
|
||||||
ShortUrlVisit(0, short_url.id, request.headers.get("User-Agent"))
|
ShortUrlVisit(0, short_url.id, request.headers.get("User-Agent"))
|
||||||
@ -88,58 +82,43 @@ async def _track_visit(short_url: ShortUrl):
|
|||||||
logger.error(f"Failed to update short url {short_url.short_url} with error", e)
|
logger.error(f"Failed to update short url {short_url.short_url} with error", e)
|
||||||
|
|
||||||
|
|
||||||
async def _find_short_url_by_url(url: str) -> ShortUrl:
|
|
||||||
return await shortUrlDao.find_single_by({ShortUrl.short_url: url})
|
|
||||||
|
|
||||||
|
|
||||||
def _get_redirect_url(url: str) -> str:
|
def _get_redirect_url(url: str) -> str:
|
||||||
# todo: multiple protocols like ts3://
|
|
||||||
if not url.startswith("http://") and not url.startswith("https://"):
|
if not url.startswith("http://") and not url.startswith("https://"):
|
||||||
url = f"http://{url}"
|
url = f"http://{url}"
|
||||||
|
|
||||||
return url
|
return url
|
||||||
|
|
||||||
|
|
||||||
def _do_redirect(url: str) -> Response:
|
|
||||||
return redirect(_get_redirect_url(url))
|
|
||||||
|
|
||||||
|
|
||||||
async def configure():
|
async def configure():
|
||||||
Logger.set_level(Environment.get("LOG_LEVEL", str, "info"))
|
Logger.set_level(Environment.get("LOG_LEVEL", str, "info"))
|
||||||
Environment.set_environment(Environment.get("ENVIRONMENT", str, "production"))
|
Environment.set_environment(Environment.get("ENVIRONMENT", str, "production"))
|
||||||
logger.info(f"Environment: {Environment.get_environment()}")
|
logger.info(f"Environment: {Environment.get_environment()}")
|
||||||
|
|
||||||
app.debug = Environment.get_environment() == "development"
|
|
||||||
|
|
||||||
await Database.startup_db()
|
await Database.startup_db()
|
||||||
|
|
||||||
|
|
||||||
|
routes = [
|
||||||
|
Route("/", endpoint=index),
|
||||||
|
Mount('/static', StaticFiles(directory='static'), name='static'),
|
||||||
|
Route("/{path:path}", endpoint=handle_request),
|
||||||
|
]
|
||||||
|
|
||||||
|
app = Starlette(routes=routes, on_startup=[configure])
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if sys.platform == "win32":
|
if sys.platform == "win32":
|
||||||
from asyncio import WindowsSelectorEventLoopPolicy
|
from asyncio import WindowsSelectorEventLoopPolicy
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
|
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
uvicorn.run(
|
||||||
loop.run_until_complete(configure())
|
app,
|
||||||
loop.close()
|
host="0.0.0.0",
|
||||||
|
port=Environment.get("PORT", int, 5001),
|
||||||
port = Environment.get("PORT", int, 5001)
|
log_config=None,
|
||||||
logger.info(f"Start API on port: {port}")
|
)
|
||||||
if Environment.get_environment() == "development":
|
|
||||||
logger.info(f"Playground: http://localhost:{port}/")
|
|
||||||
wsgi.server(eventlet.listen(("0.0.0.0", port)), app, log_output=False)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
# ((
|
|
||||||
# ( )
|
|
||||||
# ; / ,
|
|
||||||
# / \/
|
|
||||||
# / |
|
|
||||||
# / ~/
|
|
||||||
# / ) ) ~ edraft
|
|
||||||
# ___// | /
|
|
||||||
# --' \_~-,
|
|
Loading…
Reference in New Issue
Block a user