Merge pull request 'dev' (#17) from dev into master
All checks were successful
Build on push / prepare (push) Successful in 5s
Build on push / build-redirector (push) Successful in 29s
Build on push / build-api (push) Successful in 29s
Build on push / build-web (push) Successful in 53s

Reviewed-on: #17
This commit is contained in:
Sven Heidemann 2025-03-12 10:09:01 +01:00
commit 433188995e
224 changed files with 7103 additions and 1862 deletions

View File

@ -1,13 +1,48 @@
name: Build dev on push
run-name: Build dev on push
name: Build on push
run-name: Build on push
on:
push:
branches:
- dev
jobs:
prepare:
runs-on: [runner]
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
uses: https://github.com/actions/checkout@v3
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Get Date and Build Number
run: |
git fetch --tags
git tag
DATE=$(date +'%Y.%m.%d')
TAG_COUNT=$(git tag -l "${DATE}.*" | wc -l)
BUILD_NUMBER=$(($TAG_COUNT + 1))
BUILD_VERSION="${DATE}.${BUILD_NUMBER}-dev"
echo "$BUILD_VERSION" > version.txt
echo "VERSION $BUILD_VERSION"
- name: Create Git Tag for Build
run: |
git config user.name "ci"
git config user.email "dev@sh-edraft.de"
echo "tag $(cat version.txt)"
git tag $(cat version.txt)
git push origin --tags
- name: Upload build version artifact
uses: actions/upload-artifact@v3
with:
name: version
path: version.txt
build-api:
runs-on: [runner]
needs: prepare
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
@ -15,10 +50,16 @@ jobs:
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Download build version artifact
uses: actions/download-artifact@v3
with:
name: version
- name: Build docker
run: |
cd api
docker build -t git.sh-edraft.de/sh-edraft.de/open-redirect-api-dev:$(cat ../version.txt) .
echo "VERSION = \"$(cat version.txt)\"" > version.py
docker build --no-cache -t git.sh-edraft.de/sh-edraft.de/open-redirect-api:$(cat ../version.txt) .
- name: Login to registry git.sh-edraft.de
uses: https://github.com/docker/login-action@v1
@ -29,10 +70,11 @@ jobs:
- name: Push image
run: |
docker push git.sh-edraft.de/sh-edraft.de/open-redirect-api-dev:$(cat version.txt)
docker push git.sh-edraft.de/sh-edraft.de/open-redirect-api:$(cat version.txt)
build-redirector:
runs-on: [runner]
needs: prepare
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
@ -40,10 +82,15 @@ jobs:
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Download build version artifact
uses: actions/download-artifact@v3
with:
name: version
- name: Build docker
run: |
cd api
docker build -f dockerfile_redirector -t git.sh-edraft.de/sh-edraft.de/open-redirect-redirector-dev:$(cat ../version.txt) .
docker build --no-cache -f dockerfile_redirector -t git.sh-edraft.de/sh-edraft.de/open-redirect-redirector:$(cat ../version.txt) .
- name: Login to registry git.sh-edraft.de
uses: https://github.com/docker/login-action@v1
@ -54,10 +101,11 @@ jobs:
- name: Push image
run: |
docker push git.sh-edraft.de/sh-edraft.de/open-redirect-redirector-dev:$(cat version.txt)
docker push git.sh-edraft.de/sh-edraft.de/open-redirect-redirector:$(cat version.txt)
build-web:
runs-on: [runner]
needs: prepare
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
@ -65,6 +113,11 @@ jobs:
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Download build version artifact
uses: actions/download-artifact@v3
with:
name: version
- name: Prepare web build
run: |
cd web
@ -78,7 +131,7 @@ jobs:
- name: Build docker
run: |
cd web
docker build -t git.sh-edraft.de/sh-edraft.de/open-redirect-web-dev:$(cat ../version.txt) .
docker build --no-cache -t git.sh-edraft.de/sh-edraft.de/open-redirect-web:$(cat ../version.txt) .
- name: Login to registry git.sh-edraft.de
uses: https://github.com/docker/login-action@v1
@ -89,4 +142,4 @@ jobs:
- name: Push image
run: |
docker push git.sh-edraft.de/sh-edraft.de/open-redirect-web-dev:$(cat version.txt)
docker push git.sh-edraft.de/sh-edraft.de/open-redirect-web:$(cat version.txt)

View File

@ -6,7 +6,7 @@ on:
- master
jobs:
build-api:
prepare:
runs-on: [runner]
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
@ -15,10 +15,51 @@ jobs:
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Get Date and Build Number
run: |
git fetch
git tag
DATE=$(date +'%Y.%m.%d')
TAG_COUNT=$(git tag -l "${DATE}.*" | wc -l)
BUILD_NUMBER=$(($TAG_COUNT + 1))
BUILD_VERSION="${DATE}.${BUILD_NUMBER}"
echo "$BUILD_VERSION" > version.txt
echo "VERSION $BUILD_VERSION"
- name: Create Git Tag for Build
run: |
git config user.name "ci"
git config user.email "dev@sh-edraft.de"
echo "tag $(cat version.txt)"
git tag $(cat version.txt)
git push origin --tags
- name: Upload build version artifact
uses: actions/upload-artifact@v3
with:
name: version
path: version.txt
build-api:
runs-on: [runner]
needs: prepare
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
uses: https://github.com/actions/checkout@v3
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Download build version artifact
uses: actions/download-artifact@v3
with:
name: version
- name: Build docker
run: |
cd api
docker build -t git.sh-edraft.de/sh-edraft.de/open-redirect-api:$(cat ../version.txt) .
echo "VERSION = \"$(cat version.txt)\"" > version.py
docker build --no-cache -t git.sh-edraft.de/sh-edraft.de/open-redirect-api:$(cat ../version.txt) .
- name: Login to registry git.sh-edraft.de
uses: https://github.com/docker/login-action@v1
@ -30,9 +71,10 @@ jobs:
- name: Push image
run: |
docker push git.sh-edraft.de/sh-edraft.de/open-redirect-api:$(cat version.txt)
build-redirector:
runs-on: [runner]
needs: prepare
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
@ -40,10 +82,15 @@ jobs:
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Download build version artifact
uses: actions/download-artifact@v3
with:
name: version
- name: Build docker
run: |
cd api
docker build -f dockerfile_redirector -t git.sh-edraft.de/sh-edraft.de/open-redirect-redirector:$(cat ../version.txt) .
docker build --no-cache -f dockerfile_redirector -t git.sh-edraft.de/sh-edraft.de/open-redirect-redirector:$(cat ../version.txt) .
- name: Login to registry git.sh-edraft.de
uses: https://github.com/docker/login-action@v1
@ -58,6 +105,7 @@ jobs:
build-web:
runs-on: [runner]
needs: prepare
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
@ -65,6 +113,11 @@ jobs:
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Download build version artifact
uses: actions/download-artifact@v3
with:
name: version
- name: Prepare web build
run: |
cd web
@ -78,7 +131,7 @@ jobs:
- name: Build docker
run: |
cd web
docker build -t git.sh-edraft.de/sh-edraft.de/open-redirect-web:$(cat ../version.txt) .
docker build --no-cache -t git.sh-edraft.de/sh-edraft.de/open-redirect-web:$(cat ../version.txt) .
- name: Login to registry git.sh-edraft.de
uses: https://github.com/docker/login-action@v1

View File

@ -0,0 +1,29 @@
name: Test before pr merge
run-name: Test before pr merge
on:
pull_request:
types:
- opened
- edited
- reopened
- synchronize
- ready_for_review
jobs:
test-lint:
runs-on: [ runner ]
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
uses: https://github.com/actions/checkout@v3
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Installing dependencies
working-directory: ./api
run: |
python3.12 -m pip install -r requirements-dev.txt
- name: Checking black
working-directory: ./api
run: python3.12 -m black src --check

View File

@ -1,39 +0,0 @@
name: Test before pr merge
run-name: Test before pr merge
on:
pull_request:
types:
- opened
- edited
- reopened
- synchronize
- ready_for_review
jobs:
test-before-merge:
runs-on: [ runner ]
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
uses: https://github.com/actions/checkout@v3
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Setup node
uses: https://github.com/actions/setup-node@v3
- name: Installing dependencies
run: npm ci
- name: Checking eslint
run: npm run lint
- name: Setup chrome
run: |
wget -q -O - https://dl-ssl.google.com/linux/linux_signing_key.pub | apt-key add -
echo "deb http://dl.google.com/linux/chrome/deb/ stable main" > /etc/apt/sources.list.d/google.list
apt-get update
apt-get install -y google-chrome-stable xvfb
- name: Testing
run: npm run test:ci

View File

@ -0,0 +1,79 @@
name: Test before pr merge
run-name: Test before pr merge
on:
pull_request:
types:
- opened
- edited
- reopened
- synchronize
- ready_for_review
jobs:
test-lint:
runs-on: [ runner ]
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
uses: https://github.com/actions/checkout@v3
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Setup node
uses: https://github.com/actions/setup-node@v3
- name: Installing dependencies
working-directory: ./web
run: npm ci
- name: Checking eslint
working-directory: ./web
run: npm run lint
test-translation-lint:
runs-on: [ runner ]
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
uses: https://github.com/actions/checkout@v3
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Setup node
uses: https://github.com/actions/setup-node@v3
- name: Installing dependencies
working-directory: ./web
run: npm ci
- name: Checking translations
working-directory: ./web
run: npm run lint:translations
test-before-merge:
runs-on: [ runner ]
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
uses: https://github.com/actions/checkout@v3
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Setup node
uses: https://github.com/actions/setup-node@v3
- name: Installing dependencies
working-directory: ./web
run: npm ci
- name: Setup chrome
working-directory: ./web
run: |
wget -q -O - https://dl-ssl.google.com/linux/linux_signing_key.pub | apt-key add -
echo "deb http://dl.google.com/linux/chrome/deb/ stable main" > /etc/apt/sources.list.d/google.list
apt-get update
apt-get install -y google-chrome-stable xvfb
- name: Testing
working-directory: ./web
run: npm run test:ci

View File

@ -1,10 +1,13 @@
ariadne==0.23.0
eventlet==0.37.0
broadcaster==0.3.1
graphql-core==3.2.5
Flask[async]==3.1.0
Flask-Cors==5.0.0
async-property==0.2.2
python-keycloak==4.7.3
psycopg[binary]==3.2.3
psycopg-pool==3.2.4
Werkzeug==3.1.3
uvicorn==0.34.0
starlette==0.46.0
requests==2.32.3
Jinja2==3.1.5
python-keycloak==5.3.1
python-multipart==0.0.20
websockets==15.0

View File

@ -1,78 +1,45 @@
import importlib
import os
import time
from uuid import uuid4
from typing import Optional
from flask import Flask, request, g
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from api.route import Route
from core.environment import Environment
from core.logger import APILogger
app = Flask(__name__)
logger = APILogger(__name__)
def filter_relevant_headers(headers: dict) -> dict:
relevant_keys = {
"Content-Type",
"Host",
"Connection",
"User-Agent",
"Origin",
"Referer",
"Accept",
}
return {key: value for key, value in headers.items() if key in relevant_keys}
class API:
app: Optional[Starlette] = None
@classmethod
def create(cls, app: Starlette):
cls.app = app
@app.before_request
async def log_request():
g.request_id = uuid4()
g.start_time = time.time()
logger.debug(
f"Request {g.request_id}: {request.method}@{request.path} from {request.remote_addr}"
)
user = await Route.get_user()
@staticmethod
async def handle_exception(request: Request, exc: Exception):
logger.error(f"Request {request.state.request_id}", exc)
return JSONResponse({"error": str(exc)}, status_code=500)
request_info = {
"headers": filter_relevant_headers(dict(request.headers)),
"args": request.args.to_dict(),
"form-data": request.form.to_dict(),
"payload": request.get_json(silent=True),
"user": f"{user.id}-{user.keycloak_id}" if user else None,
"files": (
{key: file.filename for key, file in request.files.items()}
if request.files
else None
),
}
@staticmethod
def get_allowed_origins():
client_urls = Environment.get("CLIENT_URLS", str)
if client_urls is None or client_urls == "":
allowed_origins = ["*"]
logger.warning("No allowed origins specified, allowing all origins")
else:
allowed_origins = client_urls.split(",")
logger.trace(f"Request {g.request_id}: {request_info}")
return allowed_origins
@app.after_request
def log_after_request(response):
# calc the time it took to process the request
duration = (time.time() - g.start_time) * 1000
logger.info(
f"Request finished {g.request_id}: {response.status_code}-{request.method}@{request.path} from {request.remote_addr} in {duration:.2f}ms"
)
return response
@app.errorhandler(Exception)
def handle_exception(e):
logger.error(f"Request {g.request_id}", e)
return {"error": str(e)}, 500
# used to import all routes
routes_dir = os.path.join(os.path.dirname(__file__), "routes")
for filename in os.listdir(routes_dir):
if filename.endswith(".py") and filename != "__init__.py":
module_name = f"api.routes.{filename[:-3]}"
importlib.import_module(module_name)
# Explicitly register the routes
for route, (view_func, options) in Route.registered_routes.items():
app.add_url_rule(route, view_func=view_func, **options)
@staticmethod
def import_routes():
# used to import all routes
routes_dir = os.path.join(os.path.dirname(__file__), "routes")
for filename in os.listdir(routes_dir):
if filename.endswith(".py") and filename != "__init__.py":
module_name = f"api.routes.{filename[:-3]}"
importlib.import_module(module_name)

5
api/src/api/broadcast.py Normal file
View File

@ -0,0 +1,5 @@
from typing import Optional
from broadcaster import Broadcast
broadcast: Optional[Broadcast] = Broadcast("memory://")

View File

@ -1,9 +1,9 @@
from flask import jsonify
from starlette.responses import JSONResponse
def unauthorized():
return jsonify({"error": "Unauthorized"}), 401
return JSONResponse({"error": "Unauthorized"}, 401)
def forbidden():
return jsonify({"error": "Unauthorized"}), 401
return JSONResponse({"error": "Unauthorized"}, 401)

View File

View File

@ -0,0 +1,73 @@
import time
from uuid import uuid4
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from api.route import Route
from core.logger import APILogger
logger = APILogger("api.api")
class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
await self._log_request(request)
response = await call_next(request)
await self._log_after_request(request, response)
return response
@staticmethod
def _filter_relevant_headers(headers: dict) -> dict:
relevant_keys = {
"content-type",
"host",
"connection",
"user-agent",
"origin",
"referer",
"accept",
}
return {key: value for key, value in headers.items() if key in relevant_keys}
@classmethod
async def _log_request(cls, request: Request):
request.state.request_id = uuid4()
request.state.start_time = time.time()
logger.debug(
f"Request {request.state.request_id}: {request.method}@{request.url.path} from {request.client.host}"
)
user = await Route.get_user()
request_info = {
"headers": cls._filter_relevant_headers(dict(request.headers)),
"args": dict(request.query_params),
"form-data": (
await request.form()
if request.headers.get("content-type")
== "application/x-www-form-urlencoded"
else None
),
"payload": (
await request.json()
if request.headers.get("content-length") == "0"
else None
),
"user": f"{user.id}-{user.keycloak_id}" if user else None,
"files": (
{key: file.filename for key, file in (await request.form()).items()}
if await request.form()
else None
),
}
logger.trace(f"Request {request.state.request_id}: {request_info}")
@staticmethod
async def _log_after_request(request: Request, response: Response):
duration = (time.time() - request.state.start_time) * 1000
logger.info(
f"Request finished {request.state.request_id}: {response.status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
)

View File

@ -0,0 +1,28 @@
from contextvars import ContextVar
from typing import Optional, Union
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.websockets import WebSocket
_request_context: ContextVar[Union[Request, None]] = ContextVar("request", default=None)
class RequestMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
_request_context.set(request)
from core.logger import APILogger
logger = APILogger(__name__)
logger.trace("Set new current request")
response = await call_next(request)
return response
def set_request(request: Union[Request, WebSocket, None]):
_request_context.set(request)
def get_request() -> Optional[Request]:
return _request_context.get()

View File

@ -0,0 +1,41 @@
from uuid import uuid4
from ariadne.asgi.handlers import GraphQLTransportWSHandler
from starlette.datastructures import MutableHeaders
from starlette.websockets import WebSocket
from api.middleware.request import set_request
from core.logger import APILogger
logger = APILogger("api.ws")
class AuthenticatedGraphQLTransportWSHandler(GraphQLTransportWSHandler):
def __init__(self, *args, **kwargs):
super().__init__(
*args,
on_connect=self.on_connect,
on_disconnect=self.on_disconnect,
**kwargs,
)
@staticmethod
async def on_connect(ws: WebSocket, message: dict):
ws.state.request_id = uuid4()
logger.info(f"WebSocket connection {ws.state.request_id}")
if "Authorization" not in message:
return True
mutable_headers = MutableHeaders()
mutable_headers["Authorization"] = message.get("Authorization", "")
ws._headers = mutable_headers
set_request(ws)
return True
@staticmethod
async def on_disconnect(ws: WebSocket):
logger.debug(f"WebSocket connection {ws.state.request_id} closed")
return True

View File

@ -2,12 +2,12 @@ import functools
from functools import wraps
from inspect import iscoroutinefunction
from typing import Callable, Union, Optional
from urllib.request import Request
from flask import request
from flask_cors import cross_origin
from starlette.requests import Request
from starlette.routing import Route as StarletteRoute
from api.errors import unauthorized
from api.middleware.request import get_request
from api.route_user_extension import RouteUserExtension
from core.environment import Environment
from data.schemas.administration.api_key import ApiKey
@ -16,10 +16,10 @@ from data.schemas.administration.user import User
class Route(RouteUserExtension):
registered_routes = {}
registered_routes: list[StarletteRoute] = []
@classmethod
async def get_api_key(cls) -> ApiKey:
async def get_api_key(cls, request: Request) -> ApiKey:
auth_header = request.headers.get("Authorization", None)
api_key = auth_header.split(" ")[1]
return await apiKeyDao.find_by_key(api_key)
@ -35,11 +35,13 @@ class Route(RouteUserExtension):
return api_key_from_db is not None and not api_key_from_db.deleted
@classmethod
async def _get_auth_type(cls, auth_header: str) -> Optional[Union[User, ApiKey]]:
async def _get_auth_type(
cls, request: Request, auth_header: str
) -> Optional[Union[User, ApiKey]]:
if auth_header.startswith("Bearer "):
return await cls.get_user()
elif auth_header.startswith("API-Key "):
return await cls.get_api_key()
return await cls.get_api_key(request)
elif (
auth_header.startswith("DEV-User ")
and Environment.get_environment() == "development"
@ -49,11 +51,15 @@ class Route(RouteUserExtension):
@classmethod
async def get_authenticated_user_or_api_key(cls) -> Union[User, ApiKey]:
request = get_request()
if request is None:
raise ValueError("No request found")
auth_header = request.headers.get("Authorization", None)
if not auth_header:
raise Exception("No Authorization header found")
user_or_api_key = await cls._get_auth_type(auth_header)
user_or_api_key = await cls._get_auth_type(request, auth_header)
if user_or_api_key is None:
raise Exception("Invalid Authorization header")
return user_or_api_key
@ -62,14 +68,22 @@ class Route(RouteUserExtension):
async def get_authenticated_user_or_api_key_or_default(
cls,
) -> Optional[Union[User, ApiKey]]:
request = get_request()
if request is None:
return None
auth_header = request.headers.get("Authorization", None)
if not auth_header:
return None
return await cls._get_auth_type(auth_header)
return await cls._get_auth_type(request, auth_header)
@classmethod
async def is_authorized(cls) -> bool:
request = get_request()
if request is None:
return False
auth_header = request.headers.get("Authorization", None)
if not auth_header:
return False
@ -99,26 +113,25 @@ class Route(RouteUserExtension):
)
@wraps(f)
async def decorator(*args, **kwargs):
async def decorator(request: Request, *args, **kwargs):
if skip_in_dev and Environment.get_environment() == "development":
if iscoroutinefunction(f):
return await f(*args, **kwargs)
return f(*args, **kwargs)
return await f(request, *args, **kwargs)
return f(request, *args, **kwargs)
if not await cls.is_authorized():
return unauthorized()
if iscoroutinefunction(f):
return await f(*args, **kwargs)
return f(*args, **kwargs)
return await f(request, *args, **kwargs)
return f(request, *args, **kwargs)
return decorator
@classmethod
def route(cls, path=None, **kwargs):
def inner(fn):
cross_origin(fn)
cls.registered_routes[path] = (fn, kwargs)
cls.registered_routes.append(StarletteRoute(path, fn, **kwargs))
return fn
return inner

View File

@ -1,10 +1,11 @@
from typing import Optional
from flask import request, Request, has_request_context
from keycloak import KeycloakAuthenticationError, KeycloakConnectionError
from starlette.requests import Request
from api.auth.keycloak_client import Keycloak
from api.auth.keycloak_user import KeycloakUser
from api.middleware.request import get_request
from core.get_value import get_value
from core.logger import Logger
from data.schemas.administration.user import User
@ -19,8 +20,8 @@ logger = Logger(__name__)
class RouteUserExtension:
@classmethod
def _get_user_id_from_token(cls) -> Optional[str]:
token = cls.get_token()
def _get_user_id_from_token(cls, request: Request) -> Optional[str]:
token = cls.get_token(request)
if not token:
return None
@ -34,7 +35,7 @@ class RouteUserExtension:
return get_value(user_info, "sub", str)
@staticmethod
def get_token() -> Optional[str]:
def get_token(request: Request) -> Optional[str]:
if "Authorization" not in request.headers:
return None
@ -45,10 +46,11 @@ class RouteUserExtension:
@classmethod
async def get_user(cls) -> Optional[User]:
if not has_request_context():
request = get_request()
if request is None:
return None
user_id = cls._get_user_id_from_token()
user_id = cls._get_user_id_from_token(request)
if not user_id:
return None
@ -56,8 +58,12 @@ class RouteUserExtension:
@classmethod
async def get_dev_user(cls) -> Optional[User]:
request = get_request()
if request is None:
return None
return await userDao.find_single_by(
[{User.keycloak_id: cls.get_token()}, {User.deleted: False}]
[{User.keycloak_id: cls.get_token(request)}, {User.deleted: False}]
)
@classmethod
@ -86,8 +92,8 @@ class RouteUserExtension:
logger.error("Failed to find or create user", e)
@classmethod
async def verify_login(cls, req: Request) -> bool:
auth_header = req.headers.get("Authorization", None)
async def verify_login(cls, request: Request) -> bool:
auth_header = request.headers.get("Authorization", None)
if not auth_header or not auth_header.startswith("Bearer "):
return False

View File

@ -1,7 +1,8 @@
from uuid import uuid4
from flask import send_file
from werkzeug.exceptions import NotFound
from starlette.requests import Request
from starlette.responses import FileResponse
from starlette.exceptions import HTTPException
from api.route import Route
from core.logger import APILogger
@ -9,19 +10,23 @@ from core.logger import APILogger
logger = APILogger(__name__)
@Route.get(f"/api/files/<path:file_path>")
def get_file(file_path: str):
@Route.get("/api/files/{file_path:path}")
async def get_file(request: Request):
file_path = request.path_params["file_path"]
name = file_path
if "/" in file_path:
name = file_path.split("/")[-1]
try:
return send_file(
f"../files/{file_path}", download_name=name, as_attachment=True
return FileResponse(
path=f"files/{file_path}",
filename=name,
media_type="application/octet-stream",
)
except NotFound:
return {"error": "File not found"}, 404
except Exception as e:
error_id = uuid4()
logger.error(f"Error {error_id} getting file {file_path}", e)
return {"error": f"File error. ErrorId: {error_id}"}, 500
except HTTPException as e:
if e.status_code == 404:
return {"error": "File not found"}, 404
else:
error_id = uuid4()
logger.error(f"Error {error_id} getting file {file_path}", e)
return {"error": f"File error. ErrorId: {error_id}"}, 500

View File

@ -1,5 +1,6 @@
from ariadne import graphql
from flask import request, jsonify
from starlette.requests import Request
from starlette.responses import JSONResponse
from api.route import Route
from api_graphql.service.schema import schema
@ -10,11 +11,11 @@ logger = Logger(__name__)
@Route.post(f"{BasePath}")
async def graphql_endpoint():
data = request.get_json()
async def graphql_endpoint(request: Request):
data = await request.json()
# Note: Passing the request to the context is optional.
# In Flask, the current request is always accessible as flask.request
# In Starlette, the current request is accessible as request
success, result = await graphql(schema, data, context_value=request)
status_code = 200
@ -24,4 +25,4 @@ async def graphql_endpoint():
]
status_code = max(status_codes, default=200)
return jsonify(result), status_code
return JSONResponse(result, status_code=status_code)

View File

@ -1,4 +1,6 @@
from ariadne.explorer import ExplorerPlayground
from starlette.requests import Request
from starlette.responses import HTMLResponse
from api.route import Route
from core.environment import Environment
@ -10,7 +12,7 @@ logger = Logger(__name__)
@Route.get(f"{BasePath}/playground")
@Route.authorize(skip_in_dev=True)
async def playground():
async def playground(r: Request):
if Environment.get_environment() != "development":
return "", 403
@ -19,7 +21,6 @@ async def playground():
if dev_user:
request_global_headers = {f"Authorization": f"DEV-User {dev_user}"}
return (
ExplorerPlayground(request_global_headers=request_global_headers).html(None),
200,
return HTMLResponse(
ExplorerPlayground(request_global_headers=request_global_headers).html(None)
)

View File

@ -1,7 +1,16 @@
from starlette.requests import Request
from starlette.responses import JSONResponse
from api.route import Route
from core.configuration.feature_flags import FeatureFlags
from core.configuration.feature_flags_enum import FeatureFlagsEnum
from version import VERSION
@Route.get(f"/api/version")
def version():
return VERSION
async def version(r: Request):
feature = await FeatureFlags.has_feature(FeatureFlagsEnum.version_endpoint)
if not feature:
return JSONResponse("DISABLED", status_code=403)
return JSONResponse(VERSION)

View File

@ -4,6 +4,7 @@ from api_graphql.abc.filter.bool_filter import BoolFilter
from api_graphql.abc.filter.int_filter import IntFilter
from api_graphql.abc.filter.string_filter import StringFilter
from api_graphql.abc.filter_abc import FilterABC
from api_graphql.filter.fuzzy_filter import FuzzyFilter
class DbModelFilterABC[T](FilterABC[T]):
@ -18,3 +19,5 @@ class DbModelFilterABC[T](FilterABC[T]):
self.add_field("editor", IntFilter)
self.add_field("createdUtc", StringFilter, "created")
self.add_field("updatedUtc", StringFilter, "updated")
self.add_field("fuzzy", FuzzyFilter)

View File

@ -16,7 +16,7 @@ class MutationABC(QueryABC):
self,
name: str,
mutation_name: str,
require_any_permission: list[Permissions] = None,
require_any_permission=None,
public: bool = False,
):
"""
@ -27,6 +27,8 @@ class MutationABC(QueryABC):
:param bool public: Define if the field can resolve without authentication
:return:
"""
if require_any_permission is None:
require_any_permission = []
from api_graphql.definition import QUERIES
self.field(

View File

@ -4,12 +4,13 @@ from enum import Enum
from types import NoneType
from typing import Callable, Type, get_args, Any, Union
from ariadne import ObjectType
from ariadne import ObjectType, SubscriptionType
from graphql import GraphQLResolveInfo
from typing_extensions import deprecated
from api.route import Route
from api_graphql.abc.collection_filter_abc import CollectionFilterABC
from api_graphql.abc.field_abc import FieldABC
from api_graphql.abc.input_abc import InputABC
from api_graphql.abc.sort_abc import Sort
from api_graphql.field.collection_field import CollectionField
@ -20,6 +21,7 @@ from api_graphql.field.mutation_field import MutationField
from api_graphql.field.mutation_field_builder import MutationFieldBuilder
from api_graphql.field.resolver_field import ResolverField
from api_graphql.field.resolver_field_builder import ResolverFieldBuilder
from api_graphql.field.subscription_field import SubscriptionField
from api_graphql.service.collection_result import CollectionResult
from api_graphql.service.exceptions import (
UnauthorizedException,
@ -29,6 +31,7 @@ from api_graphql.service.exceptions import (
from api_graphql.service.query_context import QueryContext
from api_graphql.typing import TRequireAnyPermissions, TRequireAnyResolvers
from core.logger import APILogger
from core.string import first_to_lower
from service.permission.permissions_enum import Permissions
logger = APILogger(__name__)
@ -40,6 +43,7 @@ class QueryABC(ObjectType):
@abstractmethod
def __init__(self, name: str = __name__):
ObjectType.__init__(self, name)
self._subscriptions: dict[str, SubscriptionType] = {}
@staticmethod
async def _authorize():
@ -67,6 +71,8 @@ class QueryABC(ObjectType):
*args,
**kwargs,
):
info = args[0]
if len(permissions) > 0:
user = await Route.get_authenticated_user_or_api_key_or_default()
if user is not None and all(
@ -132,7 +138,12 @@ class QueryABC(ObjectType):
skip = kwargs["skip"]
collection = await field.dao.find_by(filters, sorts, take, skip)
res = CollectionResult(await field.dao.count(), len(collection), collection)
if field.direct_result:
return collection
res = CollectionResult(
await field.dao.count(filters), len(collection), collection
)
return res
async def collection_wrapper(*args, **kwargs):
@ -169,11 +180,12 @@ class QueryABC(ObjectType):
)
async def resolver_wrapper(*args, **kwargs):
return (
result = (
await field.resolver(*args, **kwargs)
if iscoroutinefunction(field.resolver)
else field.resolver(*args, **kwargs)
)
return result
if isinstance(field, DaoField):
resolver = dao_wrapper
@ -203,6 +215,13 @@ class QueryABC(ObjectType):
resolver = input_wrapper
elif isinstance(field, SubscriptionField):
async def sub_wrapper(sub: QueryABC, info: GraphQLResolveInfo, **kwargs):
return await resolver_wrapper(sub, info, **kwargs)
resolver = sub_wrapper
else:
raise ValueError(f"Unknown field type: {field.name}")
@ -220,7 +239,12 @@ class QueryABC(ObjectType):
result = await resolver(*args, **kwargs)
if field.require_any is not None:
await self._require_any(result, *field.require_any, *args, **kwargs)
await self._require_any(
result,
*field.require_any,
*args,
**kwargs,
)
return result
@ -250,6 +274,9 @@ class QueryABC(ObjectType):
self.field(
MutationFieldBuilder(name)
.with_resolver(f)
.with_change_broadcast(
f"{first_to_lower(self.name.replace("Mutation", ""))}Change"
)
.with_input(input_type, input_key)
.with_require_any_permission(require_any_permission)
.with_public(public)
@ -271,6 +298,8 @@ class QueryABC(ObjectType):
for f in filters:
collection = list(filter(lambda x: f.filter(x), collection))
total_count = len(collection)
if sort is not None:
def f_sort(x: object, k: str):

View File

@ -0,0 +1,51 @@
from abc import abstractmethod
from asyncio import iscoroutinefunction
from ariadne import SubscriptionType
from api.middleware.request import get_request
from api_graphql.abc.query_abc import QueryABC
from api_graphql.field.subscription_field_builder import SubscriptionFieldBuilder
from core.logger import APILogger
logger = APILogger(__name__)
class SubscriptionABC(SubscriptionType, QueryABC):
@abstractmethod
def __init__(self):
SubscriptionType.__init__(self)
def subscribe(self, builder: SubscriptionFieldBuilder):
field = builder.build()
async def wrapper(*args, **kwargs):
if not field.public:
await self._authorize()
if (
field.require_any is None
and not field.public
and field.require_any_permission
):
await self._require_any_permission(field.require_any_permission)
result = (
await field.resolver(*args, **kwargs)
if iscoroutinefunction(field.resolver)
else field.resolver(*args, **kwargs)
)
if field.require_any is not None:
await self._require_any(
result,
*field.require_any,
*args,
**kwargs,
)
return result
self.set_field(field.name, wrapper)
self.set_source(field.name, field.generator)

View File

@ -4,6 +4,7 @@ import os
from api_graphql.abc.db_model_query_abc import DbModelQueryABC
from api_graphql.abc.mutation_abc import MutationABC
from api_graphql.abc.query_abc import QueryABC
from api_graphql.abc.subscription_abc import SubscriptionABC
from api_graphql.query import Query
@ -19,7 +20,7 @@ def import_graphql_schema_part(part: str):
import_graphql_schema_part("queries")
import_graphql_schema_part("mutations")
sub_query_classes = [DbModelQueryABC, MutationABC]
sub_query_classes = [DbModelQueryABC, MutationABC, SubscriptionABC]
query_classes = [
*[y for x in sub_query_classes for y in x.__subclasses__()],
*[x for x in QueryABC.__subclasses__() if x not in sub_query_classes],

View File

@ -20,6 +20,7 @@ class DaoField(FieldABC):
dao: DataAccessObjectABC = None,
filter_type: Type[FilterABC] = None,
sort_type: Type[T] = None,
direct_result: bool = False,
):
FieldABC.__init__(self, name, require_any_permission, require_any, public)
self._name = name
@ -28,6 +29,7 @@ class DaoField(FieldABC):
self._dao = dao
self._filter_type = filter_type
self._sort_type = sort_type
self._direct_result = direct_result
@property
def dao(self) -> Optional[DataAccessObjectABC]:
@ -42,3 +44,7 @@ class DaoField(FieldABC):
@property
def sort_type(self) -> Optional[Type[T]]:
return self._sort_type
@property
def direct_result(self) -> bool:
return self._direct_result

View File

@ -15,6 +15,7 @@ class DaoFieldBuilder(FieldBuilderABC):
self._dao = None
self._filter_type = None
self._sort_type = None
self._direct_result = False
def with_dao(self, dao: DataAccessObjectABC) -> Self:
assert dao is not None, "dao cannot be None"
@ -31,6 +32,10 @@ class DaoFieldBuilder(FieldBuilderABC):
self._sort_type = sort_type
return self
def with_direct_result(self) -> Self:
self._direct_result = True
return self
def build(self) -> DaoField:
assert self._dao is not None, "dao cannot be None"
return DaoField(
@ -41,4 +46,5 @@ class DaoFieldBuilder(FieldBuilderABC):
self._dao,
self._filter_type,
self._sort_type,
self._direct_result,
)

View File

@ -1,7 +1,9 @@
from asyncio import iscoroutinefunction
from typing import Self, Type
from ariadne.types import Resolver
from api.broadcast import broadcast
from api_graphql.abc.field_builder_abc import FieldBuilderABC
from api_graphql.abc.input_abc import InputABC
from api_graphql.field.mutation_field import MutationField
@ -18,9 +20,41 @@ class MutationFieldBuilder(FieldBuilderABC):
def with_resolver(self, resolver: Resolver) -> Self:
assert resolver is not None, "resolver cannot be None"
self._resolver = resolver
return self
def with_broadcast(self, source: str):
assert self._resolver is not None, "resolver cannot be None for broadcast"
resolver = self._resolver
async def resolver_wrapper(*args, **kwargs):
result = (
await resolver(*args, **kwargs)
if iscoroutinefunction(resolver)
else resolver(*args, **kwargs)
)
await broadcast.publish(f"{source}", result)
return result
def with_change_broadcast(self, source: str):
assert self._resolver is not None, "resolver cannot be None for broadcast"
resolver = self._resolver
async def resolver_wrapper(*args, **kwargs):
result = (
await resolver(*args, **kwargs)
if iscoroutinefunction(resolver)
else resolver(*args, **kwargs)
)
await broadcast.publish(f"{source}", {})
return result
self._resolver = resolver_wrapper
return self
def with_input(self, input_type: Type[InputABC], input_key: str = None) -> Self:
self._input_type = input_type
self._input_key = input_key

View File

@ -16,11 +16,17 @@ class ResolverField(FieldABC):
require_any: TRequireAny = None,
public: bool = False,
resolver: Resolver = None,
direct_result: bool = False,
):
FieldABC.__init__(self, name, require_any_permission, require_any, public)
self._resolver = resolver
self._direct_result = direct_result
@property
def resolver(self) -> Optional[Resolver]:
return self._resolver
@property
def direct_result(self) -> bool:
return self._direct_result

View File

@ -12,12 +12,17 @@ class ResolverFieldBuilder(FieldBuilderABC):
FieldBuilderABC.__init__(self, name)
self._resolver = None
self._direct_result = False
def with_resolver(self, resolver: Resolver) -> Self:
assert resolver is not None, "resolver cannot be None"
self._resolver = resolver
return self
def with_direct_result(self) -> Self:
self._direct_result = True
return self
def build(self) -> ResolverField:
assert self._resolver is not None, "resolver cannot be None"
return ResolverField(
@ -26,4 +31,5 @@ class ResolverFieldBuilder(FieldBuilderABC):
self._require_any,
self._public,
self._resolver,
self._direct_result,
)

View File

@ -0,0 +1,32 @@
from typing import Optional
from ariadne.types import Resolver
from api_graphql.abc.field_abc import FieldABC
from api_graphql.typing import TRequireAny
from service.permission.permissions_enum import Permissions
class SubscriptionField(FieldABC):
def __init__(
self,
name: str,
require_any_permission: list[Permissions] = None,
require_any: TRequireAny = None,
public: bool = False,
resolver: Resolver = None,
generator: Resolver = None,
):
FieldABC.__init__(self, name, require_any_permission, require_any, public)
self._resolver = resolver
self._generator = generator
@property
def resolver(self) -> Optional[Resolver]:
return self._resolver
@property
def generator(self) -> Optional[Resolver]:
return self._generator

View File

@ -0,0 +1,46 @@
from typing import Self, AsyncGenerator
from ariadne.types import Resolver
from api.broadcast import broadcast
from api_graphql.abc.field_builder_abc import FieldBuilderABC
from api_graphql.field.subscription_field import SubscriptionField
class SubscriptionFieldBuilder(FieldBuilderABC):
def __init__(self, name: str):
FieldBuilderABC.__init__(self, name)
self._resolver = None
self._generator = None
def with_resolver(self, resolver: Resolver) -> Self:
assert resolver is not None, "resolver cannot be None"
self._resolver = resolver
return self
def with_generator(self, generator: Resolver) -> Self:
assert generator is not None, "generator cannot be None"
self._generator = generator
return self
def build(self) -> SubscriptionField:
assert self._resolver is not None, "resolver cannot be None"
if self._generator is None:
async def generator(*args, **kwargs) -> AsyncGenerator[str, None]:
async with broadcast.subscribe(channel=self._name) as subscriber:
async for message in subscriber:
yield message
self._generator = generator
return SubscriptionField(
self._name,
self._require_any_permission,
self._require_any,
self._public,
self._resolver,
self._generator,
)

View File

@ -0,0 +1,13 @@
from api_graphql.abc.db_model_filter_abc import DbModelFilterABC
from api_graphql.abc.filter.string_filter import StringFilter
class DomainFilter(DbModelFilterABC):
def __init__(
self,
obj: dict,
):
DbModelFilterABC.__init__(self, obj)
self.add_field("name", StringFilter)
self.add_field("description", StringFilter)

View File

@ -0,0 +1,15 @@
from typing import Optional
from api_graphql.abc.filter_abc import FilterABC
class FuzzyFilter(FilterABC):
def __init__(
self,
obj: Optional[dict],
):
FilterABC.__init__(self, obj)
self.add_field("fields", list)
self.add_field("term", str)
self.add_field("threshold", int)

View File

@ -9,6 +9,6 @@ class ShortUrlFilter(DbModelFilterABC):
):
DbModelFilterABC.__init__(self, obj)
self.add_field("short_url", StringFilter)
self.add_field("target_url", StringFilter)
self.add_field("shortUrl", StringFilter, db_name="short_url")
self.add_field("targetUrl", StringFilter, db_name="target_url")
self.add_field("description", StringFilter)

View File

@ -21,11 +21,21 @@ input ApiKeySort {
identifier: SortOrder
deleted: SortOrder
editorId: SortOrder
editor: UserSort
createdUtc: SortOrder
updatedUtc: SortOrder
}
enum ApiKeyFuzzyFields {
identifier
}
input ApiKeyFuzzy {
fields: [ApiKeyFuzzyFields]
term: String
threshold: Int
}
input ApiKeyFilter {
id: IntFilter
identifier: StringFilter

View File

@ -0,0 +1,65 @@
type DomainResult {
totalCount: Int
count: Int
nodes: [Domain]
}
type Domain implements DbModel {
id: ID
name: String
shortUrls: [ShortUrl]
deleted: Boolean
editor: User
createdUtc: String
updatedUtc: String
}
input DomainSort {
id: SortOrder
name: SortOrder
deleted: SortOrder
editorId: SortOrder
createdUtc: SortOrder
updatedUtc: SortOrder
}
enum DomainFuzzyFields {
name
}
input DomainFuzzy {
fields: [DomainFuzzyFields]
term: String
threshold: Int
}
input DomainFilter {
id: IntFilter
name: StringFilter
fuzzy: DomainFuzzy
deleted: BooleanFilter
editor: IntFilter
createdUtc: DateFilter
updatedUtc: DateFilter
}
type DomainMutation {
create(input: DomainCreateInput!): Domain
update(input: DomainUpdateInput!): Domain
delete(id: ID!): Boolean
restore(id: ID!): Boolean
}
input DomainCreateInput {
name: String!
}
input DomainUpdateInput {
id: ID!
name: String
}

View File

@ -0,0 +1,19 @@
type FeatureFlag implements DbModel {
id: ID
key: String
value: Boolean
deleted: Boolean
editor: User
createdUtc: String
updatedUtc: String
}
type FeatureFlagMutation {
change(input: FeatureFlagInput!): FeatureFlag
}
input FeatureFlagInput {
key: String!
value: Boolean!
}

View File

@ -9,6 +9,7 @@ type Group implements DbModel {
name: String
shortUrls: [ShortUrl]
roles: [Role]
deleted: Boolean
editor: User
@ -26,10 +27,22 @@ input GroupSort {
updatedUtc: SortOrder
}
enum GroupFuzzyFields {
name
}
input GroupFuzzy {
fields: [GroupFuzzyFields]
term: String
threshold: Int
}
input GroupFilter {
id: IntFilter
name: StringFilter
fuzzy: GroupFuzzy
deleted: BooleanFilter
editor: IntFilter
createdUtc: DateFilter
@ -45,9 +58,11 @@ type GroupMutation {
input GroupCreateInput {
name: String!
roles: [ID]
}
input GroupUpdateInput {
id: ID!
name: String
roles: [ID]
}

View File

@ -5,5 +5,10 @@ type Mutation {
role: RoleMutation
group: GroupMutation
domain: DomainMutation
shortUrl: ShortUrlMutation
setting: SettingMutation
userSetting: UserSettingMutation
featureFlag: FeatureFlagMutation
}

View File

@ -11,6 +11,11 @@ type Query {
userHasAnyPermission(permissions: [String]!): Boolean
notExistingUsersFromKeycloak: KeycloakUserResult
domains(filter: [DomainFilter], sort: [DomainSort], skip: Int, take: Int): DomainResult
groups(filter: [GroupFilter], sort: [GroupSort], skip: Int, take: Int): GroupResult
shortUrls(filter: [ShortUrlFilter], sort: [ShortUrlSort], skip: Int, take: Int): ShortUrlResult
settings(key: String): [Setting]
userSettings(key: String): [Setting]
featureFlags(key: String): [FeatureFlag]
}

View File

@ -23,18 +23,31 @@ input RoleSort {
description: SortOrder
deleted: SortOrder
editorId: SortOrder
editor: UserSort
createdUtc: SortOrder
updatedUtc: SortOrder
}
enum RoleFuzzyFields {
name
description
}
input RoleFuzzy {
fields: [RoleFuzzyFields]
term: String
threshold: Int
}
input RoleFilter {
id: IntFilter
name: StringFilter
description: StringFilter
fuzzy: RoleFuzzy
deleted: BooleanFilter
editorId: IntFilter
editor_id: IntFilter
createdUtc: DateFilter
updatedUtc: DateFilter
}

View File

@ -0,0 +1,19 @@
type Setting implements DbModel {
id: ID
key: String
value: String
deleted: Boolean
editor: User
createdUtc: String
updatedUtc: String
}
type SettingMutation {
change(input: SettingInput!): Setting
}
input SettingInput {
key: String!
value: String!
}

View File

@ -11,6 +11,7 @@ type ShortUrl implements DbModel {
description: String
visits: Int
group: Group
domain: Domain
loadingScreen: Boolean
deleted: Boolean
@ -31,12 +32,27 @@ input ShortUrlSort {
updatedUtc: SortOrder
}
enum ShortUrlFuzzyFields {
shortUrl
targetUrl
description
}
input ShortUrlFuzzy {
fields: [ShortUrlFuzzyFields]
term: String
threshold: Int
}
input ShortUrlFilter {
id: IntFilter
name: StringFilter
shortUrl: StringFilter
targetUrl: StringFilter
description: StringFilter
loadingScreen: BooleanFilter
fuzzy: ShortUrlFuzzy
deleted: BooleanFilter
editor: IntFilter
createdUtc: DateFilter
@ -48,6 +64,7 @@ type ShortUrlMutation {
update(input: ShortUrlUpdateInput!): ShortUrl
delete(id: ID!): Boolean
restore(id: ID!): Boolean
trackVisit(id: ID!, agent: String): Boolean
}
input ShortUrlCreateInput {
@ -55,6 +72,7 @@ input ShortUrlCreateInput {
targetUrl: String!
description: String
groupId: ID
domainId: ID
loadingScreen: Boolean
}
@ -64,5 +82,6 @@ input ShortUrlUpdateInput {
targetUrl: String
description: String
groupId: ID
domainId: ID
loadingScreen: Boolean
}

View File

@ -0,0 +1,16 @@
scalar SubscriptionChange
type Subscription {
ping: String
apiKeyChange: SubscriptionChange
featureFlagChange: SubscriptionChange
roleChange: SubscriptionChange
settingChange: SubscriptionChange
userChange: SubscriptionChange
userSettingChange: SubscriptionChange
domainChange: SubscriptionChange
groupChange: SubscriptionChange
shortUrlChange: SubscriptionChange
}

View File

@ -35,19 +35,33 @@ input UserSort {
email: SortOrder
deleted: SortOrder
editorId: SortOrder
editor: UserSort
createdUtc: SortOrder
updatedUtc: SortOrder
}
enum UserFuzzyFields {
keycloakId
username
email
}
input UserFuzzy {
fields: [UserFuzzyFields]
term: String
threshold: Int
}
input UserFilter {
id: IntFilter
keycloakId: StringFilter
username: StringFilter
email: StringFilter
fuzzy: UserFuzzy
deleted: BooleanFilter
editor: IntFilter
editor: UserFilter
createdUtc: DateFilter
updatedUtc: DateFilter
}

View File

@ -0,0 +1,19 @@
type UserSetting implements DbModel {
id: ID
key: String
value: String
deleted: Boolean
editor: User
createdUtc: String
updatedUtc: String
}
type UserSettingMutation {
change(input: UserSettingInput!): UserSetting
}
input UserSettingInput {
key: String!
value: String!
}

View File

@ -0,0 +1,13 @@
from api_graphql.abc.input_abc import InputABC
class DomainCreateInput(InputABC):
def __init__(self, src: dict):
InputABC.__init__(self, src)
self._name = self.option("name", str, required=True)
@property
def name(self) -> str:
return self._name

View File

@ -0,0 +1,18 @@
from api_graphql.abc.input_abc import InputABC
class DomainUpdateInput(InputABC):
def __init__(self, src: dict):
InputABC.__init__(self, src)
self._id = self.option("id", int, required=True)
self._name = self.option("name", str)
@property
def id(self) -> int:
return self._id
@property
def name(self) -> str:
return self._name

View File

@ -0,0 +1,18 @@
from api_graphql.abc.input_abc import InputABC
class FeatureFlagInput(InputABC):
def __init__(self, src: dict):
InputABC.__init__(self, src)
self._key = self.option("key", str, required=True)
self._value = self.option("value", bool, required=True)
@property
def key(self) -> str:
return self._key
@property
def value(self) -> bool:
return self._value

View File

@ -7,7 +7,12 @@ class GroupCreateInput(InputABC):
InputABC.__init__(self, src)
self._name = self.option("name", str, required=True)
self._roles = self.option("roles", list[int])
@property
def name(self) -> str:
return self._name
@property
def roles(self) -> list[int]:
return self._roles

View File

@ -8,6 +8,7 @@ class GroupUpdateInput(InputABC):
self._id = self.option("id", int, required=True)
self._name = self.option("name", str)
self._roles = self.option("roles", list[int])
@property
def id(self) -> int:
@ -16,3 +17,7 @@ class GroupUpdateInput(InputABC):
@property
def name(self) -> str:
return self._name
@property
def roles(self) -> list[int]:
return self._roles

View File

@ -0,0 +1,18 @@
from api_graphql.abc.input_abc import InputABC
class SettingInput(InputABC):
def __init__(self, src: dict):
InputABC.__init__(self, src)
self._key = self.option("key", str, required=True)
self._value = self.option("value", str, required=True)
@property
def key(self) -> str:
return self._key
@property
def value(self) -> str:
return self._value

View File

@ -12,6 +12,7 @@ class ShortUrlCreateInput(InputABC):
self._target_url = self.option("targetUrl", str, required=True)
self._description = self.option("description", str)
self._group_id = self.option("groupId", int)
self._domain_id = self.option("domainId", int)
self._loading_screen = self.option("loadingScreen", bool)
@property
@ -30,6 +31,10 @@ class ShortUrlCreateInput(InputABC):
def group_id(self) -> Optional[int]:
return self._group_id
@property
def domain_id(self) -> Optional[int]:
return self._domain_id
@property
def loading_screen(self) -> Optional[str]:
return self._loading_screen

View File

@ -13,6 +13,7 @@ class ShortUrlUpdateInput(InputABC):
self._target_url = self.option("targetUrl", str)
self._description = self.option("description", str)
self._group_id = self.option("groupId", int)
self._domain_id = self.option("domainId", int)
self._loading_screen = self.option("loadingScreen", bool)
@property
@ -35,6 +36,10 @@ class ShortUrlUpdateInput(InputABC):
def group_id(self) -> Optional[int]:
return self._group_id
@property
def domain_id(self) -> Optional[int]:
return self._domain_id
@property
def loading_screen(self) -> Optional[str]:
return self._loading_screen

View File

@ -0,0 +1,18 @@
from api_graphql.abc.input_abc import InputABC
class UserSettingInput(InputABC):
def __init__(self, src: dict):
InputABC.__init__(self, src)
self._key = self.option("key", str, required=True)
self._value = self.option("value", str, required=True)
@property
def key(self) -> str:
return self._key
@property
def value(self) -> str:
return self._value

View File

@ -33,6 +33,15 @@ class Mutation(MutationABC):
],
)
self.add_mutation_type(
"domain",
"Domain",
require_any_permission=[
Permissions.domains_create,
Permissions.domains_update,
Permissions.domains_delete,
],
)
self.add_mutation_type(
"group",
"Group",
@ -51,3 +60,22 @@ class Mutation(MutationABC):
Permissions.short_urls_delete,
],
)
self.add_mutation_type(
"setting",
"Setting",
require_any_permission=[
Permissions.settings_update,
],
)
self.add_mutation_type(
"userSetting",
"UserSetting",
)
self.add_mutation_type(
"featureFlag",
"FeatureFlag",
require_any_permission=[
Permissions.administrator,
],
)

View File

@ -0,0 +1,75 @@
from api_graphql.abc.mutation_abc import MutationABC
from api_graphql.input.domain_create_input import DomainCreateInput
from api_graphql.input.domain_update_input import DomainUpdateInput
from api_graphql.input.group_create_input import GroupCreateInput
from api_graphql.input.group_update_input import GroupUpdateInput
from core.logger import APILogger
from data.schemas.public.domain_dao import domainDao
from data.schemas.public.group import Group
from service.permission.permissions_enum import Permissions
logger = APILogger(__name__)
class DomainMutation(MutationABC):
def __init__(self):
MutationABC.__init__(self, "Domain")
self.mutation(
"create",
self.resolve_create,
DomainCreateInput,
require_any_permission=[Permissions.domains_create],
)
self.mutation(
"update",
self.resolve_update,
DomainUpdateInput,
require_any_permission=[Permissions.domains_update],
)
self.mutation(
"delete",
self.resolve_delete,
require_any_permission=[Permissions.domains_delete],
)
self.mutation(
"restore",
self.resolve_restore,
require_any_permission=[Permissions.domains_delete],
)
@staticmethod
async def resolve_create(obj: GroupCreateInput, *_):
logger.debug(f"create domain: {obj.__dict__}")
domain = Group(
0,
obj.name,
)
nid = await domainDao.create(domain)
return await domainDao.get_by_id(nid)
@staticmethod
async def resolve_update(obj: GroupUpdateInput, *_):
logger.debug(f"update domain: {input}")
if obj.name is not None:
domain = await domainDao.get_by_id(obj.id)
domain.name = obj.name
await domainDao.update(domain)
return await domainDao.get_by_id(obj.id)
@staticmethod
async def resolve_delete(*_, id: str):
logger.debug(f"delete domain: {id}")
domain = await domainDao.get_by_id(id)
await domainDao.delete(domain)
return True
@staticmethod
async def resolve_restore(*_, id: str):
logger.debug(f"restore domain: {id}")
domain = await domainDao.get_by_id(id)
await domainDao.restore(domain)
return True

View File

@ -0,0 +1,32 @@
from api_graphql.abc.mutation_abc import MutationABC
from api_graphql.input.feature_flag_input import FeatureFlagInput
from core.logger import APILogger
from data.schemas.system.feature_flag import FeatureFlag
from data.schemas.system.feature_flag_dao import featureFlagDao
from service.permission.permissions_enum import Permissions
logger = APILogger(__name__)
class FeatureFlagMutation(MutationABC):
def __init__(self):
MutationABC.__init__(self, "FeatureFlag")
self.mutation(
"change",
self.resolve_change,
FeatureFlagInput,
require_any_permission=[Permissions.administrator],
)
@staticmethod
async def resolve_change(obj: FeatureFlagInput, *_):
logger.debug(f"create new feature flag: {input}")
setting = await featureFlagDao.find_single_by({FeatureFlag.key: obj.key})
if setting is None:
raise ValueError(f"FeatureFlag {obj.key} not found")
setting.value = obj.value
await featureFlagDao.update(setting)
return await featureFlagDao.get_by_id(setting.id)

View File

@ -1,9 +1,13 @@
from typing import Optional
from api_graphql.abc.mutation_abc import MutationABC
from api_graphql.input.group_create_input import GroupCreateInput
from api_graphql.input.group_update_input import GroupUpdateInput
from core.logger import APILogger
from data.schemas.public.group import Group
from data.schemas.public.group_dao import groupDao
from data.schemas.public.group_role_assignment import GroupRoleAssignment
from data.schemas.public.group_role_assignment_dao import groupRoleAssignmentDao
from service.permission.permissions_enum import Permissions
logger = APILogger(__name__)
@ -37,25 +41,61 @@ class GroupMutation(MutationABC):
)
@staticmethod
async def resolve_create(obj: GroupCreateInput, *_):
async def _handle_group_role_assignments(gid: int, roles: Optional[list[int]]):
if roles is None:
return
existing_roles = await groupDao.get_roles(gid)
existing_role_ids = {role.id for role in existing_roles}
new_role_ids = set(roles)
roles_to_add = new_role_ids - existing_role_ids
roles_to_remove = existing_role_ids - new_role_ids
if roles_to_add:
group_role_assignments = [
GroupRoleAssignment(0, gid, role_id) for role_id in roles_to_add
]
await groupRoleAssignmentDao.create_many(group_role_assignments)
if roles_to_remove:
assignments_to_remove = await groupRoleAssignmentDao.find_by(
[
{GroupRoleAssignment.group_id: gid},
{GroupRoleAssignment.role_id: {"in": roles_to_remove}},
]
)
await groupRoleAssignmentDao.delete_many(assignments_to_remove)
@classmethod
async def resolve_create(cls, obj: GroupCreateInput, *_):
logger.debug(f"create group: {obj.__dict__}")
group = Group(
0,
obj.name,
)
nid = await groupDao.create(group)
return await groupDao.get_by_id(nid)
gid = await groupDao.create(group)
@staticmethod
async def resolve_update(obj: GroupUpdateInput, *_):
await cls._handle_group_role_assignments(gid, obj.roles)
return await groupDao.get_by_id(gid)
@classmethod
async def resolve_update(cls, obj: GroupUpdateInput, *_):
logger.debug(f"update group: {input}")
if await groupDao.find_by_id(obj.id) is None:
raise ValueError(f"Group with id {obj.id} not found")
if obj.name is not None:
group = await groupDao.get_by_id(obj.id)
group.name = obj.name
await groupDao.update(group)
await cls._handle_group_role_assignments(obj.id, obj.roles)
return await groupDao.get_by_id(obj.id)
@staticmethod

View File

@ -0,0 +1,32 @@
from api_graphql.abc.mutation_abc import MutationABC
from api_graphql.input.setting_input import SettingInput
from core.logger import APILogger
from data.schemas.system.setting import Setting
from data.schemas.system.setting_dao import settingsDao
from service.permission.permissions_enum import Permissions
logger = APILogger(__name__)
class SettingMutation(MutationABC):
def __init__(self):
MutationABC.__init__(self, "Setting")
self.mutation(
"change",
self.resolve_change,
SettingInput,
require_any_permission=[Permissions.settings_update],
)
@staticmethod
async def resolve_change(obj: SettingInput, *_):
logger.debug(f"create new setting: {input}")
setting = await settingsDao.find_single_by({Setting.key: obj.key})
if setting is None:
raise ValueError(f"Setting with key {obj.key} not found")
setting.value = obj.value
await settingsDao.update(setting)
return await settingsDao.get_by_id(setting.id)

View File

@ -1,12 +1,13 @@
from werkzeug.exceptions import NotFound
from api_graphql.abc.mutation_abc import MutationABC
from api_graphql.input.short_url_create_input import ShortUrlCreateInput
from api_graphql.input.short_url_update_input import ShortUrlUpdateInput
from core.logger import APILogger
from data.schemas.public.domain_dao import domainDao
from data.schemas.public.group_dao import groupDao
from data.schemas.public.short_url import ShortUrl
from data.schemas.public.short_url_dao import shortUrlDao
from data.schemas.public.short_url_visit import ShortUrlVisit
from data.schemas.public.short_url_visit_dao import shortUrlVisitDao
from service.permission.permissions_enum import Permissions
logger = APILogger(__name__)
@ -38,6 +39,11 @@ class ShortUrlMutation(MutationABC):
self.resolve_restore,
require_any_permission=[Permissions.short_urls_delete],
)
self.mutation(
"trackVisit",
self.resolve_track_visit,
require_any_permission=[Permissions.short_urls_update],
)
@staticmethod
async def resolve_create(obj: ShortUrlCreateInput, *_):
@ -49,6 +55,7 @@ class ShortUrlMutation(MutationABC):
obj.target_url,
obj.description,
obj.group_id,
obj.domain_id,
obj.loading_screen,
)
nid = await shortUrlDao.create(short_url)
@ -72,8 +79,18 @@ class ShortUrlMutation(MutationABC):
if obj.group_id is not None:
group_by_id = await groupDao.find_by_id(obj.group_id)
if group_by_id is None:
raise NotFound(f"Group with id {obj.group_id} does not exist")
raise ValueError(f"Group with id {obj.group_id} does not exist")
short_url.group_id = obj.group_id
else:
short_url.group_id = None
if obj.domain_id is not None:
domain_by_id = await domainDao.find_by_id(obj.domain_id)
if domain_by_id is None:
raise ValueError(f"Domain with id {obj.domain_id} does not exist")
short_url.domain_id = obj.domain_id
else:
short_url.domain_id = None
if obj.loading_screen is not None:
short_url.loading_screen = obj.loading_screen
@ -94,3 +111,9 @@ class ShortUrlMutation(MutationABC):
short_url = await shortUrlDao.get_by_id(id)
await shortUrlDao.restore(short_url)
return True
@staticmethod
async def resolve_track_visit(*_, id: int, agent: str):
logger.debug(f"track visit: {id} -- {agent}")
await shortUrlVisitDao.create(ShortUrlVisit(0, id, agent))
return True

View File

@ -0,0 +1,40 @@
from api.route import Route
from api_graphql.abc.mutation_abc import MutationABC
from api_graphql.input.user_setting_input import UserSettingInput
from core.logger import APILogger
from data.schemas.public.user_setting import UserSetting
from data.schemas.public.user_setting_dao import userSettingsDao
from data.schemas.system.setting_dao import settingsDao
from service.permission.permissions_enum import Permissions
logger = APILogger(__name__)
class UserSettingMutation(MutationABC):
def __init__(self):
MutationABC.__init__(self, "UserSetting")
self.mutation(
"change",
self.resolve_change,
UserSettingInput,
require_any_permission=[Permissions.settings_update],
)
@staticmethod
async def resolve_change(obj: UserSettingInput, *_):
logger.debug(f"create new setting: {input}")
user = await Route.get_user_or_default()
if user is None:
logger.debug("user not authorized")
return None
setting = await userSettingsDao.find_single_by(
[{UserSetting.user_id: user.id}, {UserSetting.key: obj.key}]
)
if setting is None:
await userSettingsDao.create(UserSetting(0, user.id, obj.key, obj.value))
else:
setting.value = obj.value
await userSettingsDao.update(setting)
return await userSettingsDao.find_by_key(user, obj.key)

View File

@ -0,0 +1,17 @@
from api_graphql.abc.db_model_query_abc import DbModelQueryABC
from data.schemas.public.domain import Domain
from data.schemas.public.group import Group
from data.schemas.public.short_url import ShortUrl
from data.schemas.public.short_url_dao import shortUrlDao
class DomainQuery(DbModelQueryABC):
def __init__(self):
DbModelQueryABC.__init__(self, "Domain")
self.set_field("name", lambda x, *_: x.name)
self.set_field("shortUrls", self._get_urls)
@staticmethod
async def _get_urls(domain: Domain, *_):
return await shortUrlDao.find_by({ShortUrl.domain_id: domain.id})

View File

@ -1,7 +1,11 @@
from api_graphql.abc.db_model_query_abc import DbModelQueryABC
from api_graphql.field.resolver_field_builder import ResolverFieldBuilder
from api_graphql.require_any_resolvers import group_by_assignment_resolver
from data.schemas.public.group import Group
from data.schemas.public.group_dao import groupDao
from data.schemas.public.short_url import ShortUrl
from data.schemas.public.short_url_dao import shortUrlDao
from service.permission.permissions_enum import Permissions
class GroupQuery(DbModelQueryABC):
@ -9,8 +13,22 @@ class GroupQuery(DbModelQueryABC):
DbModelQueryABC.__init__(self, "Group")
self.set_field("name", lambda x, *_: x.name)
self.set_field("shortUrls", self._get_urls)
self.field(
ResolverFieldBuilder("shortUrls")
.with_resolver(self._get_urls)
.with_require_any(
[
Permissions.groups,
],
[group_by_assignment_resolver],
)
)
self.set_field("roles", self._get_roles)
@staticmethod
async def _get_urls(group: Group, *_):
return await shortUrlDao.find_by({ShortUrl.group_id: group.id})
@staticmethod
async def _get_roles(group: Group, *_):
return await groupDao.get_roles(group.id)

View File

@ -9,5 +9,6 @@ class ShortUrlQuery(DbModelQueryABC):
self.set_field("targetUrl", lambda x, *_: x.target_url)
self.set_field("description", lambda x, *_: x.description)
self.set_field("group", lambda x, *_: x.group)
self.set_field("domain", lambda x, *_: x.domain)
self.set_field("visits", lambda x, *_: x.visit_count)
self.set_field("loadingScreen", lambda x, *_: x.loading_screen)

View File

@ -5,11 +5,13 @@ from api_graphql.abc.sort_abc import Sort
from api_graphql.field.dao_field_builder import DaoFieldBuilder
from api_graphql.field.resolver_field_builder import ResolverFieldBuilder
from api_graphql.filter.api_key_filter import ApiKeyFilter
from api_graphql.filter.domain_filter import DomainFilter
from api_graphql.filter.group_filter import GroupFilter
from api_graphql.filter.permission_filter import PermissionFilter
from api_graphql.filter.role_filter import RoleFilter
from api_graphql.filter.short_url_filter import ShortUrlFilter
from api_graphql.filter.user_filter import UserFilter
from api_graphql.require_any_resolvers import group_by_assignment_resolver
from data.schemas.administration.api_key import ApiKey
from data.schemas.administration.api_key_dao import apiKeyDao
from data.schemas.administration.user import User
@ -18,10 +20,16 @@ from data.schemas.permission.permission import Permission
from data.schemas.permission.permission_dao import permissionDao
from data.schemas.permission.role import Role
from data.schemas.permission.role_dao import roleDao
from data.schemas.public.domain import Domain
from data.schemas.public.domain_dao import domainDao
from data.schemas.public.group import Group
from data.schemas.public.group_dao import groupDao
from data.schemas.public.short_url import ShortUrl
from data.schemas.public.short_url_dao import shortUrlDao
from data.schemas.public.user_setting import UserSetting
from data.schemas.public.user_setting_dao import userSettingsDao
from data.schemas.system.feature_flag_dao import featureFlagDao
from data.schemas.system.setting_dao import settingsDao
from service.permission.permissions_enum import Permissions
@ -48,7 +56,15 @@ class Query(QueryABC):
.with_dao(roleDao)
.with_filter(RoleFilter)
.with_sort(Sort[Role])
.with_require_any_permission([Permissions.roles])
.with_require_any_permission(
[
Permissions.roles,
Permissions.users_create,
Permissions.users_update,
Permissions.groups_create,
Permissions.groups_update,
]
)
)
self.field(
@ -81,25 +97,55 @@ class Query(QueryABC):
)
self.field(
DaoFieldBuilder("groups")
.with_dao(groupDao)
.with_filter(GroupFilter)
.with_sort(Sort[Group])
DaoFieldBuilder("domains")
.with_dao(domainDao)
.with_filter(DomainFilter)
.with_sort(Sort[Domain])
.with_require_any_permission(
[
Permissions.groups,
Permissions.domains,
Permissions.short_urls_create,
Permissions.short_urls_update,
]
)
)
# partially public to load redirect if not resolved/redirected by api
self.field(
DaoFieldBuilder("groups")
.with_dao(groupDao)
.with_filter(GroupFilter)
.with_sort(Sort[Group])
.with_require_any(
[
Permissions.groups,
Permissions.short_urls_create,
Permissions.short_urls_update,
],
[group_by_assignment_resolver],
)
)
self.field(
DaoFieldBuilder("shortUrls")
.with_dao(shortUrlDao)
.with_filter(ShortUrlFilter)
.with_sort(Sort[ShortUrl])
.with_require_any_permission([Permissions.short_urls])
.with_require_any([Permissions.short_urls], [group_by_assignment_resolver])
)
self.field(
ResolverFieldBuilder("settings")
.with_resolver(self._resolve_settings)
.with_direct_result()
.with_public(True)
)
self.field(
ResolverFieldBuilder("userSettings")
.with_resolver(self._resolve_user_settings)
.with_direct_result()
)
self.field(
ResolverFieldBuilder("featureFlags")
.with_resolver(self._resolve_feature_flags)
.with_direct_result()
)
@staticmethod
@ -132,3 +178,27 @@ class Query(QueryABC):
for x in kc_users
if x["id"] not in existing_user_keycloak_ids
]
@staticmethod
async def _resolve_settings(*args, **kwargs):
if "key" in kwargs:
return [await settingsDao.find_by_key(kwargs["key"])]
return await settingsDao.get_all()
@staticmethod
async def _resolve_user_settings(*args, **kwargs):
user = await Route.get_user()
if user is None:
return None
if "key" in kwargs:
return await userSettingsDao.find_by(
{UserSetting.user_id: user.id, UserSetting.key: kwargs["key"]}
)
return await userSettingsDao.find_by({UserSetting.user_id: user.id})
@staticmethod
async def _resolve_feature_flags(*args, **kwargs):
if "key" in kwargs:
return [await featureFlagDao.find_by_key(kwargs["key"])]
return await featureFlagDao.get_all()

View File

@ -0,0 +1,30 @@
from api_graphql.service.collection_result import CollectionResult
from api_graphql.service.query_context import QueryContext
from data.schemas.public.group_dao import groupDao
from service.permission.permissions_enum import Permissions
async def group_by_assignment_resolver(ctx: QueryContext) -> bool:
if not isinstance(ctx.data, CollectionResult):
return False
if ctx.has_permission(Permissions.short_urls_by_assignment):
groups = [await x.group for x in ctx.data.nodes]
role_ids = {x.id for x in await ctx.user.roles}
filtered_groups = [
g.id
for g in groups
if g is not None
and (roles := await groupDao.get_roles(g.id))
and all(r.id in role_ids for r in roles)
]
ctx.data.nodes = [
node
for node in ctx.data.nodes
if (await node.group) is not None
and (await node.group).id in filtered_groups
]
return True
return True

View File

@ -14,7 +14,7 @@ class QueryContext:
self,
data: Any,
user: Optional[User],
user_permissions: Optional[list[Permission]],
user_permissions: Optional[list[Permissions]],
*args,
**kwargs
):
@ -23,7 +23,7 @@ class QueryContext:
self._user = user
if user_permissions is None:
user_permissions = []
self._user_permissions: list[str] = [x.name for x in user_permissions]
self._user_permissions: list[str] = [x.value for x in user_permissions]
self._resolve_info = None
for arg in args:

View File

@ -5,6 +5,7 @@ from ariadne import make_executable_schema, load_schema_from_path
from api_graphql.definition import QUERIES
from api_graphql.mutation import Mutation
from api_graphql.query import Query
from api_graphql.subscription import Subscription
type_defs = load_schema_from_path(
os.path.join(os.path.dirname(os.path.realpath(__file__)), "../graphql/")
@ -13,5 +14,6 @@ schema = make_executable_schema(
type_defs,
Query(),
Mutation(),
Subscription(),
*QUERIES,
)

View File

@ -0,0 +1,66 @@
from api_graphql.abc.subscription_abc import SubscriptionABC
from api_graphql.field.subscription_field_builder import SubscriptionFieldBuilder
from service.permission.permissions_enum import Permissions
class Subscription(SubscriptionABC):
def __init__(self):
SubscriptionABC.__init__(self)
self.subscribe(
SubscriptionFieldBuilder("ping")
.with_resolver(lambda message, *_: message.message)
.with_public(True)
)
self.subscribe(
SubscriptionFieldBuilder("apiKeyChange")
.with_resolver(lambda message, *_: message.message)
.with_require_any_permission([Permissions.api_keys])
)
self.subscribe(
SubscriptionFieldBuilder("featureFlagChange")
.with_resolver(lambda message, *_: message.message)
.with_public(True)
)
self.subscribe(
SubscriptionFieldBuilder("roleChange")
.with_resolver(lambda message, *_: message.message)
.with_require_any_permission([Permissions.roles])
)
self.subscribe(
SubscriptionFieldBuilder("settingChange")
.with_resolver(lambda message, *_: message.message)
.with_require_any_permission([Permissions.settings])
)
self.subscribe(
SubscriptionFieldBuilder("userChange")
.with_resolver(lambda message, *_: message.message)
.with_require_any_permission([Permissions.users])
)
self.subscribe(
SubscriptionFieldBuilder("userSettingChange")
.with_resolver(lambda message, *_: message.message)
.with_public(True)
)
self.subscribe(
SubscriptionFieldBuilder("domainChange")
.with_resolver(lambda message, *_: message.message)
.with_require_any_permission([Permissions.domains])
)
self.subscribe(
SubscriptionFieldBuilder("groupChange")
.with_resolver(lambda message, *_: message.message)
.with_require_any_permission([Permissions.groups])
)
self.subscribe(
SubscriptionFieldBuilder("shortUrlChange")
.with_resolver(lambda message, *_: message.message)
.with_require_any_permission([Permissions.short_urls])
)

View File

@ -1,11 +1,15 @@
from collections.abc import Awaitable
from typing import Callable, Union, Optional
from typing import Callable, Union, Optional, Coroutine, Any
from api_graphql.service.query_context import QueryContext
from service.permission.permissions_enum import Permissions
TRequireAnyPermissions = Optional[list[Permissions]]
TRequireAnyResolvers = list[
Union[Callable[[QueryContext], bool], Awaitable[[QueryContext], bool]]
Union[
Callable[[QueryContext], bool],
Awaitable[[QueryContext], bool],
Callable[[QueryContext], Coroutine[Any, Any, bool]],
]
]
TRequireAny = tuple[TRequireAnyPermissions, TRequireAnyResolvers]

View File

View File

@ -0,0 +1,20 @@
from core.configuration.feature_flags_enum import FeatureFlagsEnum
from data.schemas.system.feature_flag_dao import featureFlagDao
class FeatureFlags:
_flags = {
FeatureFlagsEnum.version_endpoint.value: True, # 15.01.2025
}
@staticmethod
def get_default(key: FeatureFlagsEnum) -> bool:
return FeatureFlags._flags[key.value]
@staticmethod
async def has_feature(key: FeatureFlagsEnum) -> bool:
value = await featureFlagDao.find_by_key(key.value)
if value is None:
return False
return value.value

View File

@ -0,0 +1,6 @@
from enum import Enum
class FeatureFlagsEnum(Enum):
# modules
version_endpoint = "VersionEndpoint"

1
api/src/core/const.py Normal file
View File

@ -0,0 +1 @@
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f %z"

View File

@ -4,9 +4,12 @@ from enum import Enum
from types import NoneType
from typing import Generic, Optional, Union, TypeVar, Any, Type
from core.const import DATETIME_FORMAT
from core.database.abc.db_model_abc import DbModelABC
from core.database.database import Database
from core.get_value import get_value
from core.logger import DBLogger
from core.string import camel_to_snake
from core.typing import T, Attribute, AttributeFilters, AttributeSorts
T_DBM = TypeVar("T_DBM", bound=DbModelABC)
@ -23,7 +26,11 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
self._default_filter_condition = None
self.__attributes: dict[str, type] = {}
self.__joins: dict[str, str] = {}
self.__db_names: dict[str, str] = {}
self.__foreign_tables: dict[str, str] = {}
self.__date_attributes: set[str] = set()
self.__ignored_attributes: set[str] = set()
@ -69,6 +76,40 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
if attr_type in [datetime, datetime.datetime]:
self.__date_attributes.add(db_name)
def reference(
self,
attr: Attribute,
primary_attr: Attribute,
foreign_attr: Attribute,
table_name: str,
):
"""
Add a reference to another table for the given attribute
:param str primary_attr: Name of the primary key in the foreign object
:param str foreign_attr: Name of the foreign key in the object
:param str table_name: Name of the table to reference
:return:
"""
if table_name == self._table_name:
return
if isinstance(attr, property):
attr = attr.fget.__name__
if isinstance(primary_attr, property):
primary_attr = primary_attr.fget.__name__
primary_attr = primary_attr.lower().replace("_", "")
if isinstance(foreign_attr, property):
foreign_attr = foreign_attr.fget.__name__
foreign_attr = foreign_attr.lower().replace("_", "")
self.__joins[foreign_attr] = (
f"LEFT JOIN {table_name} ON {table_name}.{primary_attr} = {self._table_name}.{foreign_attr}"
)
self.__foreign_tables[attr] = table_name
def to_object(self, result: dict) -> T_DBM:
"""
Convert a result from the database to an object
@ -89,8 +130,13 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
return self._model_type(**value_map)
async def count(self) -> int:
result = await self._db.select_map(f"SELECT COUNT(*) FROM {self._table_name}")
async def count(self, filters: AttributeFilters = None) -> int:
query = f"SELECT COUNT(*) FROM {self._table_name}"
if filters is not None and (not isinstance(filters, list) or len(filters) > 0):
query += f" WHERE {self._build_conditions(filters)}"
result = await self._db.select_map(query)
return result[0]["count"]
async def get_all(self) -> list[T_DBM]:
@ -370,6 +416,9 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
if isinstance(value, NoneType):
return "NULL"
if value is None:
return "NULL"
if isinstance(value, Enum):
return str(value.value)
@ -381,6 +430,12 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
return "ARRAY[]::text[]"
return f"ARRAY[{", ".join([DataAccessObjectABC._get_value_sql(x) for x in value])}]"
if isinstance(value, datetime.datetime):
if value.tzinfo is None:
value = value.replace(tzinfo=datetime.timezone.utc)
return f"'{value.strftime(DATETIME_FORMAT)}'"
return str(value)
@staticmethod
@ -409,15 +464,18 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
take: int = None,
skip: int = None,
) -> str:
query = f"SELECT * FROM {self._table_name}"
query = f"SELECT {self._table_name}.* FROM {self._table_name}"
if filters and len(filters) > 0:
for join in self.__joins:
query += f" {self.__joins[join]}"
if filters is not None and (not isinstance(filters, list) or len(filters) > 0):
query += f" WHERE {self._build_conditions(filters)}"
if sorts and len(sorts) > 0:
if sorts is not None and (not isinstance(sorts, list) or len(sorts) > 0):
query += f" ORDER BY {self._build_order_by(sorts)}"
if take:
if take is not None:
query += f" LIMIT {take}"
if skip:
if skip is not None:
query += f" OFFSET {skip}"
return query
@ -435,12 +493,41 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
for attr, values in f.items():
if isinstance(attr, property):
attr = attr.fget.__name__
if attr in self.__foreign_tables:
foreign_table = self.__foreign_tables[attr]
conditions.extend(
self._build_foreign_conditions(foreign_table, values)
)
continue
if attr == "fuzzy":
conditions.append(
" OR ".join(
self._build_fuzzy_conditions(
[
(
self.__db_names[x]
if x in self.__db_names
else self.__db_names[camel_to_snake(x)]
)
for x in get_value(values, "fields", list[str])
],
get_value(values, "term", str),
get_value(values, "threshold", int, 5),
)
)
)
continue
db_name = self.__db_names[attr]
if isinstance(values, dict):
for operator, value in values.items():
conditions.append(
self._build_condition(db_name, operator, value)
self._build_condition(
f"{self._table_name}.{db_name}", operator, value
)
)
elif isinstance(values, list):
sub_conditions = []
@ -448,18 +535,80 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
if isinstance(value, dict):
for operator, val in value.items():
sub_conditions.append(
self._build_condition(db_name, operator, val)
self._build_condition(
f"{self._table_name}.{db_name}", operator, val
)
)
else:
sub_conditions.append(
f"{db_name} = {self._get_value_sql(value)}"
self._get_value_validation_sql(db_name, value)
)
conditions.append(f"({' OR '.join(sub_conditions)})")
else:
conditions.append(f"{db_name} = {self._get_value_sql(values)}")
conditions.append(self._get_value_validation_sql(db_name, values))
return " AND ".join(conditions)
def _build_fuzzy_conditions(
self, fields: list[str], term: str, threshold: int = 10
) -> list[str]:
conditions = []
for field in fields:
conditions.append(
f"levenshtein({field}, '{term}') <= {threshold}"
) # Adjust the threshold as needed
return conditions
def _build_foreign_conditions(self, table: str, values: dict) -> list[str]:
"""
Build SQL conditions for foreign key references
:param table: Foreign table name
:param values: Filter values
:return: List of conditions
"""
conditions = []
for attr, sub_values in values.items():
if isinstance(attr, property):
attr = attr.fget.__name__
if attr in self.__foreign_tables:
foreign_table = self.__foreign_tables[attr]
conditions.extend(
self._build_foreign_conditions(foreign_table, sub_values)
)
continue
db_name = f"{table}.{attr.lower().replace('_', '')}"
if isinstance(sub_values, dict):
for operator, value in sub_values.items():
conditions.append(self._build_condition(db_name, operator, value))
elif isinstance(sub_values, list):
sub_conditions = []
for value in sub_values:
if isinstance(value, dict):
for operator, val in value.items():
sub_conditions.append(
self._build_condition(db_name, operator, val)
)
else:
sub_conditions.append(
self._get_value_validation_sql(db_name, value)
)
conditions.append(f"({' OR '.join(sub_conditions)})")
else:
conditions.append(self._get_value_validation_sql(db_name, sub_values))
return conditions
def _get_value_validation_sql(self, field: str, value: Any):
value = self._get_value_sql(value)
if value == "NULL":
return f"{self._table_name}.{field} IS NULL"
return f"{self._table_name}.{field} = {value}"
def _build_condition(self, db_name: str, operator: str, value: Any) -> str:
"""
Build individual SQL condition based on the operator
@ -520,6 +669,13 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
if isinstance(attr, property):
attr = attr.fget.__name__
if attr in self.__foreign_tables:
foreign_table = self.__foreign_tables[attr]
sort_clauses.extend(
self._build_foreign_order_by(foreign_table, direction)
)
continue
match attr:
case "createdUtc":
attr = "created"
@ -537,6 +693,30 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
return ", ".join(sort_clauses)
def _build_foreign_order_by(self, table: str, direction: str) -> list[str]:
"""
Build SQL order by clause for foreign key references
:param table: Foreign table name
:param direction: Sort direction
:return: List of order by clauses
"""
sort_clauses = []
for attr, sub_direction in direction.items():
if isinstance(attr, property):
attr = attr.fget.__name__
if attr in self.__foreign_tables:
foreign_table = self.__foreign_tables[attr]
sort_clauses.extend(
self._build_foreign_order_by(foreign_table, sub_direction)
)
continue
db_name = f"{table}.{attr.lower().replace('_', '')}"
sort_clauses.append(f"{db_name} {sub_direction.upper()}")
return sort_clauses
@staticmethod
async def _get_editor_id(obj: T_DBM):
editor_id = obj.editor_id

View File

@ -1,5 +1,4 @@
import ast
from typing import Type, Optional, Any
from typing import Type, Optional
from core.typing import T
@ -22,26 +21,38 @@ def get_value(
:rtype: Optional[T]
"""
if key in source:
value = source[key]
if isinstance(value, cast_type):
return value
try:
if cast_type == bool:
return value.lower() in ["true", "1"]
if cast_type == list:
subtype = (
cast_type.__args__[0] if hasattr(cast_type, "__args__") else None
)
value = ast.literal_eval(value)
return [
subtype(item) if subtype is not None else item for item in value
]
return cast_type(value)
except (ValueError, TypeError):
return default
else:
if key not in source:
return default
value = source[key]
if isinstance(
value,
cast_type if not hasattr(cast_type, "__origin__") else cast_type.__origin__,
):
return value
try:
if cast_type == bool:
return value.lower() in ["true", "1"]
if (
cast_type if not hasattr(cast_type, "__origin__") else cast_type.__origin__
) == list:
if (
not (value.startswith("[") and value.endswith("]"))
and list_delimiter not in value
):
raise ValueError(
"List values must be enclosed in square brackets or use a delimiter."
)
if value.startswith("[") and value.endswith("]"):
value = value[1:-1]
value = value.split(list_delimiter)
subtype = cast_type.__args__[0] if hasattr(cast_type, "__args__") else None
return [subtype(item) if subtype is not None else item for item in value]
return cast_type(value)
except (ValueError, TypeError):
return default

View File

@ -1,8 +1,13 @@
import asyncio
import os
import traceback
from datetime import datetime
from api.middleware.request import get_request
from core.environment import Environment
class Logger:
_level = "info"
_levels = ["trace", "debug", "info", "warning", "error", "fatal"]
@ -54,6 +59,34 @@ class Logger:
else:
raise ValueError(f"Invalid log level: {level}")
def _get_structured_message(self, level: str, timestamp: str, messages: str) -> str:
structured_message = {
"timestamp": timestamp,
"level": level.upper(),
"source": self.source,
"messages": messages,
}
request = get_request()
if request is not None:
structured_message["request"] = {
"url": str(request.url),
"method": request.method if request.scope == "http" else "ws",
"data": (
asyncio.create_task(request.body())
if request.scope == "http"
else None
),
}
return str(structured_message)
def _write_log_to_file(self, content: str):
self._ensure_file_size()
with open(self.log_file, "a") as log_file:
log_file.write(content + "\n")
log_file.close()
def _log(self, level: str, *messages):
try:
if self._levels.index(level) < self._levels.index(self._level):
@ -63,17 +96,18 @@ class Logger:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
formatted_message = f"<{timestamp}> [{level.upper():^7}] [{self._file_prefix:^5}] - [{self.source}]: {' '.join(messages)}"
self._ensure_file_size()
with open(self.log_file, "a") as log_file:
log_file.write(formatted_message + "\n")
log_file.close()
if Environment.get("STRUCTURED_LOGGING", bool, False):
self._write_log_to_file(
self._get_structured_message(level, timestamp, " ".join(messages))
)
else:
self._write_log_to_file(formatted_message)
color = self.COLORS.get(level, self.COLORS["reset"])
reset_color = self.COLORS["reset"]
print(f"{color}{formatted_message}{reset_color}")
print(
f"{self.COLORS.get(level, self.COLORS["reset"])}{formatted_message}{self.COLORS["reset"]}"
)
except Exception as e:
print(f"Error while logging: {e}")
print(f"Error while logging: {e} -> {traceback.format_exc()}")
def trace(self, *messages):
self._log("trace", *messages)

9
api/src/core/string.py Normal file
View File

@ -0,0 +1,9 @@
import re
def first_to_lower(s: str) -> str:
return s[0].lower() + s[1:] if s else s
def camel_to_snake(s: str) -> str:
return re.sub(r"(?<!^)(?=[A-Z])", "_", s).lower()

View File

@ -42,17 +42,9 @@ class User(DbModelABC):
@async_property
async def permissions(self):
from data.schemas.permission.role_user_dao import roleUserDao
from data.schemas.permission.role_permission_dao import rolePermissionDao
from data.schemas.permission.permission_dao import permissionDao
from data.schemas.administration.user_dao import userDao
x = [
rp.permission_id
for x in await roleUserDao.get_by_user_id(self.id)
for rp in await rolePermissionDao.get_by_role_id(x.role_id)
]
return await permissionDao.get_by({"id": {"in": x}})
return await userDao.get_permissions(self.id)
async def has_permission(self, permission: Permissions) -> bool:
from data.schemas.administration.user_dao import userDao

View File

@ -33,13 +33,30 @@ class UserDao(DbModelDaoABC[User]):
SELECT COUNT(*)
FROM permission.role_users ru
JOIN permission.role_permissions rp ON ru.roleId = rp.roleId
WHERE ru.userId = {user_id} AND rp.permissionId = {p.id};
WHERE ru.userId = {user_id}
AND rp.permissionId = {p.id}
AND ru.deleted = FALSE
AND rp.deleted = FALSE;
"""
)
if result is None or len(result) == 0:
return False
return True
return result[0]["count"] > 0
async def get_permissions(self, user_id: int) -> list[Permissions]:
result = await self._db.select_map(
f"""
SELECT p.*
FROM permission.permissions p
JOIN permission.role_permissions rp ON p.id = rp.permissionId
JOIN permission.role_users ru ON rp.roleId = ru.roleId
WHERE ru.userId = {user_id}
AND rp.deleted = FALSE
AND ru.deleted = FALSE;
"""
)
return [Permissions(p["name"]) for p in result]
userDao = UserDao()

View File

@ -0,0 +1,27 @@
from datetime import datetime
from typing import Optional
from core.database.abc.db_model_abc import DbModelABC
from core.typing import SerialId
class Domain(DbModelABC):
def __init__(
self,
id: SerialId,
name: str,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name
@property
def name(self) -> str:
return self._name
@name.setter
def name(self, value: str):
self._name = value

View File

@ -0,0 +1,22 @@
from core.logger import DBLogger
from data.schemas.public.domain import Domain
from data.schemas.public.group import Group
logger = DBLogger(__name__)
from core.database.abc.db_model_dao_abc import DbModelDaoABC
class DomainDao(DbModelDaoABC[Group]):
def __init__(self):
DbModelDaoABC.__init__(self, __name__, Group, "public.domains")
self.attribute(Domain.name, str)
async def get_by_name(self, name: str) -> Group:
result = await self._db.select_map(
f"SELECT * FROM {self._table_name} WHERE Name = '{name}'"
)
return self.to_object(result[0])
domainDao = DomainDao()

View File

@ -17,5 +17,19 @@ class GroupDao(DbModelDaoABC[Group]):
)
return self.to_object(result[0])
async def get_roles(self, group_id: int):
result = await self._db.select_map(
f"""
SELECT r.*
FROM permission.roles r
JOIN public.group_role_assignments gra ON r.id = gra.roleId
WHERE gra.groupId = {group_id}
AND gra.deleted = FALSE
"""
)
from data.schemas.permission.role_dao import roleDao
return [roleDao.to_object(x) for x in result]
groupDao = GroupDao()

View File

@ -0,0 +1,43 @@
from datetime import datetime
from typing import Optional
from async_property import async_property
from core.database.abc.db_model_abc import DbModelABC
from core.typing import SerialId
class GroupRoleAssignment(DbModelABC):
def __init__(
self,
id: SerialId,
group_id: SerialId,
role_id: SerialId,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._group_id = group_id
self._role_id = role_id
@property
def group_id(self) -> SerialId:
return self._group_id
@async_property
async def group(self):
from data.schemas.public.group_dao import groupDao
return await groupDao.get_by_id(self._group_id)
@property
def role_id(self) -> SerialId:
return self._role_id
@async_property
async def role(self):
from data.schemas.permission.role_dao import roleDao
return await roleDao.get_by_id(self._role_id)

View File

@ -0,0 +1,37 @@
from core.logger import DBLogger
from data.schemas.public.group_role_assignment import GroupRoleAssignment
logger = DBLogger(__name__)
from core.database.abc.db_model_dao_abc import DbModelDaoABC
class GroupRoleAssignmentDao(DbModelDaoABC[GroupRoleAssignment]):
def __init__(self):
DbModelDaoABC.__init__(
self, __name__, GroupRoleAssignment, "public.group_role_assignments"
)
self.attribute(GroupRoleAssignment.group_id, int)
self.attribute(GroupRoleAssignment.role_id, int)
async def get_by_group_id(
self, gid: int, with_deleted=False
) -> list[GroupRoleAssignment]:
f = [{GroupRoleAssignment.group_id: gid}]
if not with_deleted:
f.append({GroupRoleAssignment.deleted: False})
return await self.find_by(f)
async def get_by_role_id(
self, rid: int, with_deleted=False
) -> list[GroupRoleAssignment]:
f = [{GroupRoleAssignment.role_id: rid}]
if not with_deleted:
f.append({GroupRoleAssignment.deleted: False})
return await self.find_by(f)
groupRoleAssignmentDao = GroupRoleAssignmentDao()

View File

@ -16,6 +16,7 @@ class ShortUrl(DbModelABC):
target_url: str,
description: Optional[str],
group_id: Optional[SerialId],
domain_id: Optional[SerialId],
loading_screen: Optional[str] = None,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
@ -27,6 +28,10 @@ class ShortUrl(DbModelABC):
self._target_url = target_url
self._description = description
self._group_id = group_id
self._domain_id = domain_id
if loading_screen is None or loading_screen == "":
loading_screen = False
self._loading_screen = loading_screen
@property
@ -70,6 +75,23 @@ class ShortUrl(DbModelABC):
return await groupDao.get_by_id(self._group_id)
@property
def domain_id(self) -> SerialId:
return self._domain_id
@domain_id.setter
def domain_id(self, value: SerialId):
self._domain_id = value
@async_property
async def domain(self) -> Optional[Group]:
if self._domain_id is None:
return None
from data.schemas.public.domain_dao import domainDao
return await domainDao.get_by_id(self._domain_id)
@async_property
async def visit_count(self) -> int:
from data.schemas.public.short_url_visit_dao import shortUrlVisitDao

View File

@ -13,6 +13,7 @@ class ShortUrlDao(DbModelDaoABC[ShortUrl]):
self.attribute(ShortUrl.target_url, str)
self.attribute(ShortUrl.description, str)
self.attribute(ShortUrl.group_id, int)
self.attribute(ShortUrl.domain_id, int)
self.attribute(ShortUrl.loading_screen, bool)

View File

@ -0,0 +1,48 @@
from datetime import datetime
from typing import Optional, Union
from async_property import async_property
from core.database.abc.db_model_abc import DbModelABC
from core.typing import SerialId
class UserSetting(DbModelABC):
def __init__(
self,
id: SerialId,
user_id: SerialId,
key: str,
value: str,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._user_id = user_id
self._key = key
self._value = value
@property
def user_id(self) -> SerialId:
return self._user_id
@async_property
async def user(self):
from data.schemas.administration.user_dao import userDao
return await userDao.get_by_id(self._user_id)
@property
def key(self) -> str:
return self._key
@property
def value(self) -> str:
return self._value
@value.setter
def value(self, value: Union[str, int, float, bool]):
self._value = str(value)

View File

@ -0,0 +1,24 @@
from core.database.abc.db_model_dao_abc import DbModelDaoABC
from core.logger import DBLogger
from data.schemas.administration.user import User
from data.schemas.public.user_setting import UserSetting
logger = DBLogger(__name__)
class UserSettingDao(DbModelDaoABC[UserSetting]):
def __init__(self):
DbModelDaoABC.__init__(self, __name__, UserSetting, "public.user_settings")
self.attribute(UserSetting.user_id, int)
self.attribute(UserSetting.key, str)
self.attribute(UserSetting.value, str)
async def find_by_key(self, user: User, key: str) -> UserSetting:
return await self.find_single_by(
[{UserSetting.user_id: user.id}, {UserSetting.key: key}]
)
userSettingsDao = UserSettingDao()

View File

@ -0,0 +1,34 @@
from datetime import datetime
from typing import Optional
from core.database.abc.db_model_abc import DbModelABC
from core.typing import SerialId
class FeatureFlag(DbModelABC):
def __init__(
self,
id: SerialId,
key: str,
value: bool,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._key = key
self._value = value
@property
def key(self) -> str:
return self._key
@property
def value(self) -> bool:
return self._value
@value.setter
def value(self, value: bool):
self._value = value

View File

@ -0,0 +1,20 @@
from core.database.abc.db_model_dao_abc import DbModelDaoABC
from core.logger import DBLogger
from data.schemas.system.feature_flag import FeatureFlag
logger = DBLogger(__name__)
class FeatureFlagDao(DbModelDaoABC[FeatureFlag]):
def __init__(self):
DbModelDaoABC.__init__(self, __name__, FeatureFlag, "system.feature_flags")
self.attribute(FeatureFlag.key, str)
self.attribute(FeatureFlag.value, bool)
async def find_by_key(self, key: str) -> FeatureFlag:
return await self.find_single_by({FeatureFlag.key: key})
featureFlagDao = FeatureFlagDao()

View File

@ -0,0 +1,34 @@
from datetime import datetime
from typing import Optional, Union
from core.database.abc.db_model_abc import DbModelABC
from core.typing import SerialId
class Setting(DbModelABC):
def __init__(
self,
id: SerialId,
key: str,
value: str,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._key = key
self._value = value
@property
def key(self) -> str:
return self._key
@property
def value(self) -> str:
return self._value
@value.setter
def value(self, value: Union[str, int, float, bool]):
self._value = str(value)

View File

@ -0,0 +1,20 @@
from core.database.abc.db_model_dao_abc import DbModelDaoABC
from core.logger import DBLogger
from data.schemas.system.setting import Setting
logger = DBLogger(__name__)
class SettingDao(DbModelDaoABC[Setting]):
def __init__(self):
DbModelDaoABC.__init__(self, __name__, Setting, "system.settings")
self.attribute(Setting.key, str)
self.attribute(Setting.value, str)
async def find_by_key(self, key: str) -> Setting:
return await self.find_single_by({Setting.key: key})
settingsDao = SettingDao()

View File

@ -0,0 +1,32 @@
CREATE
SCHEMA IF NOT EXISTS public;
-- groups
CREATE TABLE IF NOT EXISTS public.domains
(
Id SERIAL PRIMARY KEY,
Name VARCHAR(255) NOT NULL,
-- for history
Deleted BOOLEAN NOT NULL DEFAULT FALSE,
EditorId INT NULL REFERENCES administration.users (Id),
CreatedUtc timestamptz NOT NULL DEFAULT NOW(),
UpdatedUtc timestamptz NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS public.domains_history
(
LIKE public.domains
);
CREATE TRIGGER domains_history_trigger
BEFORE INSERT OR UPDATE OR DELETE
ON public.domains
FOR EACH ROW
EXECUTE FUNCTION public.history_trigger_function();
ALTER TABLE public.short_urls
ADD COLUMN domainId INT NULL REFERENCES public.domains (Id);
ALTER TABLE public.short_urls_history
ADD COLUMN domainId INT NULL REFERENCES public.domains (Id);

View File

@ -0,0 +1,27 @@
CREATE
SCHEMA IF NOT EXISTS public;
-- groups
CREATE TABLE IF NOT EXISTS public.group_role_assignments
(
Id SERIAL PRIMARY KEY,
GroupId INT NOT NULL REFERENCES public.groups (Id),
RoleId INT NOT NULL REFERENCES permission.roles (Id),
-- for history
Deleted BOOLEAN NOT NULL DEFAULT FALSE,
EditorId INT NULL REFERENCES administration.users (Id),
CreatedUtc timestamptz NOT NULL DEFAULT NOW(),
UpdatedUtc timestamptz NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS public.group_role_assignments_history
(
LIKE public.group_role_assignments
);
CREATE TRIGGER group_role_assignment_history_trigger
BEFORE INSERT OR UPDATE OR DELETE
ON public.group_role_assignments
FOR EACH ROW
EXECUTE FUNCTION public.history_trigger_function();

View File

@ -0,0 +1,24 @@
CREATE SCHEMA IF NOT EXISTS system;
CREATE TABLE IF NOT EXISTS system.settings
(
Id SERIAL PRIMARY KEY,
Key TEXT NOT NULL,
Value TEXT NOT NULL,
-- for history
Deleted BOOLEAN NOT NULL DEFAULT FALSE,
EditorId INT NULL REFERENCES administration.users (Id),
CreatedUtc timestamptz NOT NULL DEFAULT NOW(),
UpdatedUtc timestamptz NOT NULL DEFAULT NOW()
);
CREATE TABLE system.settings_history
(
LIKE system.settings
);
CREATE TRIGGER ip_list_history_trigger
BEFORE INSERT OR UPDATE OR DELETE
ON system.settings
FOR EACH ROW
EXECUTE FUNCTION public.history_trigger_function();

View File

@ -0,0 +1,24 @@
CREATE SCHEMA IF NOT EXISTS system;
CREATE TABLE IF NOT EXISTS system.feature_flags
(
Id SERIAL PRIMARY KEY,
Key TEXT NOT NULL,
Value BOOLEAN NOT NULL,
-- for history
Deleted BOOLEAN NOT NULL DEFAULT FALSE,
EditorId INT NULL REFERENCES administration.users (Id),
CreatedUtc timestamptz NOT NULL DEFAULT NOW(),
UpdatedUtc timestamptz NOT NULL DEFAULT NOW()
);
CREATE TABLE system.feature_flags_history
(
LIKE system.feature_flags
);
CREATE TRIGGER ip_list_history_trigger
BEFORE INSERT OR UPDATE OR DELETE
ON system.feature_flags
FOR EACH ROW
EXECUTE FUNCTION public.history_trigger_function();

View File

@ -0,0 +1,25 @@
CREATE SCHEMA IF NOT EXISTS public;
CREATE TABLE IF NOT EXISTS public.user_settings
(
Id SERIAL PRIMARY KEY,
Key TEXT NOT NULL,
Value TEXT NOT NULL,
UserId INT NOT NULL REFERENCES administration.users (Id) ON DELETE CASCADE,
-- for history
Deleted BOOLEAN NOT NULL DEFAULT FALSE,
EditorId INT NULL REFERENCES administration.users (Id),
CreatedUtc timestamptz NOT NULL DEFAULT NOW(),
UpdatedUtc timestamptz NOT NULL DEFAULT NOW()
);
CREATE TABLE public.user_settings_history
(
LIKE public.user_settings
);
CREATE TRIGGER ip_list_history_trigger
BEFORE INSERT OR UPDATE OR DELETE
ON public.user_settings
FOR EACH ROW
EXECUTE FUNCTION public.history_trigger_function();

View File

@ -0,0 +1,40 @@
from core.configuration.feature_flags import FeatureFlags
from core.configuration.feature_flags_enum import FeatureFlagsEnum
from core.logger import DBLogger
from data.abc.data_seeder_abc import DataSeederABC
from data.schemas.system.feature_flag import FeatureFlag
from data.schemas.system.feature_flag_dao import featureFlagDao
logger = DBLogger(__name__)
class FeatureFlagsSeeder(DataSeederABC):
def __init__(self):
DataSeederABC.__init__(self)
async def seed(self):
logger.info("Seeding feature flags")
feature_flags = await featureFlagDao.get_all()
feature_flag_keys = [x.key for x in feature_flags]
possible_feature_flags = {
x.value: FeatureFlags.get_default(x) for x in FeatureFlagsEnum
}
to_create = [
FeatureFlag(0, x, possible_feature_flags[x])
for x in possible_feature_flags.keys()
if x not in feature_flag_keys
]
if len(to_create) > 0:
await featureFlagDao.create_many(to_create)
to_create_dicts = {x.key: x.value for x in to_create}
logger.debug(f"Created feature flags: {to_create_dicts}")
to_delete = [
x for x in feature_flags if x.key not in possible_feature_flags.keys()
]
if len(to_delete) > 0:
await featureFlagDao.delete_many(to_delete, hard_delete=True)
to_delete_dicts = {x.key: x.value for x in to_delete}
logger.debug(f"Deleted feature flags: {to_delete_dicts}")

Some files were not shown because too many files have changed in this diff Show More