Some checks failed
Test API before pr merge / test-lint (pull_request) Failing after 10s
Test before pr merge / test-translation-lint (pull_request) Successful in 39s
Test before pr merge / test-lint (pull_request) Successful in 42s
Test before pr merge / test-before-merge (pull_request) Successful in 1m56s
392 lines
12 KiB
Python
392 lines
12 KiB
Python
import asyncio
|
|
import sys
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
import requests
|
|
import uvicorn
|
|
from starlette.applications import Starlette
|
|
from starlette.requests import Request
|
|
from starlette.responses import RedirectResponse
|
|
from starlette.routing import Route, Mount
|
|
from starlette.staticfiles import StaticFiles
|
|
from starlette.templating import Jinja2Templates
|
|
|
|
from core.environment import Environment
|
|
from core.logger import Logger
|
|
|
|
logger = Logger(__name__)
|
|
templates = Jinja2Templates(directory="templates")
|
|
|
|
|
|
class Cache:
|
|
CACHING_MINUTES = Environment.get("CACHING_MINUTES", int, 5)
|
|
|
|
# {shortUrlKey: ShortUrl}
|
|
_cache: dict[str, dict] = {}
|
|
_cache_timestamps: dict[str, datetime] = {}
|
|
|
|
@classmethod
|
|
def is_expired(cls, key: str) -> bool:
|
|
logger.trace(f"Check if cache for {key} is expired")
|
|
timestamp = cls._cache_timestamps.get(key)
|
|
if timestamp is None:
|
|
return True
|
|
now = datetime.now()
|
|
diff = now - timestamp
|
|
res = diff.total_seconds() > cls.CACHING_MINUTES * 60
|
|
if res:
|
|
logger.debug(f"Cache for {key} is expired")
|
|
return res
|
|
|
|
@classmethod
|
|
def remove(cls, key: str):
|
|
logger.trace(f"Remove cache for {key}")
|
|
if key in cls._cache:
|
|
del cls._cache[key]
|
|
if key in cls._cache_timestamps:
|
|
del cls._cache_timestamps[key]
|
|
|
|
@classmethod
|
|
def check_expired(cls, key: str):
|
|
logger.trace(f"Check expired cache for {key}")
|
|
if cls.is_expired(key):
|
|
cls.remove(key)
|
|
return True
|
|
return False
|
|
|
|
@classmethod
|
|
def get(cls, key: str) -> Optional[dict]:
|
|
logger.debug(f"Get cache for {key}")
|
|
value = cls._cache.get(key, None)
|
|
if value is not None:
|
|
if cls.is_expired(key):
|
|
logger.debug(f"Cache for {key} expired")
|
|
cls.remove(key)
|
|
|
|
return value
|
|
|
|
@classmethod
|
|
def set(cls, key: str, value: dict):
|
|
logger.debug(f"Set cache for {key} with value {value}")
|
|
cls._cache[key] = value
|
|
cls._cache_timestamps[key] = datetime.now()
|
|
|
|
@classmethod
|
|
def clear(cls):
|
|
logger.debug("Clear cache")
|
|
cls._cache = {}
|
|
|
|
|
|
async def index(request: Request):
|
|
return templates.TemplateResponse("404.html", {"request": request}, status_code=404)
|
|
|
|
|
|
async def handle_request(request: Request):
|
|
path = request.path_params["path"]
|
|
short_url = await _find_short_url_by_path(path)
|
|
if short_url is None:
|
|
return templates.TemplateResponse(
|
|
"404.html", {"request": request}, status_code=404
|
|
)
|
|
|
|
domains = Environment.get("DOMAINS", list[str], [])
|
|
domain = short_url["domain"]
|
|
logger.debug(
|
|
f"Domain: {domain["name"] if domain is not None else None}, request.host: {request.headers['host']}"
|
|
)
|
|
|
|
host = request.headers["host"]
|
|
if ":" in host:
|
|
host = host.split(":")[0]
|
|
|
|
domain_strict_mode = Environment.get("DOMAIN_STRICT_MODE", bool, False)
|
|
if domain is not None and (
|
|
domain["name"] not in domains
|
|
or (domain_strict_mode and not host.endswith(domain["name"]))
|
|
):
|
|
return templates.TemplateResponse(
|
|
"404.html", {"request": request}, status_code=404
|
|
)
|
|
|
|
user_agent = request.headers.get("User-Agent", "").lower()
|
|
|
|
if "wheregoes" in user_agent or "someothertool" in user_agent:
|
|
return await _handle_short_url(request, short_url)
|
|
|
|
if short_url["loadingScreen"]:
|
|
await _track_visit(request, short_url)
|
|
|
|
return templates.TemplateResponse(
|
|
"redirect.html",
|
|
{
|
|
"request": request,
|
|
"key": short_url["shortUrl"],
|
|
"target_url": _get_redirect_url(short_url["targetUrl"]),
|
|
},
|
|
)
|
|
|
|
return await _handle_short_url(request, short_url)
|
|
|
|
|
|
async def _find_short_url_by_path(path: str) -> Optional[dict]:
|
|
from_cache = Cache.get(path)
|
|
if from_cache is not None:
|
|
if Cache.check_expired(path):
|
|
asyncio.create_task(_find_short_url_by_path(path))
|
|
return from_cache
|
|
|
|
api_url = Environment.get("API_URL", str)
|
|
if api_url is None:
|
|
raise Exception("API_URL is not set")
|
|
|
|
api_key = Environment.get("API_KEY", str)
|
|
if api_key is None:
|
|
raise Exception("API_KEY is not set")
|
|
|
|
request = requests.post(
|
|
f"{api_url}/graphql",
|
|
json={
|
|
"query": f"""
|
|
query getShortUrlByPath($path: String!) {{
|
|
shortUrls(filter: [{{ shortUrl: {{ equal: $path }} }}, {{ deleted: {{ equal: false }} }}, {{ group: {{ deleted: {{ equal: false }} }} }}]) {{
|
|
nodes {{
|
|
id
|
|
shortUrl
|
|
targetUrl
|
|
description
|
|
group {{
|
|
id
|
|
name
|
|
}}
|
|
domain {{
|
|
id
|
|
name
|
|
}}
|
|
loadingScreen
|
|
deleted
|
|
}}
|
|
}}
|
|
|
|
shortUrlsWithoutGroup: shortUrls(filter: [{{ shortUrl: {{ equal: $path }} }}, {{ deleted: {{ equal: false }} }}, {{ group: {{ isNull: true }} }}]) {{
|
|
nodes {{
|
|
id
|
|
shortUrl
|
|
targetUrl
|
|
description
|
|
group {{
|
|
id
|
|
name
|
|
}}
|
|
domain {{
|
|
id
|
|
name
|
|
}}
|
|
loadingScreen
|
|
deleted
|
|
}}
|
|
}}
|
|
}}
|
|
""",
|
|
"variables": {"path": path},
|
|
},
|
|
headers={"Authorization": f"API-Key {api_key}"},
|
|
)
|
|
data = request.json()
|
|
if "errors" in data:
|
|
logger.warning(f"Failed to find short url by path {path} -> {data["errors"]}")
|
|
|
|
if (
|
|
"data" not in data
|
|
or "shortUrls" not in data["data"]
|
|
or "nodes" not in data["data"]["shortUrls"]
|
|
or "nodes" not in data["data"]["shortUrlsWithoutGroup"]
|
|
):
|
|
return None
|
|
|
|
nodes = [
|
|
*data["data"]["shortUrls"]["nodes"],
|
|
*data["data"]["shortUrlsWithoutGroup"]["nodes"],
|
|
]
|
|
if len(nodes) == 0:
|
|
return None
|
|
|
|
for node in nodes:
|
|
Cache.set(node["shortUrl"], node)
|
|
return nodes[0]
|
|
|
|
|
|
async def _handle_short_url(request: Request, short_url: dict):
|
|
await _track_visit(request, short_url)
|
|
|
|
return RedirectResponse(_get_redirect_url(short_url["targetUrl"]))
|
|
|
|
|
|
async def _track_visit(r: Request, short_url: dict):
|
|
api_url = Environment.get("API_URL", str)
|
|
if api_url is None:
|
|
raise Exception("API_URL is not set")
|
|
|
|
api_key = Environment.get("API_KEY", str)
|
|
if api_key is None:
|
|
raise Exception("API_KEY is not set")
|
|
|
|
try:
|
|
request = requests.post(
|
|
f"{api_url}/graphql",
|
|
json={
|
|
"query": f"""
|
|
mutation trackShortUrlVisit($id: Int!, $agent: String) {{
|
|
shortUrl {{
|
|
trackVisit(id: $id, agent: $agent)
|
|
}}
|
|
}}
|
|
""",
|
|
"variables": {
|
|
"id": short_url["id"],
|
|
"agent": r.headers.get("User-Agent"),
|
|
},
|
|
},
|
|
headers={"Authorization": f"API-Key {api_key}"},
|
|
)
|
|
if request.status_code != 200:
|
|
logger.warning(
|
|
f"Failed to track visit for short url {short_url["shortUrl"]}"
|
|
)
|
|
|
|
data = request.json()
|
|
if "errors" in data:
|
|
raise Exception(data["errors"])
|
|
else:
|
|
logger.debug(f"Tracked visit for short url {short_url["shortUrl"]}")
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to update short url {short_url["shortUrl"]} with error", e
|
|
)
|
|
|
|
|
|
def _get_redirect_url(url: str) -> str:
|
|
protocols = Environment.get("PROTOCOLS", list[str], ["http", "https"])
|
|
|
|
if not any(url.startswith(f"{protocol}://") for protocol in protocols):
|
|
url = f"http://{url}"
|
|
|
|
return url
|
|
|
|
|
|
def _get_all_short_urls():
|
|
logger.info("Loading all short urls to cache")
|
|
|
|
api_url = Environment.get("API_URL", str)
|
|
if api_url is None:
|
|
raise Exception("API_URL is not set")
|
|
|
|
api_key = Environment.get("API_KEY", str)
|
|
if api_key is None:
|
|
raise Exception("API_KEY is not set")
|
|
|
|
request = requests.post(
|
|
f"{api_url}/graphql",
|
|
json={
|
|
"query": f"""
|
|
query getShortUrlsForCache {{
|
|
shortUrls(filter: [{{ deleted: {{ equal: false }} }}, {{ group: {{ deleted: {{ equal: false }} }} }}]) {{
|
|
nodes {{
|
|
id
|
|
shortUrl
|
|
targetUrl
|
|
description
|
|
group {{
|
|
id
|
|
name
|
|
}}
|
|
domain {{
|
|
id
|
|
name
|
|
}}
|
|
loadingScreen
|
|
deleted
|
|
}}
|
|
}}
|
|
|
|
shortUrlsWithoutGroup: shortUrls(filter: [{{ deleted: {{ equal: false }} }}, {{ group: {{ isNull: true }} }}]) {{
|
|
nodes {{
|
|
id
|
|
shortUrl
|
|
targetUrl
|
|
description
|
|
group {{
|
|
id
|
|
name
|
|
}}
|
|
domain {{
|
|
id
|
|
name
|
|
}}
|
|
loadingScreen
|
|
deleted
|
|
}}
|
|
}}
|
|
}}
|
|
""",
|
|
"variables": {},
|
|
},
|
|
headers={"Authorization": f"API-Key {api_key}"},
|
|
)
|
|
data = request.json()
|
|
if "errors" in data:
|
|
logger.warning(f"Failed to get all short urls -> {data["errors"]}")
|
|
|
|
if (
|
|
"data" not in data
|
|
or "shortUrls" not in data["data"]
|
|
or "nodes" not in data["data"]["shortUrls"]
|
|
or "nodes" not in data["data"]["shortUrlsWithoutGroup"]
|
|
):
|
|
raise ValueError("Failed to get all short urls")
|
|
|
|
nodes = [
|
|
*data["data"]["shortUrls"]["nodes"],
|
|
*data["data"]["shortUrlsWithoutGroup"]["nodes"],
|
|
]
|
|
|
|
for node in nodes:
|
|
Cache.set(node["shortUrl"], node)
|
|
|
|
logger.info(f"Loaded {len(nodes)} short urls to cache")
|
|
|
|
|
|
async def configure():
|
|
Logger.set_level(Environment.get("LOG_LEVEL", str, "info"))
|
|
Environment.set_environment(Environment.get("ENVIRONMENT", str, "production"))
|
|
logger.info(f"Environment: {Environment.get_environment()}")
|
|
|
|
_get_all_short_urls()
|
|
|
|
|
|
routes = [
|
|
Route("/", endpoint=index),
|
|
Mount("/static", StaticFiles(directory="static"), name="static"),
|
|
Route("/{path:path}", endpoint=handle_request),
|
|
]
|
|
|
|
app = Starlette(routes=routes, on_startup=[configure])
|
|
|
|
|
|
def main():
|
|
if sys.platform == "win32":
|
|
from asyncio import WindowsSelectorEventLoopPolicy
|
|
|
|
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
|
|
|
|
uvicorn.run(
|
|
app,
|
|
host="0.0.0.0",
|
|
port=Environment.get("PORT", int, 5001),
|
|
log_config=None,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|