Redirector asgi rewrite
All checks were successful
Test before pr merge / test-translation-lint (pull_request) Successful in 34s
Test before pr merge / test-lint (pull_request) Successful in 38s
Test before pr merge / test-before-merge (pull_request) Successful in 1m35s

This commit is contained in:
Sven Heidemann 2025-03-08 10:08:03 +01:00
parent 993654dabd
commit 520e898dff
19 changed files with 274 additions and 131 deletions

View File

@ -3,6 +3,7 @@ from typing import Optional, Union
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.websockets import WebSocket
_request_context: ContextVar[Union[Request, None]] = ContextVar("request", default=None)
@ -19,5 +20,9 @@ class RequestMiddleware(BaseHTTPMiddleware):
return response
def set_request(request: Union[Request, WebSocket, None]):
_request_context.set(request)
def get_request() -> Optional[Request]:
return _request_context.get()

View File

@ -0,0 +1,26 @@
from ariadne.asgi.handlers import GraphQLTransportWSHandler
from starlette.datastructures import MutableHeaders
from starlette.websockets import WebSocket
from api.middleware.request import set_request
from core.logger import APILogger
logger = APILogger("WS")
class AuthenticatedGraphQLTransportWSHandler(GraphQLTransportWSHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, on_connect=self.on_connect, **kwargs)
@staticmethod
async def on_connect(ws: WebSocket, message: dict):
if "Authorization" not in message:
return True
mutable_headers = MutableHeaders()
mutable_headers["Authorization"] = message.get("Authorization", "")
ws._headers = mutable_headers
set_request(ws)
return True

View File

@ -3,6 +3,7 @@ from asyncio import iscoroutinefunction
from ariadne import SubscriptionType
from api.middleware.request import get_request
from api_graphql.abc.query_abc import QueryABC
from api_graphql.field.subscription_field_builder import SubscriptionFieldBuilder
from core.logger import APILogger

View File

@ -46,7 +46,8 @@ input ShortUrlFuzzy {
input ShortUrlFilter {
id: IntFilter
name: StringFilter
shortUrl: StringFilter
targetUrl: StringFilter
description: StringFilter
loadingScreen: BooleanFilter
@ -63,6 +64,7 @@ type ShortUrlMutation {
update(input: ShortUrlUpdateInput!): ShortUrl
delete(id: ID!): Boolean
restore(id: ID!): Boolean
trackVisit(id: ID!, agent: String): Boolean
}
input ShortUrlCreateInput {

View File

@ -8,6 +8,8 @@ from data.schemas.public.domain_dao import domainDao
from data.schemas.public.group_dao import groupDao
from data.schemas.public.short_url import ShortUrl
from data.schemas.public.short_url_dao import shortUrlDao
from data.schemas.public.short_url_visit import ShortUrlVisit
from data.schemas.public.short_url_visit_dao import shortUrlVisitDao
from service.permission.permissions_enum import Permissions
logger = APILogger(__name__)
@ -39,6 +41,11 @@ class ShortUrlMutation(MutationABC):
self.resolve_restore,
require_any_permission=[Permissions.short_urls_delete],
)
self.mutation(
"trackVisit",
self.resolve_track_visit,
require_any_permission=[Permissions.short_urls_update],
)
@staticmethod
async def resolve_create(obj: ShortUrlCreateInput, *_):
@ -106,3 +113,9 @@ class ShortUrlMutation(MutationABC):
short_url = await shortUrlDao.get_by_id(id)
await shortUrlDao.restore(short_url)
return True
@staticmethod
async def resolve_track_visit(*_, id: int, agent: str):
logger.debug(f"track visit: {id} -- {agent}")
await shortUrlVisitDao.create(ShortUrlVisit(0, id, agent))
return True

View File

@ -1,117 +1,188 @@
import asyncio
import sys
from typing import Optional
import eventlet
from eventlet import wsgi
from flask import Flask, request, Response, redirect, render_template
import requests
import uvicorn
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import RedirectResponse
from starlette.routing import Route, Mount
from starlette.staticfiles import StaticFiles
from starlette.templating import Jinja2Templates
from core.database.database import Database
from core.environment import Environment
from core.logger import Logger
from data.schemas.public.short_url import ShortUrl
from data.schemas.public.short_url_dao import shortUrlDao
from data.schemas.public.short_url_visit import ShortUrlVisit
from data.schemas.public.short_url_visit_dao import shortUrlVisitDao
logger = Logger(__name__)
templates = Jinja2Templates(directory="templates")
class Redirector(Flask):
def __init__(self, *args, **kwargs):
Flask.__init__(self, *args, **kwargs)
async def index(request: Request):
return templates.TemplateResponse("404.html", {"request": request}, status_code=404)
app = Redirector(__name__)
@app.route("/")
def index():
return render_template("404.html"), 404
@app.route("/<path:path>")
async def _handle_request(path: str):
short_url = await _find_short_url_by_url(path)
async def handle_request(request: Request):
path = request.path_params["path"]
short_url = _find_short_url_by_path(path)
if short_url is None:
return render_template("404.html"), 404
return templates.TemplateResponse(
"404.html", {"request": request}, status_code=404
)
domains = Environment.get("DOMAINS", list[str], [])
domain = await short_url.domain
domain = short_url["domain"]
logger.debug(
f"Domain: {domain.name if domain is not None else None}, request.host: {request.host}"
f"Domain: {domain["name"] if domain is not None else None}, request.host: {request.headers['host']}"
)
host = request.host
host = request.headers["host"]
if ":" in host:
host = host.split(":")[0]
domain_strict_mode = Environment.get("DOMAIN_STRICT_MODE", bool, False)
if domain is not None and (
domain.name not in domains
or (domain_strict_mode and not host.endswith(domain.name))
domain["name"] not in domains
or (domain_strict_mode and not host.endswith(domain["name"]))
):
return render_template("404.html"), 404
return templates.TemplateResponse(
"404.html", {"request": request}, status_code=404
)
user_agent = request.headers.get("User-Agent", "").lower()
if "wheregoes" in user_agent or "someothertool" in user_agent:
return await _handle_short_url(path, short_url)
return await _handle_short_url(request, short_url)
if short_url.loading_screen:
await _track_visit(short_url)
if short_url["loadingScreen"]:
await _track_visit(request, short_url)
return render_template(
return templates.TemplateResponse(
"redirect.html",
key=short_url.short_url,
target_url=_get_redirect_url(short_url.target_url),
{
"request": request,
"key": short_url["shortUrl"],
"target_url": _get_redirect_url(short_url["targetUrl"]),
},
)
return await _handle_short_url(path, short_url)
return await _handle_short_url(request, short_url)
async def _handle_short_url(path: str, short_url: ShortUrl):
if path.startswith("api/"):
path = path.replace("api/", "")
def _find_short_url_by_path(path: str) -> Optional[dict]:
api_url = Environment.get("API_URL", str)
if api_url is None:
raise Exception("API_URL is not set")
await _track_visit(short_url)
api_key = Environment.get("API_KEY", str)
if api_key is None:
raise Exception("API_KEY is not set")
return _do_redirect(short_url.target_url)
request = requests.post(
f"{api_url}/graphql",
json={
"query": f"""
query getShortUrlByPath($path: String!) {{
shortUrls(filter: {{ shortUrl: {{ equal: $path }}, deleted: {{ equal: false }} }}) {{
nodes {{
id
shortUrl
targetUrl
description
group {{
id
name
}}
domain {{
id
name
}}
loadingScreen
deleted
}}
}}
}}
""",
"variables": {"path": path},
},
headers={"Authorization": f"API-Key {api_key}"},
)
data = request.json()["data"]["shortUrls"]["nodes"]
if len(data) == 0:
return None
return data[0]
async def _track_visit(short_url: ShortUrl):
async def _handle_short_url(request: Request, short_url: dict):
await _track_visit(request, short_url)
return RedirectResponse(_get_redirect_url(short_url["targetUrl"]))
async def _track_visit(r: Request, short_url: dict):
api_url = Environment.get("API_URL", str)
if api_url is None:
raise Exception("API_URL is not set")
api_key = Environment.get("API_KEY", str)
if api_key is None:
raise Exception("API_KEY is not set")
try:
await shortUrlVisitDao.create(
ShortUrlVisit(0, short_url.id, request.headers.get("User-Agent"))
request = requests.post(
f"{api_url}/graphql",
json={
"query": f"""
mutation trackShortUrlVisit($id: ID!, $agent: String) {{
shortUrl {{
trackVisit(id: $id, agent: $agent)
}}
}}
""",
"variables": {
"id": short_url["id"],
"agent": r.headers.get("User-Agent"),
},
},
headers={"Authorization": f"API-Key {api_key}"},
)
if request.status_code != 200:
logger.warning(
f"Failed to track visit for short url {short_url["shortUrl"]}"
)
data = request.json()
if "errors" in data:
raise Exception(data["errors"])
else:
logger.debug(f"Tracked visit for short url {short_url["shortUrl"]}")
except Exception as e:
logger.error(f"Failed to update short url {short_url.short_url} with error", e)
async def _find_short_url_by_url(url: str) -> ShortUrl:
return await shortUrlDao.find_single_by({ShortUrl.short_url: url})
logger.error(
f"Failed to update short url {short_url["shortUrl"]} with error", e
)
def _get_redirect_url(url: str) -> str:
# todo: multiple protocols like ts3://
if not url.startswith("http://") and not url.startswith("https://"):
url = f"http://{url}"
return url
def _do_redirect(url: str) -> Response:
return redirect(_get_redirect_url(url))
async def configure():
Logger.set_level(Environment.get("LOG_LEVEL", str, "info"))
Environment.set_environment(Environment.get("ENVIRONMENT", str, "production"))
logger.info(f"Environment: {Environment.get_environment()}")
app.debug = Environment.get_environment() == "development"
await Database.startup_db()
routes = [
Route("/", endpoint=index),
Mount("/static", StaticFiles(directory="static"), name="static"),
Route("/{path:path}", endpoint=handle_request),
]
app = Starlette(routes=routes, on_startup=[configure])
def main():
@ -120,26 +191,13 @@ def main():
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
loop = asyncio.new_event_loop()
loop.run_until_complete(configure())
loop.close()
port = Environment.get("PORT", int, 5001)
logger.info(f"Start API on port: {port}")
if Environment.get_environment() == "development":
logger.info(f"Playground: http://localhost:{port}/")
wsgi.server(eventlet.listen(("0.0.0.0", port)), app, log_output=False)
uvicorn.run(
app,
host="0.0.0.0",
port=Environment.get("PORT", int, 5001),
log_config=None,
)
if __name__ == "__main__":
main()
# ((
# ( )
# ; / ,
# / \/
# / |
# / ~/
# / ) ) ~ edraft
# ___// | /
# --' \_~-,

View File

@ -12,6 +12,7 @@ from api.auth.keycloak_client import Keycloak
from api.broadcast import broadcast
from api.middleware.logging import LoggingMiddleware
from api.middleware.request import RequestMiddleware
from api.middleware.websocket import AuthenticatedGraphQLTransportWSHandler
from api.route import Route
from api_graphql.service.schema import schema
from core.database.database import Database
@ -118,7 +119,8 @@ class Startup:
WebSocketRoute(
"/graphql",
endpoint=GraphQL(
schema, websocket_handler=GraphQLTransportWSHandler()
schema,
websocket_handler=AuthenticatedGraphQLTransportWSHandler(),
),
),
],

View File

@ -8,17 +8,17 @@
</head>
<body>
<div class="w-full h-full flex flex-col justify-center items-center">
<div class="flex items-center justify-center">
<div class="relative w-screen h-screen bg-cover bg-center"
<div class="flex items-center justify-center w-full h-full">
<div class="relative w-full h-full bg-cover bg-center"
style="background-image: url('/static/custom/background.jpg')"></div>
<div class="absolute w-1/3 h-2/5 rounded-xl p-5 flex flex-col gap-5 justify-center items-center">
<div class="absolute w-11/12 sm:w-2/3 md:w-1/2 lg:w-1/3 h-2/5 rounded-xl p-5 flex flex-col gap-5 justify-center items-center">
<div class="absolute inset-0 bg-black opacity-70 rounded-xl"></div>
<div class="relative logo flex justify-center items-center">
<img class="h-48" src="/static/custom/logo.png" alt="logo">
<img class="h-24 sm:h-32 md:h-48" src="/static/custom/logo.png" alt="logo">
</div>
<h1 class="relative text-3xl text-white">Redirecting...</h1>
<h1 class="relative text-xl sm:text-2xl md:text-3xl text-white">Redirecting...</h1>
<p class="relative text-white">You will be redirected in <span id="countdown">5</span> seconds.</p>
</div>
</div>

7
web/.gitignore vendored
View File

@ -1 +1,6 @@
config.*.json
config.*.json
dist/
.angular/
node_modules/
coverage/

View File

@ -1,14 +1,38 @@
import { ComponentFixture, TestBed } from "@angular/core/testing";
import { ComponentFixture, TestBed } from '@angular/core/testing';
import { HomeComponent } from "./home.component";
import { HomeComponent } from './home.component';
import { SharedModule } from 'src/app/modules/shared/shared.module';
import { TranslateModule } from '@ngx-translate/core';
import { AuthService } from 'src/app/service/auth.service';
import { KeycloakService } from 'keycloak-angular';
import { ErrorHandlingService } from 'src/app/service/error-handling.service';
import { ToastService } from 'src/app/service/toast.service';
import { ConfirmationService, MessageService } from 'primeng/api';
import { ActivatedRoute } from '@angular/router';
import { of } from 'rxjs';
describe("HomeComponent", () => {
describe('HomeComponent', () => {
let component: HomeComponent;
let fixture: ComponentFixture<HomeComponent>;
beforeEach(async () => {
await TestBed.configureTestingModule({
declarations: [HomeComponent],
imports: [SharedModule, TranslateModule.forRoot()],
providers: [
AuthService,
KeycloakService,
ErrorHandlingService,
ToastService,
MessageService,
ConfirmationService,
{
provide: ActivatedRoute,
useValue: {
snapshot: { params: of({}) },
},
},
],
}).compileComponents();
fixture = TestBed.createComponent(HomeComponent);
@ -16,7 +40,7 @@ describe("HomeComponent", () => {
fixture.detectChanges();
});
it("should create", () => {
it('should create', () => {
expect(component).toBeTruthy();
});
});

View File

@ -1,8 +1,15 @@
import { Component } from '@angular/core';
import { KeycloakService } from 'keycloak-angular';
@Component({
selector: 'app-home',
templateUrl: './home.component.html',
styleUrl: './home.component.scss',
})
export class HomeComponent {}
export class HomeComponent {
constructor(private keycloak: KeycloakService) {
if (!this.keycloak.isLoggedIn()) {
this.keycloak.login().then(() => {});
}
}
}

View File

@ -1,7 +1,7 @@
import { HttpInterceptorFn } from "@angular/common/http";
import { KeycloakService } from "keycloak-angular";
import { inject } from "@angular/core";
import { from, switchMap } from "rxjs";
import { HttpInterceptorFn } from '@angular/common/http';
import { KeycloakService } from 'keycloak-angular';
import { inject } from '@angular/core';
import { from, switchMap } from 'rxjs';
export const tokenInterceptor: HttpInterceptorFn = (req, next) => {
const keycloak = inject(KeycloakService);
@ -15,14 +15,14 @@ export const tokenInterceptor: HttpInterceptorFn = (req, next) => {
}
return from(keycloak.getToken()).pipe(
switchMap((token) => {
switchMap(token => {
const modifiedReq = token
? req.clone({
headers: req.headers.set("Authorization", `Bearer ${token}`),
headers: req.headers.set('Authorization', `Bearer ${token}`),
})
: req;
return next(modifiedReq);
}),
})
);
};

View File

@ -113,7 +113,7 @@ export class DomainsDataService
return this.apollo
.subscribe<{ domainChange: void }>({
query: gql`
subscription onRoleChange {
subscription onDomainChange {
domainChange
}
`,

View File

@ -30,8 +30,8 @@ export class DomainsPage extends PageBase<
});
}
load(): void {
this.loading = true;
load(silent?: boolean): void {
if (!silent) this.loading = true;
this.dataService
.load(this.filter, this.sort, this.skip, this.take)
.subscribe(result => {

View File

@ -121,7 +121,7 @@ export class GroupsDataService
return this.apollo
.subscribe<{ groupChange: void }>({
query: gql`
subscription onRoleChange {
subscription onGroupChange {
groupChange
}
`,

View File

@ -1,16 +1,16 @@
import { Component } from "@angular/core";
import { PageBase } from "src/app/core/base/page-base";
import { ToastService } from "src/app/service/toast.service";
import { ConfirmationDialogService } from "src/app/service/confirmation-dialog.service";
import { PermissionsEnum } from "src/app/model/auth/permissionsEnum";
import { Group } from "src/app/model/entities/group";
import { GroupsDataService } from "src/app/modules/admin/groups/groups.data.service";
import { GroupsColumns } from "src/app/modules/admin/groups/groups.columns";
import { Component } from '@angular/core';
import { PageBase } from 'src/app/core/base/page-base';
import { ToastService } from 'src/app/service/toast.service';
import { ConfirmationDialogService } from 'src/app/service/confirmation-dialog.service';
import { PermissionsEnum } from 'src/app/model/auth/permissionsEnum';
import { Group } from 'src/app/model/entities/group';
import { GroupsDataService } from 'src/app/modules/admin/groups/groups.data.service';
import { GroupsColumns } from 'src/app/modules/admin/groups/groups.columns';
@Component({
selector: "app-groups",
templateUrl: "./groups.page.html",
styleUrl: "./groups.page.scss",
selector: 'app-groups',
templateUrl: './groups.page.html',
styleUrl: './groups.page.scss',
})
export class GroupsPage extends PageBase<
Group,
@ -19,7 +19,7 @@ export class GroupsPage extends PageBase<
> {
constructor(
private toast: ToastService,
private confirmation: ConfirmationDialogService,
private confirmation: ConfirmationDialogService
) {
super(true, {
read: [PermissionsEnum.groups],
@ -30,11 +30,11 @@ export class GroupsPage extends PageBase<
});
}
load(): void {
this.loading = true;
load(silent?: boolean): void {
if (!silent) this.loading = true;
this.dataService
.load(this.filter, this.sort, this.skip, this.take)
.subscribe((result) => {
.subscribe(result => {
this.result = result;
this.loading = false;
});
@ -42,12 +42,12 @@ export class GroupsPage extends PageBase<
delete(group: Group): void {
this.confirmation.confirmDialog({
header: "dialog.delete.header",
message: "dialog.delete.message",
header: 'dialog.delete.header',
message: 'dialog.delete.message',
accept: () => {
this.loading = true;
this.dataService.delete(group).subscribe(() => {
this.toast.success("action.deleted");
this.toast.success('action.deleted');
this.load();
});
},
@ -57,12 +57,12 @@ export class GroupsPage extends PageBase<
restore(group: Group): void {
this.confirmation.confirmDialog({
header: "dialog.restore.header",
message: "dialog.restore.message",
header: 'dialog.restore.header',
message: 'dialog.restore.message',
accept: () => {
this.loading = true;
this.dataService.restore(group).subscribe(() => {
this.toast.success("action.restored");
this.toast.success('action.restored');
this.load();
});
},

View File

@ -139,7 +139,7 @@ export class ShortUrlsDataService
return this.apollo
.subscribe<{ shortUrlChange: void }>({
query: gql`
subscription onRoleChange {
subscription onShortUrlChange {
shortUrlChange
}
`,

View File

@ -95,8 +95,8 @@ export class ShortUrlsPage
};
}
load(): void {
this.loading = true;
load(silent?: boolean): void {
if (!silent) this.loading = true;
this.dataService
.load(this.filter, this.sort, this.skip, this.take)
.subscribe(result => {

View File

@ -70,6 +70,7 @@ import { ConfigService } from 'src/app/service/config.service';
import { Logger } from 'src/app/service/logger.service';
import { Router } from '@angular/router';
import { SliderModule } from 'primeng/slider';
import { KeycloakService } from 'keycloak-angular';
const sharedModules = [
StepsModule,
@ -148,11 +149,11 @@ function debounce(func: (...args: unknown[]) => void, wait: number) {
exports: [...sharedModules, ...sharedComponents],
providers: [
provideHttpClient(withInterceptors([tokenInterceptor])),
provideApollo(() => {
const logger = new Logger('graphql');
const settings = inject(ConfigService);
const keycloak = inject(KeycloakService);
const httpLink = inject(HttpLink);
const router = inject(Router);
@ -168,9 +169,11 @@ function debounce(func: (...args: unknown[]) => void, wait: number) {
retryAttempts: Infinity,
shouldRetry: () => true,
keepAlive: 10000,
connectionParams: () => ({
authToken: localStorage.getItem('token'),
}),
connectionParams: async () => {
return {
Authorization: `Bearer ${await keycloak.getToken()}`,
};
},
on: {
connected: () => {
logger.info('WebSocket connected');
@ -187,10 +190,7 @@ function debounce(func: (...args: unknown[]) => void, wait: number) {
})
);
// Using the ability to split links, you can send data to each link
// depending on what kind of operation is being sent
const link = split(
// Split based on operation type
({ query }) => {
const definition = getMainDefinition(query);
return (