Compare commits

..

1 Commits

Author SHA1 Message Date
de3fd0c11f Redirector asgi rewrite
Some checks failed
Test before pr merge / test-translation-lint (pull_request) Successful in 1m18s
Test before pr merge / test-lint (pull_request) Successful in 1m20s
Test before pr merge / test-before-merge (pull_request) Failing after 2m24s
2025-03-10 23:52:13 +01:00
18 changed files with 208 additions and 76 deletions

View File

@ -3,6 +3,7 @@ from typing import Optional, Union
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.websockets import WebSocket
_request_context: ContextVar[Union[Request, None]] = ContextVar("request", default=None)
@ -19,5 +20,9 @@ class RequestMiddleware(BaseHTTPMiddleware):
return response
def set_request(request: Union[Request, WebSocket, None]):
_request_context.set(request)
def get_request() -> Optional[Request]:
return _request_context.get()

View File

@ -0,0 +1,26 @@
from ariadne.asgi.handlers import GraphQLTransportWSHandler
from starlette.datastructures import MutableHeaders
from starlette.websockets import WebSocket
from api.middleware.request import set_request
from core.logger import APILogger
logger = APILogger("WS")
class AuthenticatedGraphQLTransportWSHandler(GraphQLTransportWSHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, on_connect=self.on_connect, **kwargs)
@staticmethod
async def on_connect(ws: WebSocket, message: dict):
if "Authorization" not in message:
return True
mutable_headers = MutableHeaders()
mutable_headers["Authorization"] = message.get("Authorization", "")
ws._headers = mutable_headers
set_request(ws)
return True

View File

@ -3,6 +3,7 @@ from asyncio import iscoroutinefunction
from ariadne import SubscriptionType
from api.middleware.request import get_request
from api_graphql.abc.query_abc import QueryABC
from api_graphql.field.subscription_field_builder import SubscriptionFieldBuilder
from core.logger import APILogger

View File

@ -46,7 +46,8 @@ input ShortUrlFuzzy {
input ShortUrlFilter {
id: IntFilter
name: StringFilter
shortUrl: StringFilter
targetUrl: StringFilter
description: StringFilter
loadingScreen: BooleanFilter
@ -63,6 +64,7 @@ type ShortUrlMutation {
update(input: ShortUrlUpdateInput!): ShortUrl
delete(id: ID!): Boolean
restore(id: ID!): Boolean
trackVisit(id: ID!, agent: String): Boolean
}
input ShortUrlCreateInput {

View File

@ -8,6 +8,8 @@ from data.schemas.public.domain_dao import domainDao
from data.schemas.public.group_dao import groupDao
from data.schemas.public.short_url import ShortUrl
from data.schemas.public.short_url_dao import shortUrlDao
from data.schemas.public.short_url_visit import ShortUrlVisit
from data.schemas.public.short_url_visit_dao import shortUrlVisitDao
from service.permission.permissions_enum import Permissions
logger = APILogger(__name__)
@ -39,6 +41,11 @@ class ShortUrlMutation(MutationABC):
self.resolve_restore,
require_any_permission=[Permissions.short_urls_delete],
)
self.mutation(
"trackVisit",
self.resolve_track_visit,
require_any_permission=[Permissions.short_urls_update],
)
@staticmethod
async def resolve_create(obj: ShortUrlCreateInput, *_):
@ -106,3 +113,9 @@ class ShortUrlMutation(MutationABC):
short_url = await shortUrlDao.get_by_id(id)
await shortUrlDao.restore(short_url)
return True
@staticmethod
async def resolve_track_visit(*_, id: int, agent: str):
logger.debug(f"track visit: {id} -- {agent}")
await shortUrlVisitDao.create(ShortUrlVisit(0, id, agent))
return True

View File

@ -1,6 +1,8 @@
import asyncio
import sys
from typing import Optional
import requests
import uvicorn
from starlette.applications import Starlette
from starlette.requests import Request
@ -9,13 +11,8 @@ from starlette.routing import Route, Mount
from starlette.staticfiles import StaticFiles
from starlette.templating import Jinja2Templates
from core.database.database import Database
from core.environment import Environment
from core.logger import Logger
from data.schemas.public.short_url import ShortUrl
from data.schemas.public.short_url_dao import shortUrlDao
from data.schemas.public.short_url_visit import ShortUrlVisit
from data.schemas.public.short_url_visit_dao import shortUrlVisitDao
logger = Logger(__name__)
templates = Jinja2Templates(directory="templates")
@ -27,16 +24,16 @@ async def index(request: Request):
async def handle_request(request: Request):
path = request.path_params["path"]
short_url = await shortUrlDao.find_single_by({ShortUrl.short_url: path})
short_url = _find_short_url_by_path(path)
if short_url is None:
return templates.TemplateResponse(
"404.html", {"request": request}, status_code=404
)
domains = Environment.get("DOMAINS", list[str], [])
domain = await short_url.domain
domain = short_url["domain"]
logger.debug(
f"Domain: {domain.name if domain is not None else None}, request.host: {request.headers['host']}"
f"Domain: {domain["name"] if domain is not None else None}, request.host: {request.headers['host']}"
)
host = request.headers["host"]
@ -45,8 +42,8 @@ async def handle_request(request: Request):
domain_strict_mode = Environment.get("DOMAIN_STRICT_MODE", bool, False)
if domain is not None and (
domain.name not in domains
or (domain_strict_mode and not host.endswith(domain.name))
domain["name"] not in domains
or (domain_strict_mode and not host.endswith(domain["name"]))
):
return templates.TemplateResponse(
"404.html", {"request": request}, status_code=404
@ -55,39 +52,115 @@ async def handle_request(request: Request):
user_agent = request.headers.get("User-Agent", "").lower()
if "wheregoes" in user_agent or "someothertool" in user_agent:
return await _handle_short_url(request, path, short_url)
return await _handle_short_url(request, short_url)
if short_url.loading_screen:
if short_url["loadingScreen"]:
await _track_visit(request, short_url)
return templates.TemplateResponse(
"redirect.html",
{
"request": request,
"key": short_url.short_url,
"target_url": _get_redirect_url(short_url.target_url),
"key": short_url["shortUrl"],
"target_url": _get_redirect_url(short_url["targetUrl"]),
},
)
return await _handle_short_url(request, path, short_url)
return await _handle_short_url(request, short_url)
async def _handle_short_url(request: Request, path: str, short_url: ShortUrl):
if path.startswith("api/"):
path = path.replace("api/", "")
def _find_short_url_by_path(path: str) -> Optional[dict]:
api_url = Environment.get("API_URL", str)
if api_url is None:
raise Exception("API_URL is not set")
api_key = Environment.get("API_KEY", str)
if api_key is None:
raise Exception("API_KEY is not set")
request = requests.post(
f"{api_url}/graphql",
json={
"query": f"""
query getShortUrlByPath($path: String!) {{
shortUrls(filter: {{ shortUrl: {{ equal: $path }}, deleted: {{ equal: false }} }}) {{
nodes {{
id
shortUrl
targetUrl
description
group {{
id
name
}}
domain {{
id
name
}}
loadingScreen
deleted
}}
}}
}}
""",
"variables": {"path": path},
},
headers={"Authorization": f"API-Key {api_key}"},
)
data = request.json()["data"]["shortUrls"]["nodes"]
if len(data) == 0:
return None
return data[0]
async def _handle_short_url(request: Request, short_url: dict):
await _track_visit(request, short_url)
return RedirectResponse(_get_redirect_url(short_url.target_url))
return RedirectResponse(_get_redirect_url(short_url["targetUrl"]))
async def _track_visit(request: Request, short_url: ShortUrl):
async def _track_visit(r: Request, short_url: dict):
api_url = Environment.get("API_URL", str)
if api_url is None:
raise Exception("API_URL is not set")
api_key = Environment.get("API_KEY", str)
if api_key is None:
raise Exception("API_KEY is not set")
try:
await shortUrlVisitDao.create(
ShortUrlVisit(0, short_url.id, request.headers.get("User-Agent"))
request = requests.post(
f"{api_url}/graphql",
json={
"query": f"""
mutation trackShortUrlVisit($id: ID!, $agent: String) {{
shortUrl {{
trackVisit(id: $id, agent: $agent)
}}
}}
""",
"variables": {
"id": short_url["id"],
"agent": r.headers.get("User-Agent"),
},
},
headers={"Authorization": f"API-Key {api_key}"},
)
if request.status_code != 200:
logger.warning(
f"Failed to track visit for short url {short_url["shortUrl"]}"
)
data = request.json()
if "errors" in data:
raise Exception(data["errors"])
else:
logger.debug(f"Tracked visit for short url {short_url["shortUrl"]}")
except Exception as e:
logger.error(f"Failed to update short url {short_url.short_url} with error", e)
logger.error(
f"Failed to update short url {short_url["shortUrl"]} with error", e
)
def _get_redirect_url(url: str) -> str:
@ -102,8 +175,6 @@ async def configure():
Environment.set_environment(Environment.get("ENVIRONMENT", str, "production"))
logger.info(f"Environment: {Environment.get_environment()}")
await Database.startup_db()
routes = [
Route("/", endpoint=index),

View File

@ -12,6 +12,7 @@ from api.auth.keycloak_client import Keycloak
from api.broadcast import broadcast
from api.middleware.logging import LoggingMiddleware
from api.middleware.request import RequestMiddleware
from api.middleware.websocket import AuthenticatedGraphQLTransportWSHandler
from api.route import Route
from api_graphql.service.schema import schema
from core.database.database import Database
@ -118,7 +119,8 @@ class Startup:
WebSocketRoute(
"/graphql",
endpoint=GraphQL(
schema, websocket_handler=GraphQLTransportWSHandler()
schema,
websocket_handler=AuthenticatedGraphQLTransportWSHandler(),
),
),
],

View File

@ -8,17 +8,17 @@
</head>
<body>
<div class="w-full h-full flex flex-col justify-center items-center">
<div class="flex items-center justify-center">
<div class="relative w-screen h-screen bg-cover bg-center"
<div class="flex items-center justify-center w-full h-full">
<div class="relative w-full h-full bg-cover bg-center"
style="background-image: url('/static/custom/background.jpg')"></div>
<div class="absolute w-1/3 h-2/5 rounded-xl p-5 flex flex-col gap-5 justify-center items-center">
<div class="absolute w-11/12 sm:w-2/3 md:w-1/2 lg:w-1/3 h-2/5 rounded-xl p-5 flex flex-col gap-5 justify-center items-center">
<div class="absolute inset-0 bg-black opacity-70 rounded-xl"></div>
<div class="relative logo flex justify-center items-center">
<img class="h-48" src="/static/custom/logo.png" alt="logo">
<img class="h-24 sm:h-32 md:h-48" src="/static/custom/logo.png" alt="logo">
</div>
<h1 class="relative text-3xl text-white">Redirecting...</h1>
<h1 class="relative text-xl sm:text-2xl md:text-3xl text-white">Redirecting...</h1>
<p class="relative text-white">You will be redirected in <span id="countdown">5</span> seconds.</p>
</div>
</div>

7
web/.gitignore vendored
View File

@ -1 +1,6 @@
config.*.json
config.*.json
dist/
.angular/
node_modules/
coverage/

View File

@ -1,8 +1,15 @@
import { Component } from '@angular/core';
import { KeycloakService } from 'keycloak-angular';
@Component({
selector: 'app-home',
templateUrl: './home.component.html',
styleUrl: './home.component.scss',
})
export class HomeComponent {}
export class HomeComponent {
constructor(private keycloak: KeycloakService) {
if (!this.keycloak.isLoggedIn()) {
this.keycloak.login().then(() => {});
}
}
}

View File

@ -1,7 +1,7 @@
import { HttpInterceptorFn } from "@angular/common/http";
import { KeycloakService } from "keycloak-angular";
import { inject } from "@angular/core";
import { from, switchMap } from "rxjs";
import { HttpInterceptorFn } from '@angular/common/http';
import { KeycloakService } from 'keycloak-angular';
import { inject } from '@angular/core';
import { from, switchMap } from 'rxjs';
export const tokenInterceptor: HttpInterceptorFn = (req, next) => {
const keycloak = inject(KeycloakService);
@ -15,14 +15,14 @@ export const tokenInterceptor: HttpInterceptorFn = (req, next) => {
}
return from(keycloak.getToken()).pipe(
switchMap((token) => {
switchMap(token => {
const modifiedReq = token
? req.clone({
headers: req.headers.set("Authorization", `Bearer ${token}`),
headers: req.headers.set('Authorization', `Bearer ${token}`),
})
: req;
return next(modifiedReq);
}),
})
);
};

View File

@ -113,7 +113,7 @@ export class DomainsDataService
return this.apollo
.subscribe<{ domainChange: void }>({
query: gql`
subscription onRoleChange {
subscription onDomainChange {
domainChange
}
`,

View File

@ -30,8 +30,8 @@ export class DomainsPage extends PageBase<
});
}
load(): void {
this.loading = true;
load(silent?: boolean): void {
if (!silent) this.loading = true;
this.dataService
.load(this.filter, this.sort, this.skip, this.take)
.subscribe(result => {

View File

@ -121,7 +121,7 @@ export class GroupsDataService
return this.apollo
.subscribe<{ groupChange: void }>({
query: gql`
subscription onRoleChange {
subscription onGroupChange {
groupChange
}
`,

View File

@ -1,16 +1,16 @@
import { Component } from "@angular/core";
import { PageBase } from "src/app/core/base/page-base";
import { ToastService } from "src/app/service/toast.service";
import { ConfirmationDialogService } from "src/app/service/confirmation-dialog.service";
import { PermissionsEnum } from "src/app/model/auth/permissionsEnum";
import { Group } from "src/app/model/entities/group";
import { GroupsDataService } from "src/app/modules/admin/groups/groups.data.service";
import { GroupsColumns } from "src/app/modules/admin/groups/groups.columns";
import { Component } from '@angular/core';
import { PageBase } from 'src/app/core/base/page-base';
import { ToastService } from 'src/app/service/toast.service';
import { ConfirmationDialogService } from 'src/app/service/confirmation-dialog.service';
import { PermissionsEnum } from 'src/app/model/auth/permissionsEnum';
import { Group } from 'src/app/model/entities/group';
import { GroupsDataService } from 'src/app/modules/admin/groups/groups.data.service';
import { GroupsColumns } from 'src/app/modules/admin/groups/groups.columns';
@Component({
selector: "app-groups",
templateUrl: "./groups.page.html",
styleUrl: "./groups.page.scss",
selector: 'app-groups',
templateUrl: './groups.page.html',
styleUrl: './groups.page.scss',
})
export class GroupsPage extends PageBase<
Group,
@ -19,7 +19,7 @@ export class GroupsPage extends PageBase<
> {
constructor(
private toast: ToastService,
private confirmation: ConfirmationDialogService,
private confirmation: ConfirmationDialogService
) {
super(true, {
read: [PermissionsEnum.groups],
@ -30,11 +30,11 @@ export class GroupsPage extends PageBase<
});
}
load(): void {
this.loading = true;
load(silent?: boolean): void {
if (!silent) this.loading = true;
this.dataService
.load(this.filter, this.sort, this.skip, this.take)
.subscribe((result) => {
.subscribe(result => {
this.result = result;
this.loading = false;
});
@ -42,12 +42,12 @@ export class GroupsPage extends PageBase<
delete(group: Group): void {
this.confirmation.confirmDialog({
header: "dialog.delete.header",
message: "dialog.delete.message",
header: 'dialog.delete.header',
message: 'dialog.delete.message',
accept: () => {
this.loading = true;
this.dataService.delete(group).subscribe(() => {
this.toast.success("action.deleted");
this.toast.success('action.deleted');
this.load();
});
},
@ -57,12 +57,12 @@ export class GroupsPage extends PageBase<
restore(group: Group): void {
this.confirmation.confirmDialog({
header: "dialog.restore.header",
message: "dialog.restore.message",
header: 'dialog.restore.header',
message: 'dialog.restore.message',
accept: () => {
this.loading = true;
this.dataService.restore(group).subscribe(() => {
this.toast.success("action.restored");
this.toast.success('action.restored');
this.load();
});
},

View File

@ -139,7 +139,7 @@ export class ShortUrlsDataService
return this.apollo
.subscribe<{ shortUrlChange: void }>({
query: gql`
subscription onRoleChange {
subscription onShortUrlChange {
shortUrlChange
}
`,

View File

@ -95,8 +95,8 @@ export class ShortUrlsPage
};
}
load(): void {
this.loading = true;
load(silent?: boolean): void {
if (!silent) this.loading = true;
this.dataService
.load(this.filter, this.sort, this.skip, this.take)
.subscribe(result => {

View File

@ -70,6 +70,7 @@ import { ConfigService } from 'src/app/service/config.service';
import { Logger } from 'src/app/service/logger.service';
import { Router } from '@angular/router';
import { SliderModule } from 'primeng/slider';
import { KeycloakService } from 'keycloak-angular';
const sharedModules = [
StepsModule,
@ -148,11 +149,11 @@ function debounce(func: (...args: unknown[]) => void, wait: number) {
exports: [...sharedModules, ...sharedComponents],
providers: [
provideHttpClient(withInterceptors([tokenInterceptor])),
provideApollo(() => {
const logger = new Logger('graphql');
const settings = inject(ConfigService);
const keycloak = inject(KeycloakService);
const httpLink = inject(HttpLink);
const router = inject(Router);
@ -168,9 +169,11 @@ function debounce(func: (...args: unknown[]) => void, wait: number) {
retryAttempts: Infinity,
shouldRetry: () => true,
keepAlive: 10000,
connectionParams: () => ({
authToken: localStorage.getItem('token'),
}),
connectionParams: async () => {
return {
Authorization: `Bearer ${await keycloak.getToken()}`,
};
},
on: {
connected: () => {
logger.info('WebSocket connected');
@ -187,10 +190,7 @@ function debounce(func: (...args: unknown[]) => void, wait: number) {
})
);
// Using the ability to split links, you can send data to each link
// depending on what kind of operation is being sent
const link = split(
// Split based on operation type
({ query }) => {
const definition = getMainDefinition(query);
return (