open-redirect/api/src/redirector.py
edraft 5c9ff4b813
Some checks failed
Test API before pr merge / test-lint (pull_request) Failing after 10s
Test before pr merge / test-translation-lint (pull_request) Successful in 39s
Test before pr merge / test-lint (pull_request) Successful in 42s
Test before pr merge / test-before-merge (pull_request) Successful in 1m56s
Added redirector caching #21
2025-05-02 13:32:15 +02:00

392 lines
12 KiB
Python

import asyncio
import sys
from datetime import datetime
from typing import Optional
import requests
import uvicorn
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import RedirectResponse
from starlette.routing import Route, Mount
from starlette.staticfiles import StaticFiles
from starlette.templating import Jinja2Templates
from core.environment import Environment
from core.logger import Logger
logger = Logger(__name__)
templates = Jinja2Templates(directory="templates")
class Cache:
CACHING_MINUTES = Environment.get("CACHING_MINUTES", int, 5)
# {shortUrlKey: ShortUrl}
_cache: dict[str, dict] = {}
_cache_timestamps: dict[str, datetime] = {}
@classmethod
def is_expired(cls, key: str) -> bool:
logger.trace(f"Check if cache for {key} is expired")
timestamp = cls._cache_timestamps.get(key)
if timestamp is None:
return True
now = datetime.now()
diff = now - timestamp
res = diff.total_seconds() > cls.CACHING_MINUTES * 60
if res:
logger.debug(f"Cache for {key} is expired")
return res
@classmethod
def remove(cls, key: str):
logger.trace(f"Remove cache for {key}")
if key in cls._cache:
del cls._cache[key]
if key in cls._cache_timestamps:
del cls._cache_timestamps[key]
@classmethod
def check_expired(cls, key: str):
logger.trace(f"Check expired cache for {key}")
if cls.is_expired(key):
cls.remove(key)
return True
return False
@classmethod
def get(cls, key: str) -> Optional[dict]:
logger.debug(f"Get cache for {key}")
value = cls._cache.get(key, None)
if value is not None:
if cls.is_expired(key):
logger.debug(f"Cache for {key} expired")
cls.remove(key)
return value
@classmethod
def set(cls, key: str, value: dict):
logger.debug(f"Set cache for {key} with value {value}")
cls._cache[key] = value
cls._cache_timestamps[key] = datetime.now()
@classmethod
def clear(cls):
logger.debug("Clear cache")
cls._cache = {}
async def index(request: Request):
return templates.TemplateResponse("404.html", {"request": request}, status_code=404)
async def handle_request(request: Request):
path = request.path_params["path"]
short_url = await _find_short_url_by_path(path)
if short_url is None:
return templates.TemplateResponse(
"404.html", {"request": request}, status_code=404
)
domains = Environment.get("DOMAINS", list[str], [])
domain = short_url["domain"]
logger.debug(
f"Domain: {domain["name"] if domain is not None else None}, request.host: {request.headers['host']}"
)
host = request.headers["host"]
if ":" in host:
host = host.split(":")[0]
domain_strict_mode = Environment.get("DOMAIN_STRICT_MODE", bool, False)
if domain is not None and (
domain["name"] not in domains
or (domain_strict_mode and not host.endswith(domain["name"]))
):
return templates.TemplateResponse(
"404.html", {"request": request}, status_code=404
)
user_agent = request.headers.get("User-Agent", "").lower()
if "wheregoes" in user_agent or "someothertool" in user_agent:
return await _handle_short_url(request, short_url)
if short_url["loadingScreen"]:
await _track_visit(request, short_url)
return templates.TemplateResponse(
"redirect.html",
{
"request": request,
"key": short_url["shortUrl"],
"target_url": _get_redirect_url(short_url["targetUrl"]),
},
)
return await _handle_short_url(request, short_url)
async def _find_short_url_by_path(path: str) -> Optional[dict]:
from_cache = Cache.get(path)
if from_cache is not None:
if Cache.check_expired(path):
asyncio.create_task(_find_short_url_by_path(path))
return from_cache
api_url = Environment.get("API_URL", str)
if api_url is None:
raise Exception("API_URL is not set")
api_key = Environment.get("API_KEY", str)
if api_key is None:
raise Exception("API_KEY is not set")
request = requests.post(
f"{api_url}/graphql",
json={
"query": f"""
query getShortUrlByPath($path: String!) {{
shortUrls(filter: [{{ shortUrl: {{ equal: $path }} }}, {{ deleted: {{ equal: false }} }}, {{ group: {{ deleted: {{ equal: false }} }} }}]) {{
nodes {{
id
shortUrl
targetUrl
description
group {{
id
name
}}
domain {{
id
name
}}
loadingScreen
deleted
}}
}}
shortUrlsWithoutGroup: shortUrls(filter: [{{ shortUrl: {{ equal: $path }} }}, {{ deleted: {{ equal: false }} }}, {{ group: {{ isNull: true }} }}]) {{
nodes {{
id
shortUrl
targetUrl
description
group {{
id
name
}}
domain {{
id
name
}}
loadingScreen
deleted
}}
}}
}}
""",
"variables": {"path": path},
},
headers={"Authorization": f"API-Key {api_key}"},
)
data = request.json()
if "errors" in data:
logger.warning(f"Failed to find short url by path {path} -> {data["errors"]}")
if (
"data" not in data
or "shortUrls" not in data["data"]
or "nodes" not in data["data"]["shortUrls"]
or "nodes" not in data["data"]["shortUrlsWithoutGroup"]
):
return None
nodes = [
*data["data"]["shortUrls"]["nodes"],
*data["data"]["shortUrlsWithoutGroup"]["nodes"],
]
if len(nodes) == 0:
return None
for node in nodes:
Cache.set(node["shortUrl"], node)
return nodes[0]
async def _handle_short_url(request: Request, short_url: dict):
await _track_visit(request, short_url)
return RedirectResponse(_get_redirect_url(short_url["targetUrl"]))
async def _track_visit(r: Request, short_url: dict):
api_url = Environment.get("API_URL", str)
if api_url is None:
raise Exception("API_URL is not set")
api_key = Environment.get("API_KEY", str)
if api_key is None:
raise Exception("API_KEY is not set")
try:
request = requests.post(
f"{api_url}/graphql",
json={
"query": f"""
mutation trackShortUrlVisit($id: Int!, $agent: String) {{
shortUrl {{
trackVisit(id: $id, agent: $agent)
}}
}}
""",
"variables": {
"id": short_url["id"],
"agent": r.headers.get("User-Agent"),
},
},
headers={"Authorization": f"API-Key {api_key}"},
)
if request.status_code != 200:
logger.warning(
f"Failed to track visit for short url {short_url["shortUrl"]}"
)
data = request.json()
if "errors" in data:
raise Exception(data["errors"])
else:
logger.debug(f"Tracked visit for short url {short_url["shortUrl"]}")
except Exception as e:
logger.error(
f"Failed to update short url {short_url["shortUrl"]} with error", e
)
def _get_redirect_url(url: str) -> str:
protocols = Environment.get("PROTOCOLS", list[str], ["http", "https"])
if not any(url.startswith(f"{protocol}://") for protocol in protocols):
url = f"http://{url}"
return url
def _get_all_short_urls():
logger.info("Loading all short urls to cache")
api_url = Environment.get("API_URL", str)
if api_url is None:
raise Exception("API_URL is not set")
api_key = Environment.get("API_KEY", str)
if api_key is None:
raise Exception("API_KEY is not set")
request = requests.post(
f"{api_url}/graphql",
json={
"query": f"""
query getShortUrlsForCache {{
shortUrls(filter: [{{ deleted: {{ equal: false }} }}, {{ group: {{ deleted: {{ equal: false }} }} }}]) {{
nodes {{
id
shortUrl
targetUrl
description
group {{
id
name
}}
domain {{
id
name
}}
loadingScreen
deleted
}}
}}
shortUrlsWithoutGroup: shortUrls(filter: [{{ deleted: {{ equal: false }} }}, {{ group: {{ isNull: true }} }}]) {{
nodes {{
id
shortUrl
targetUrl
description
group {{
id
name
}}
domain {{
id
name
}}
loadingScreen
deleted
}}
}}
}}
""",
"variables": {},
},
headers={"Authorization": f"API-Key {api_key}"},
)
data = request.json()
if "errors" in data:
logger.warning(f"Failed to get all short urls -> {data["errors"]}")
if (
"data" not in data
or "shortUrls" not in data["data"]
or "nodes" not in data["data"]["shortUrls"]
or "nodes" not in data["data"]["shortUrlsWithoutGroup"]
):
raise ValueError("Failed to get all short urls")
nodes = [
*data["data"]["shortUrls"]["nodes"],
*data["data"]["shortUrlsWithoutGroup"]["nodes"],
]
for node in nodes:
Cache.set(node["shortUrl"], node)
logger.info(f"Loaded {len(nodes)} short urls to cache")
async def configure():
Logger.set_level(Environment.get("LOG_LEVEL", str, "info"))
Environment.set_environment(Environment.get("ENVIRONMENT", str, "production"))
logger.info(f"Environment: {Environment.get_environment()}")
_get_all_short_urls()
routes = [
Route("/", endpoint=index),
Mount("/static", StaticFiles(directory="static"), name="static"),
Route("/{path:path}", endpoint=handle_request),
]
app = Starlette(routes=routes, on_startup=[configure])
def main():
if sys.platform == "win32":
from asyncio import WindowsSelectorEventLoopPolicy
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
uvicorn.run(
app,
host="0.0.0.0",
port=Environment.get("PORT", int, 5001),
log_config=None,
)
if __name__ == "__main__":
main()