Merge pull request 'dev' (#17) from dev into master
All checks were successful
Build on push / prepare (push) Successful in 5s
Build on push / build-redirector (push) Successful in 29s
Build on push / build-api (push) Successful in 29s
Build on push / build-web (push) Successful in 53s

Reviewed-on: #17
This commit is contained in:
Sven Heidemann 2025-03-12 10:09:01 +01:00
commit 433188995e
224 changed files with 7103 additions and 1862 deletions

View File

@ -1,13 +1,48 @@
name: Build dev on push name: Build on push
run-name: Build dev on push run-name: Build on push
on: on:
push: push:
branches: branches:
- dev - dev
jobs: jobs:
prepare:
runs-on: [runner]
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
uses: https://github.com/actions/checkout@v3
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Get Date and Build Number
run: |
git fetch --tags
git tag
DATE=$(date +'%Y.%m.%d')
TAG_COUNT=$(git tag -l "${DATE}.*" | wc -l)
BUILD_NUMBER=$(($TAG_COUNT + 1))
BUILD_VERSION="${DATE}.${BUILD_NUMBER}-dev"
echo "$BUILD_VERSION" > version.txt
echo "VERSION $BUILD_VERSION"
- name: Create Git Tag for Build
run: |
git config user.name "ci"
git config user.email "dev@sh-edraft.de"
echo "tag $(cat version.txt)"
git tag $(cat version.txt)
git push origin --tags
- name: Upload build version artifact
uses: actions/upload-artifact@v3
with:
name: version
path: version.txt
build-api: build-api:
runs-on: [runner] runs-on: [runner]
needs: prepare
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps: steps:
- name: Clone Repository - name: Clone Repository
@ -15,10 +50,16 @@ jobs:
with: with:
token: ${{ secrets.CI_ACCESS_TOKEN }} token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Download build version artifact
uses: actions/download-artifact@v3
with:
name: version
- name: Build docker - name: Build docker
run: | run: |
cd api cd api
docker build -t git.sh-edraft.de/sh-edraft.de/open-redirect-api-dev:$(cat ../version.txt) . echo "VERSION = \"$(cat version.txt)\"" > version.py
docker build --no-cache -t git.sh-edraft.de/sh-edraft.de/open-redirect-api:$(cat ../version.txt) .
- name: Login to registry git.sh-edraft.de - name: Login to registry git.sh-edraft.de
uses: https://github.com/docker/login-action@v1 uses: https://github.com/docker/login-action@v1
@ -29,10 +70,11 @@ jobs:
- name: Push image - name: Push image
run: | run: |
docker push git.sh-edraft.de/sh-edraft.de/open-redirect-api-dev:$(cat version.txt) docker push git.sh-edraft.de/sh-edraft.de/open-redirect-api:$(cat version.txt)
build-redirector: build-redirector:
runs-on: [runner] runs-on: [runner]
needs: prepare
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps: steps:
- name: Clone Repository - name: Clone Repository
@ -40,10 +82,15 @@ jobs:
with: with:
token: ${{ secrets.CI_ACCESS_TOKEN }} token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Download build version artifact
uses: actions/download-artifact@v3
with:
name: version
- name: Build docker - name: Build docker
run: | run: |
cd api cd api
docker build -f dockerfile_redirector -t git.sh-edraft.de/sh-edraft.de/open-redirect-redirector-dev:$(cat ../version.txt) . docker build --no-cache -f dockerfile_redirector -t git.sh-edraft.de/sh-edraft.de/open-redirect-redirector:$(cat ../version.txt) .
- name: Login to registry git.sh-edraft.de - name: Login to registry git.sh-edraft.de
uses: https://github.com/docker/login-action@v1 uses: https://github.com/docker/login-action@v1
@ -54,10 +101,11 @@ jobs:
- name: Push image - name: Push image
run: | run: |
docker push git.sh-edraft.de/sh-edraft.de/open-redirect-redirector-dev:$(cat version.txt) docker push git.sh-edraft.de/sh-edraft.de/open-redirect-redirector:$(cat version.txt)
build-web: build-web:
runs-on: [runner] runs-on: [runner]
needs: prepare
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps: steps:
- name: Clone Repository - name: Clone Repository
@ -65,6 +113,11 @@ jobs:
with: with:
token: ${{ secrets.CI_ACCESS_TOKEN }} token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Download build version artifact
uses: actions/download-artifact@v3
with:
name: version
- name: Prepare web build - name: Prepare web build
run: | run: |
cd web cd web
@ -78,7 +131,7 @@ jobs:
- name: Build docker - name: Build docker
run: | run: |
cd web cd web
docker build -t git.sh-edraft.de/sh-edraft.de/open-redirect-web-dev:$(cat ../version.txt) . docker build --no-cache -t git.sh-edraft.de/sh-edraft.de/open-redirect-web:$(cat ../version.txt) .
- name: Login to registry git.sh-edraft.de - name: Login to registry git.sh-edraft.de
uses: https://github.com/docker/login-action@v1 uses: https://github.com/docker/login-action@v1
@ -89,4 +142,4 @@ jobs:
- name: Push image - name: Push image
run: | run: |
docker push git.sh-edraft.de/sh-edraft.de/open-redirect-web-dev:$(cat version.txt) docker push git.sh-edraft.de/sh-edraft.de/open-redirect-web:$(cat version.txt)

View File

@ -6,7 +6,7 @@ on:
- master - master
jobs: jobs:
build-api: prepare:
runs-on: [runner] runs-on: [runner]
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps: steps:
@ -15,10 +15,51 @@ jobs:
with: with:
token: ${{ secrets.CI_ACCESS_TOKEN }} token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Get Date and Build Number
run: |
git fetch
git tag
DATE=$(date +'%Y.%m.%d')
TAG_COUNT=$(git tag -l "${DATE}.*" | wc -l)
BUILD_NUMBER=$(($TAG_COUNT + 1))
BUILD_VERSION="${DATE}.${BUILD_NUMBER}"
echo "$BUILD_VERSION" > version.txt
echo "VERSION $BUILD_VERSION"
- name: Create Git Tag for Build
run: |
git config user.name "ci"
git config user.email "dev@sh-edraft.de"
echo "tag $(cat version.txt)"
git tag $(cat version.txt)
git push origin --tags
- name: Upload build version artifact
uses: actions/upload-artifact@v3
with:
name: version
path: version.txt
build-api:
runs-on: [runner]
needs: prepare
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
uses: https://github.com/actions/checkout@v3
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Download build version artifact
uses: actions/download-artifact@v3
with:
name: version
- name: Build docker - name: Build docker
run: | run: |
cd api cd api
docker build -t git.sh-edraft.de/sh-edraft.de/open-redirect-api:$(cat ../version.txt) . echo "VERSION = \"$(cat version.txt)\"" > version.py
docker build --no-cache -t git.sh-edraft.de/sh-edraft.de/open-redirect-api:$(cat ../version.txt) .
- name: Login to registry git.sh-edraft.de - name: Login to registry git.sh-edraft.de
uses: https://github.com/docker/login-action@v1 uses: https://github.com/docker/login-action@v1
@ -30,9 +71,10 @@ jobs:
- name: Push image - name: Push image
run: | run: |
docker push git.sh-edraft.de/sh-edraft.de/open-redirect-api:$(cat version.txt) docker push git.sh-edraft.de/sh-edraft.de/open-redirect-api:$(cat version.txt)
build-redirector: build-redirector:
runs-on: [runner] runs-on: [runner]
needs: prepare
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps: steps:
- name: Clone Repository - name: Clone Repository
@ -40,10 +82,15 @@ jobs:
with: with:
token: ${{ secrets.CI_ACCESS_TOKEN }} token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Download build version artifact
uses: actions/download-artifact@v3
with:
name: version
- name: Build docker - name: Build docker
run: | run: |
cd api cd api
docker build -f dockerfile_redirector -t git.sh-edraft.de/sh-edraft.de/open-redirect-redirector:$(cat ../version.txt) . docker build --no-cache -f dockerfile_redirector -t git.sh-edraft.de/sh-edraft.de/open-redirect-redirector:$(cat ../version.txt) .
- name: Login to registry git.sh-edraft.de - name: Login to registry git.sh-edraft.de
uses: https://github.com/docker/login-action@v1 uses: https://github.com/docker/login-action@v1
@ -58,6 +105,7 @@ jobs:
build-web: build-web:
runs-on: [runner] runs-on: [runner]
needs: prepare
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps: steps:
- name: Clone Repository - name: Clone Repository
@ -65,6 +113,11 @@ jobs:
with: with:
token: ${{ secrets.CI_ACCESS_TOKEN }} token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Download build version artifact
uses: actions/download-artifact@v3
with:
name: version
- name: Prepare web build - name: Prepare web build
run: | run: |
cd web cd web
@ -78,7 +131,7 @@ jobs:
- name: Build docker - name: Build docker
run: | run: |
cd web cd web
docker build -t git.sh-edraft.de/sh-edraft.de/open-redirect-web:$(cat ../version.txt) . docker build --no-cache -t git.sh-edraft.de/sh-edraft.de/open-redirect-web:$(cat ../version.txt) .
- name: Login to registry git.sh-edraft.de - name: Login to registry git.sh-edraft.de
uses: https://github.com/docker/login-action@v1 uses: https://github.com/docker/login-action@v1

View File

@ -0,0 +1,29 @@
name: Test before pr merge
run-name: Test before pr merge
on:
pull_request:
types:
- opened
- edited
- reopened
- synchronize
- ready_for_review
jobs:
test-lint:
runs-on: [ runner ]
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
uses: https://github.com/actions/checkout@v3
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Installing dependencies
working-directory: ./api
run: |
python3.12 -m pip install -r requirements-dev.txt
- name: Checking black
working-directory: ./api
run: python3.12 -m black src --check

View File

@ -1,39 +0,0 @@
name: Test before pr merge
run-name: Test before pr merge
on:
pull_request:
types:
- opened
- edited
- reopened
- synchronize
- ready_for_review
jobs:
test-before-merge:
runs-on: [ runner ]
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
uses: https://github.com/actions/checkout@v3
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Setup node
uses: https://github.com/actions/setup-node@v3
- name: Installing dependencies
run: npm ci
- name: Checking eslint
run: npm run lint
- name: Setup chrome
run: |
wget -q -O - https://dl-ssl.google.com/linux/linux_signing_key.pub | apt-key add -
echo "deb http://dl.google.com/linux/chrome/deb/ stable main" > /etc/apt/sources.list.d/google.list
apt-get update
apt-get install -y google-chrome-stable xvfb
- name: Testing
run: npm run test:ci

View File

@ -0,0 +1,79 @@
name: Test before pr merge
run-name: Test before pr merge
on:
pull_request:
types:
- opened
- edited
- reopened
- synchronize
- ready_for_review
jobs:
test-lint:
runs-on: [ runner ]
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
uses: https://github.com/actions/checkout@v3
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Setup node
uses: https://github.com/actions/setup-node@v3
- name: Installing dependencies
working-directory: ./web
run: npm ci
- name: Checking eslint
working-directory: ./web
run: npm run lint
test-translation-lint:
runs-on: [ runner ]
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
uses: https://github.com/actions/checkout@v3
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Setup node
uses: https://github.com/actions/setup-node@v3
- name: Installing dependencies
working-directory: ./web
run: npm ci
- name: Checking translations
working-directory: ./web
run: npm run lint:translations
test-before-merge:
runs-on: [ runner ]
container: git.sh-edraft.de/sh-edraft.de/act-runner:latest
steps:
- name: Clone Repository
uses: https://github.com/actions/checkout@v3
with:
token: ${{ secrets.CI_ACCESS_TOKEN }}
- name: Setup node
uses: https://github.com/actions/setup-node@v3
- name: Installing dependencies
working-directory: ./web
run: npm ci
- name: Setup chrome
working-directory: ./web
run: |
wget -q -O - https://dl-ssl.google.com/linux/linux_signing_key.pub | apt-key add -
echo "deb http://dl.google.com/linux/chrome/deb/ stable main" > /etc/apt/sources.list.d/google.list
apt-get update
apt-get install -y google-chrome-stable xvfb
- name: Testing
working-directory: ./web
run: npm run test:ci

View File

@ -1,10 +1,13 @@
ariadne==0.23.0 ariadne==0.23.0
eventlet==0.37.0 broadcaster==0.3.1
graphql-core==3.2.5 graphql-core==3.2.5
Flask[async]==3.1.0
Flask-Cors==5.0.0
async-property==0.2.2 async-property==0.2.2
python-keycloak==4.7.3
psycopg[binary]==3.2.3 psycopg[binary]==3.2.3
psycopg-pool==3.2.4 psycopg-pool==3.2.4
Werkzeug==3.1.3 uvicorn==0.34.0
starlette==0.46.0
requests==2.32.3
Jinja2==3.1.5
python-keycloak==5.3.1
python-multipart==0.0.20
websockets==15.0

View File

@ -1,78 +1,45 @@
import importlib import importlib
import os import os
import time from typing import Optional
from uuid import uuid4
from flask import Flask, request, g from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from api.route import Route from core.environment import Environment
from core.logger import APILogger from core.logger import APILogger
app = Flask(__name__)
logger = APILogger(__name__) logger = APILogger(__name__)
def filter_relevant_headers(headers: dict) -> dict: class API:
relevant_keys = { app: Optional[Starlette] = None
"Content-Type",
"Host",
"Connection",
"User-Agent",
"Origin",
"Referer",
"Accept",
}
return {key: value for key, value in headers.items() if key in relevant_keys}
@classmethod
def create(cls, app: Starlette):
cls.app = app
@app.before_request @staticmethod
async def log_request(): async def handle_exception(request: Request, exc: Exception):
g.request_id = uuid4() logger.error(f"Request {request.state.request_id}", exc)
g.start_time = time.time() return JSONResponse({"error": str(exc)}, status_code=500)
logger.debug(
f"Request {g.request_id}: {request.method}@{request.path} from {request.remote_addr}"
)
user = await Route.get_user()
request_info = { @staticmethod
"headers": filter_relevant_headers(dict(request.headers)), def get_allowed_origins():
"args": request.args.to_dict(), client_urls = Environment.get("CLIENT_URLS", str)
"form-data": request.form.to_dict(), if client_urls is None or client_urls == "":
"payload": request.get_json(silent=True), allowed_origins = ["*"]
"user": f"{user.id}-{user.keycloak_id}" if user else None, logger.warning("No allowed origins specified, allowing all origins")
"files": ( else:
{key: file.filename for key, file in request.files.items()} allowed_origins = client_urls.split(",")
if request.files
else None
),
}
logger.trace(f"Request {g.request_id}: {request_info}") return allowed_origins
@staticmethod
@app.after_request def import_routes():
def log_after_request(response): # used to import all routes
# calc the time it took to process the request routes_dir = os.path.join(os.path.dirname(__file__), "routes")
duration = (time.time() - g.start_time) * 1000 for filename in os.listdir(routes_dir):
logger.info( if filename.endswith(".py") and filename != "__init__.py":
f"Request finished {g.request_id}: {response.status_code}-{request.method}@{request.path} from {request.remote_addr} in {duration:.2f}ms" module_name = f"api.routes.{filename[:-3]}"
) importlib.import_module(module_name)
return response
@app.errorhandler(Exception)
def handle_exception(e):
logger.error(f"Request {g.request_id}", e)
return {"error": str(e)}, 500
# used to import all routes
routes_dir = os.path.join(os.path.dirname(__file__), "routes")
for filename in os.listdir(routes_dir):
if filename.endswith(".py") and filename != "__init__.py":
module_name = f"api.routes.{filename[:-3]}"
importlib.import_module(module_name)
# Explicitly register the routes
for route, (view_func, options) in Route.registered_routes.items():
app.add_url_rule(route, view_func=view_func, **options)

5
api/src/api/broadcast.py Normal file
View File

@ -0,0 +1,5 @@
from typing import Optional
from broadcaster import Broadcast
broadcast: Optional[Broadcast] = Broadcast("memory://")

View File

@ -1,9 +1,9 @@
from flask import jsonify from starlette.responses import JSONResponse
def unauthorized(): def unauthorized():
return jsonify({"error": "Unauthorized"}), 401 return JSONResponse({"error": "Unauthorized"}, 401)
def forbidden(): def forbidden():
return jsonify({"error": "Unauthorized"}), 401 return JSONResponse({"error": "Unauthorized"}, 401)

View File

View File

@ -0,0 +1,73 @@
import time
from uuid import uuid4
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from api.route import Route
from core.logger import APILogger
logger = APILogger("api.api")
class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
await self._log_request(request)
response = await call_next(request)
await self._log_after_request(request, response)
return response
@staticmethod
def _filter_relevant_headers(headers: dict) -> dict:
relevant_keys = {
"content-type",
"host",
"connection",
"user-agent",
"origin",
"referer",
"accept",
}
return {key: value for key, value in headers.items() if key in relevant_keys}
@classmethod
async def _log_request(cls, request: Request):
request.state.request_id = uuid4()
request.state.start_time = time.time()
logger.debug(
f"Request {request.state.request_id}: {request.method}@{request.url.path} from {request.client.host}"
)
user = await Route.get_user()
request_info = {
"headers": cls._filter_relevant_headers(dict(request.headers)),
"args": dict(request.query_params),
"form-data": (
await request.form()
if request.headers.get("content-type")
== "application/x-www-form-urlencoded"
else None
),
"payload": (
await request.json()
if request.headers.get("content-length") == "0"
else None
),
"user": f"{user.id}-{user.keycloak_id}" if user else None,
"files": (
{key: file.filename for key, file in (await request.form()).items()}
if await request.form()
else None
),
}
logger.trace(f"Request {request.state.request_id}: {request_info}")
@staticmethod
async def _log_after_request(request: Request, response: Response):
duration = (time.time() - request.state.start_time) * 1000
logger.info(
f"Request finished {request.state.request_id}: {response.status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
)

View File

@ -0,0 +1,28 @@
from contextvars import ContextVar
from typing import Optional, Union
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.websockets import WebSocket
_request_context: ContextVar[Union[Request, None]] = ContextVar("request", default=None)
class RequestMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
_request_context.set(request)
from core.logger import APILogger
logger = APILogger(__name__)
logger.trace("Set new current request")
response = await call_next(request)
return response
def set_request(request: Union[Request, WebSocket, None]):
_request_context.set(request)
def get_request() -> Optional[Request]:
return _request_context.get()

View File

@ -0,0 +1,41 @@
from uuid import uuid4
from ariadne.asgi.handlers import GraphQLTransportWSHandler
from starlette.datastructures import MutableHeaders
from starlette.websockets import WebSocket
from api.middleware.request import set_request
from core.logger import APILogger
logger = APILogger("api.ws")
class AuthenticatedGraphQLTransportWSHandler(GraphQLTransportWSHandler):
def __init__(self, *args, **kwargs):
super().__init__(
*args,
on_connect=self.on_connect,
on_disconnect=self.on_disconnect,
**kwargs,
)
@staticmethod
async def on_connect(ws: WebSocket, message: dict):
ws.state.request_id = uuid4()
logger.info(f"WebSocket connection {ws.state.request_id}")
if "Authorization" not in message:
return True
mutable_headers = MutableHeaders()
mutable_headers["Authorization"] = message.get("Authorization", "")
ws._headers = mutable_headers
set_request(ws)
return True
@staticmethod
async def on_disconnect(ws: WebSocket):
logger.debug(f"WebSocket connection {ws.state.request_id} closed")
return True

View File

@ -2,12 +2,12 @@ import functools
from functools import wraps from functools import wraps
from inspect import iscoroutinefunction from inspect import iscoroutinefunction
from typing import Callable, Union, Optional from typing import Callable, Union, Optional
from urllib.request import Request
from flask import request from starlette.requests import Request
from flask_cors import cross_origin from starlette.routing import Route as StarletteRoute
from api.errors import unauthorized from api.errors import unauthorized
from api.middleware.request import get_request
from api.route_user_extension import RouteUserExtension from api.route_user_extension import RouteUserExtension
from core.environment import Environment from core.environment import Environment
from data.schemas.administration.api_key import ApiKey from data.schemas.administration.api_key import ApiKey
@ -16,10 +16,10 @@ from data.schemas.administration.user import User
class Route(RouteUserExtension): class Route(RouteUserExtension):
registered_routes = {} registered_routes: list[StarletteRoute] = []
@classmethod @classmethod
async def get_api_key(cls) -> ApiKey: async def get_api_key(cls, request: Request) -> ApiKey:
auth_header = request.headers.get("Authorization", None) auth_header = request.headers.get("Authorization", None)
api_key = auth_header.split(" ")[1] api_key = auth_header.split(" ")[1]
return await apiKeyDao.find_by_key(api_key) return await apiKeyDao.find_by_key(api_key)
@ -35,11 +35,13 @@ class Route(RouteUserExtension):
return api_key_from_db is not None and not api_key_from_db.deleted return api_key_from_db is not None and not api_key_from_db.deleted
@classmethod @classmethod
async def _get_auth_type(cls, auth_header: str) -> Optional[Union[User, ApiKey]]: async def _get_auth_type(
cls, request: Request, auth_header: str
) -> Optional[Union[User, ApiKey]]:
if auth_header.startswith("Bearer "): if auth_header.startswith("Bearer "):
return await cls.get_user() return await cls.get_user()
elif auth_header.startswith("API-Key "): elif auth_header.startswith("API-Key "):
return await cls.get_api_key() return await cls.get_api_key(request)
elif ( elif (
auth_header.startswith("DEV-User ") auth_header.startswith("DEV-User ")
and Environment.get_environment() == "development" and Environment.get_environment() == "development"
@ -49,11 +51,15 @@ class Route(RouteUserExtension):
@classmethod @classmethod
async def get_authenticated_user_or_api_key(cls) -> Union[User, ApiKey]: async def get_authenticated_user_or_api_key(cls) -> Union[User, ApiKey]:
request = get_request()
if request is None:
raise ValueError("No request found")
auth_header = request.headers.get("Authorization", None) auth_header = request.headers.get("Authorization", None)
if not auth_header: if not auth_header:
raise Exception("No Authorization header found") raise Exception("No Authorization header found")
user_or_api_key = await cls._get_auth_type(auth_header) user_or_api_key = await cls._get_auth_type(request, auth_header)
if user_or_api_key is None: if user_or_api_key is None:
raise Exception("Invalid Authorization header") raise Exception("Invalid Authorization header")
return user_or_api_key return user_or_api_key
@ -62,14 +68,22 @@ class Route(RouteUserExtension):
async def get_authenticated_user_or_api_key_or_default( async def get_authenticated_user_or_api_key_or_default(
cls, cls,
) -> Optional[Union[User, ApiKey]]: ) -> Optional[Union[User, ApiKey]]:
request = get_request()
if request is None:
return None
auth_header = request.headers.get("Authorization", None) auth_header = request.headers.get("Authorization", None)
if not auth_header: if not auth_header:
return None return None
return await cls._get_auth_type(auth_header) return await cls._get_auth_type(request, auth_header)
@classmethod @classmethod
async def is_authorized(cls) -> bool: async def is_authorized(cls) -> bool:
request = get_request()
if request is None:
return False
auth_header = request.headers.get("Authorization", None) auth_header = request.headers.get("Authorization", None)
if not auth_header: if not auth_header:
return False return False
@ -99,26 +113,25 @@ class Route(RouteUserExtension):
) )
@wraps(f) @wraps(f)
async def decorator(*args, **kwargs): async def decorator(request: Request, *args, **kwargs):
if skip_in_dev and Environment.get_environment() == "development": if skip_in_dev and Environment.get_environment() == "development":
if iscoroutinefunction(f): if iscoroutinefunction(f):
return await f(*args, **kwargs) return await f(request, *args, **kwargs)
return f(*args, **kwargs) return f(request, *args, **kwargs)
if not await cls.is_authorized(): if not await cls.is_authorized():
return unauthorized() return unauthorized()
if iscoroutinefunction(f): if iscoroutinefunction(f):
return await f(*args, **kwargs) return await f(request, *args, **kwargs)
return f(*args, **kwargs) return f(request, *args, **kwargs)
return decorator return decorator
@classmethod @classmethod
def route(cls, path=None, **kwargs): def route(cls, path=None, **kwargs):
def inner(fn): def inner(fn):
cross_origin(fn) cls.registered_routes.append(StarletteRoute(path, fn, **kwargs))
cls.registered_routes[path] = (fn, kwargs)
return fn return fn
return inner return inner

View File

@ -1,10 +1,11 @@
from typing import Optional from typing import Optional
from flask import request, Request, has_request_context
from keycloak import KeycloakAuthenticationError, KeycloakConnectionError from keycloak import KeycloakAuthenticationError, KeycloakConnectionError
from starlette.requests import Request
from api.auth.keycloak_client import Keycloak from api.auth.keycloak_client import Keycloak
from api.auth.keycloak_user import KeycloakUser from api.auth.keycloak_user import KeycloakUser
from api.middleware.request import get_request
from core.get_value import get_value from core.get_value import get_value
from core.logger import Logger from core.logger import Logger
from data.schemas.administration.user import User from data.schemas.administration.user import User
@ -19,8 +20,8 @@ logger = Logger(__name__)
class RouteUserExtension: class RouteUserExtension:
@classmethod @classmethod
def _get_user_id_from_token(cls) -> Optional[str]: def _get_user_id_from_token(cls, request: Request) -> Optional[str]:
token = cls.get_token() token = cls.get_token(request)
if not token: if not token:
return None return None
@ -34,7 +35,7 @@ class RouteUserExtension:
return get_value(user_info, "sub", str) return get_value(user_info, "sub", str)
@staticmethod @staticmethod
def get_token() -> Optional[str]: def get_token(request: Request) -> Optional[str]:
if "Authorization" not in request.headers: if "Authorization" not in request.headers:
return None return None
@ -45,10 +46,11 @@ class RouteUserExtension:
@classmethod @classmethod
async def get_user(cls) -> Optional[User]: async def get_user(cls) -> Optional[User]:
if not has_request_context(): request = get_request()
if request is None:
return None return None
user_id = cls._get_user_id_from_token() user_id = cls._get_user_id_from_token(request)
if not user_id: if not user_id:
return None return None
@ -56,8 +58,12 @@ class RouteUserExtension:
@classmethod @classmethod
async def get_dev_user(cls) -> Optional[User]: async def get_dev_user(cls) -> Optional[User]:
request = get_request()
if request is None:
return None
return await userDao.find_single_by( return await userDao.find_single_by(
[{User.keycloak_id: cls.get_token()}, {User.deleted: False}] [{User.keycloak_id: cls.get_token(request)}, {User.deleted: False}]
) )
@classmethod @classmethod
@ -86,8 +92,8 @@ class RouteUserExtension:
logger.error("Failed to find or create user", e) logger.error("Failed to find or create user", e)
@classmethod @classmethod
async def verify_login(cls, req: Request) -> bool: async def verify_login(cls, request: Request) -> bool:
auth_header = req.headers.get("Authorization", None) auth_header = request.headers.get("Authorization", None)
if not auth_header or not auth_header.startswith("Bearer "): if not auth_header or not auth_header.startswith("Bearer "):
return False return False

View File

@ -1,7 +1,8 @@
from uuid import uuid4 from uuid import uuid4
from flask import send_file from starlette.requests import Request
from werkzeug.exceptions import NotFound from starlette.responses import FileResponse
from starlette.exceptions import HTTPException
from api.route import Route from api.route import Route
from core.logger import APILogger from core.logger import APILogger
@ -9,19 +10,23 @@ from core.logger import APILogger
logger = APILogger(__name__) logger = APILogger(__name__)
@Route.get(f"/api/files/<path:file_path>") @Route.get("/api/files/{file_path:path}")
def get_file(file_path: str): async def get_file(request: Request):
file_path = request.path_params["file_path"]
name = file_path name = file_path
if "/" in file_path: if "/" in file_path:
name = file_path.split("/")[-1] name = file_path.split("/")[-1]
try: try:
return send_file( return FileResponse(
f"../files/{file_path}", download_name=name, as_attachment=True path=f"files/{file_path}",
filename=name,
media_type="application/octet-stream",
) )
except NotFound: except HTTPException as e:
return {"error": "File not found"}, 404 if e.status_code == 404:
except Exception as e: return {"error": "File not found"}, 404
error_id = uuid4() else:
logger.error(f"Error {error_id} getting file {file_path}", e) error_id = uuid4()
return {"error": f"File error. ErrorId: {error_id}"}, 500 logger.error(f"Error {error_id} getting file {file_path}", e)
return {"error": f"File error. ErrorId: {error_id}"}, 500

View File

@ -1,5 +1,6 @@
from ariadne import graphql from ariadne import graphql
from flask import request, jsonify from starlette.requests import Request
from starlette.responses import JSONResponse
from api.route import Route from api.route import Route
from api_graphql.service.schema import schema from api_graphql.service.schema import schema
@ -10,11 +11,11 @@ logger = Logger(__name__)
@Route.post(f"{BasePath}") @Route.post(f"{BasePath}")
async def graphql_endpoint(): async def graphql_endpoint(request: Request):
data = request.get_json() data = await request.json()
# Note: Passing the request to the context is optional. # Note: Passing the request to the context is optional.
# In Flask, the current request is always accessible as flask.request # In Starlette, the current request is accessible as request
success, result = await graphql(schema, data, context_value=request) success, result = await graphql(schema, data, context_value=request)
status_code = 200 status_code = 200
@ -24,4 +25,4 @@ async def graphql_endpoint():
] ]
status_code = max(status_codes, default=200) status_code = max(status_codes, default=200)
return jsonify(result), status_code return JSONResponse(result, status_code=status_code)

View File

@ -1,4 +1,6 @@
from ariadne.explorer import ExplorerPlayground from ariadne.explorer import ExplorerPlayground
from starlette.requests import Request
from starlette.responses import HTMLResponse
from api.route import Route from api.route import Route
from core.environment import Environment from core.environment import Environment
@ -10,7 +12,7 @@ logger = Logger(__name__)
@Route.get(f"{BasePath}/playground") @Route.get(f"{BasePath}/playground")
@Route.authorize(skip_in_dev=True) @Route.authorize(skip_in_dev=True)
async def playground(): async def playground(r: Request):
if Environment.get_environment() != "development": if Environment.get_environment() != "development":
return "", 403 return "", 403
@ -19,7 +21,6 @@ async def playground():
if dev_user: if dev_user:
request_global_headers = {f"Authorization": f"DEV-User {dev_user}"} request_global_headers = {f"Authorization": f"DEV-User {dev_user}"}
return ( return HTMLResponse(
ExplorerPlayground(request_global_headers=request_global_headers).html(None), ExplorerPlayground(request_global_headers=request_global_headers).html(None)
200,
) )

View File

@ -1,7 +1,16 @@
from starlette.requests import Request
from starlette.responses import JSONResponse
from api.route import Route from api.route import Route
from core.configuration.feature_flags import FeatureFlags
from core.configuration.feature_flags_enum import FeatureFlagsEnum
from version import VERSION from version import VERSION
@Route.get(f"/api/version") @Route.get(f"/api/version")
def version(): async def version(r: Request):
return VERSION feature = await FeatureFlags.has_feature(FeatureFlagsEnum.version_endpoint)
if not feature:
return JSONResponse("DISABLED", status_code=403)
return JSONResponse(VERSION)

View File

@ -4,6 +4,7 @@ from api_graphql.abc.filter.bool_filter import BoolFilter
from api_graphql.abc.filter.int_filter import IntFilter from api_graphql.abc.filter.int_filter import IntFilter
from api_graphql.abc.filter.string_filter import StringFilter from api_graphql.abc.filter.string_filter import StringFilter
from api_graphql.abc.filter_abc import FilterABC from api_graphql.abc.filter_abc import FilterABC
from api_graphql.filter.fuzzy_filter import FuzzyFilter
class DbModelFilterABC[T](FilterABC[T]): class DbModelFilterABC[T](FilterABC[T]):
@ -18,3 +19,5 @@ class DbModelFilterABC[T](FilterABC[T]):
self.add_field("editor", IntFilter) self.add_field("editor", IntFilter)
self.add_field("createdUtc", StringFilter, "created") self.add_field("createdUtc", StringFilter, "created")
self.add_field("updatedUtc", StringFilter, "updated") self.add_field("updatedUtc", StringFilter, "updated")
self.add_field("fuzzy", FuzzyFilter)

View File

@ -16,7 +16,7 @@ class MutationABC(QueryABC):
self, self,
name: str, name: str,
mutation_name: str, mutation_name: str,
require_any_permission: list[Permissions] = None, require_any_permission=None,
public: bool = False, public: bool = False,
): ):
""" """
@ -27,6 +27,8 @@ class MutationABC(QueryABC):
:param bool public: Define if the field can resolve without authentication :param bool public: Define if the field can resolve without authentication
:return: :return:
""" """
if require_any_permission is None:
require_any_permission = []
from api_graphql.definition import QUERIES from api_graphql.definition import QUERIES
self.field( self.field(

View File

@ -4,12 +4,13 @@ from enum import Enum
from types import NoneType from types import NoneType
from typing import Callable, Type, get_args, Any, Union from typing import Callable, Type, get_args, Any, Union
from ariadne import ObjectType from ariadne import ObjectType, SubscriptionType
from graphql import GraphQLResolveInfo from graphql import GraphQLResolveInfo
from typing_extensions import deprecated from typing_extensions import deprecated
from api.route import Route from api.route import Route
from api_graphql.abc.collection_filter_abc import CollectionFilterABC from api_graphql.abc.collection_filter_abc import CollectionFilterABC
from api_graphql.abc.field_abc import FieldABC
from api_graphql.abc.input_abc import InputABC from api_graphql.abc.input_abc import InputABC
from api_graphql.abc.sort_abc import Sort from api_graphql.abc.sort_abc import Sort
from api_graphql.field.collection_field import CollectionField from api_graphql.field.collection_field import CollectionField
@ -20,6 +21,7 @@ from api_graphql.field.mutation_field import MutationField
from api_graphql.field.mutation_field_builder import MutationFieldBuilder from api_graphql.field.mutation_field_builder import MutationFieldBuilder
from api_graphql.field.resolver_field import ResolverField from api_graphql.field.resolver_field import ResolverField
from api_graphql.field.resolver_field_builder import ResolverFieldBuilder from api_graphql.field.resolver_field_builder import ResolverFieldBuilder
from api_graphql.field.subscription_field import SubscriptionField
from api_graphql.service.collection_result import CollectionResult from api_graphql.service.collection_result import CollectionResult
from api_graphql.service.exceptions import ( from api_graphql.service.exceptions import (
UnauthorizedException, UnauthorizedException,
@ -29,6 +31,7 @@ from api_graphql.service.exceptions import (
from api_graphql.service.query_context import QueryContext from api_graphql.service.query_context import QueryContext
from api_graphql.typing import TRequireAnyPermissions, TRequireAnyResolvers from api_graphql.typing import TRequireAnyPermissions, TRequireAnyResolvers
from core.logger import APILogger from core.logger import APILogger
from core.string import first_to_lower
from service.permission.permissions_enum import Permissions from service.permission.permissions_enum import Permissions
logger = APILogger(__name__) logger = APILogger(__name__)
@ -40,6 +43,7 @@ class QueryABC(ObjectType):
@abstractmethod @abstractmethod
def __init__(self, name: str = __name__): def __init__(self, name: str = __name__):
ObjectType.__init__(self, name) ObjectType.__init__(self, name)
self._subscriptions: dict[str, SubscriptionType] = {}
@staticmethod @staticmethod
async def _authorize(): async def _authorize():
@ -67,6 +71,8 @@ class QueryABC(ObjectType):
*args, *args,
**kwargs, **kwargs,
): ):
info = args[0]
if len(permissions) > 0: if len(permissions) > 0:
user = await Route.get_authenticated_user_or_api_key_or_default() user = await Route.get_authenticated_user_or_api_key_or_default()
if user is not None and all( if user is not None and all(
@ -132,7 +138,12 @@ class QueryABC(ObjectType):
skip = kwargs["skip"] skip = kwargs["skip"]
collection = await field.dao.find_by(filters, sorts, take, skip) collection = await field.dao.find_by(filters, sorts, take, skip)
res = CollectionResult(await field.dao.count(), len(collection), collection) if field.direct_result:
return collection
res = CollectionResult(
await field.dao.count(filters), len(collection), collection
)
return res return res
async def collection_wrapper(*args, **kwargs): async def collection_wrapper(*args, **kwargs):
@ -169,11 +180,12 @@ class QueryABC(ObjectType):
) )
async def resolver_wrapper(*args, **kwargs): async def resolver_wrapper(*args, **kwargs):
return ( result = (
await field.resolver(*args, **kwargs) await field.resolver(*args, **kwargs)
if iscoroutinefunction(field.resolver) if iscoroutinefunction(field.resolver)
else field.resolver(*args, **kwargs) else field.resolver(*args, **kwargs)
) )
return result
if isinstance(field, DaoField): if isinstance(field, DaoField):
resolver = dao_wrapper resolver = dao_wrapper
@ -203,6 +215,13 @@ class QueryABC(ObjectType):
resolver = input_wrapper resolver = input_wrapper
elif isinstance(field, SubscriptionField):
async def sub_wrapper(sub: QueryABC, info: GraphQLResolveInfo, **kwargs):
return await resolver_wrapper(sub, info, **kwargs)
resolver = sub_wrapper
else: else:
raise ValueError(f"Unknown field type: {field.name}") raise ValueError(f"Unknown field type: {field.name}")
@ -220,7 +239,12 @@ class QueryABC(ObjectType):
result = await resolver(*args, **kwargs) result = await resolver(*args, **kwargs)
if field.require_any is not None: if field.require_any is not None:
await self._require_any(result, *field.require_any, *args, **kwargs) await self._require_any(
result,
*field.require_any,
*args,
**kwargs,
)
return result return result
@ -250,6 +274,9 @@ class QueryABC(ObjectType):
self.field( self.field(
MutationFieldBuilder(name) MutationFieldBuilder(name)
.with_resolver(f) .with_resolver(f)
.with_change_broadcast(
f"{first_to_lower(self.name.replace("Mutation", ""))}Change"
)
.with_input(input_type, input_key) .with_input(input_type, input_key)
.with_require_any_permission(require_any_permission) .with_require_any_permission(require_any_permission)
.with_public(public) .with_public(public)
@ -271,6 +298,8 @@ class QueryABC(ObjectType):
for f in filters: for f in filters:
collection = list(filter(lambda x: f.filter(x), collection)) collection = list(filter(lambda x: f.filter(x), collection))
total_count = len(collection)
if sort is not None: if sort is not None:
def f_sort(x: object, k: str): def f_sort(x: object, k: str):

View File

@ -0,0 +1,51 @@
from abc import abstractmethod
from asyncio import iscoroutinefunction
from ariadne import SubscriptionType
from api.middleware.request import get_request
from api_graphql.abc.query_abc import QueryABC
from api_graphql.field.subscription_field_builder import SubscriptionFieldBuilder
from core.logger import APILogger
logger = APILogger(__name__)
class SubscriptionABC(SubscriptionType, QueryABC):
@abstractmethod
def __init__(self):
SubscriptionType.__init__(self)
def subscribe(self, builder: SubscriptionFieldBuilder):
field = builder.build()
async def wrapper(*args, **kwargs):
if not field.public:
await self._authorize()
if (
field.require_any is None
and not field.public
and field.require_any_permission
):
await self._require_any_permission(field.require_any_permission)
result = (
await field.resolver(*args, **kwargs)
if iscoroutinefunction(field.resolver)
else field.resolver(*args, **kwargs)
)
if field.require_any is not None:
await self._require_any(
result,
*field.require_any,
*args,
**kwargs,
)
return result
self.set_field(field.name, wrapper)
self.set_source(field.name, field.generator)

View File

@ -4,6 +4,7 @@ import os
from api_graphql.abc.db_model_query_abc import DbModelQueryABC from api_graphql.abc.db_model_query_abc import DbModelQueryABC
from api_graphql.abc.mutation_abc import MutationABC from api_graphql.abc.mutation_abc import MutationABC
from api_graphql.abc.query_abc import QueryABC from api_graphql.abc.query_abc import QueryABC
from api_graphql.abc.subscription_abc import SubscriptionABC
from api_graphql.query import Query from api_graphql.query import Query
@ -19,7 +20,7 @@ def import_graphql_schema_part(part: str):
import_graphql_schema_part("queries") import_graphql_schema_part("queries")
import_graphql_schema_part("mutations") import_graphql_schema_part("mutations")
sub_query_classes = [DbModelQueryABC, MutationABC] sub_query_classes = [DbModelQueryABC, MutationABC, SubscriptionABC]
query_classes = [ query_classes = [
*[y for x in sub_query_classes for y in x.__subclasses__()], *[y for x in sub_query_classes for y in x.__subclasses__()],
*[x for x in QueryABC.__subclasses__() if x not in sub_query_classes], *[x for x in QueryABC.__subclasses__() if x not in sub_query_classes],

View File

@ -20,6 +20,7 @@ class DaoField(FieldABC):
dao: DataAccessObjectABC = None, dao: DataAccessObjectABC = None,
filter_type: Type[FilterABC] = None, filter_type: Type[FilterABC] = None,
sort_type: Type[T] = None, sort_type: Type[T] = None,
direct_result: bool = False,
): ):
FieldABC.__init__(self, name, require_any_permission, require_any, public) FieldABC.__init__(self, name, require_any_permission, require_any, public)
self._name = name self._name = name
@ -28,6 +29,7 @@ class DaoField(FieldABC):
self._dao = dao self._dao = dao
self._filter_type = filter_type self._filter_type = filter_type
self._sort_type = sort_type self._sort_type = sort_type
self._direct_result = direct_result
@property @property
def dao(self) -> Optional[DataAccessObjectABC]: def dao(self) -> Optional[DataAccessObjectABC]:
@ -42,3 +44,7 @@ class DaoField(FieldABC):
@property @property
def sort_type(self) -> Optional[Type[T]]: def sort_type(self) -> Optional[Type[T]]:
return self._sort_type return self._sort_type
@property
def direct_result(self) -> bool:
return self._direct_result

View File

@ -15,6 +15,7 @@ class DaoFieldBuilder(FieldBuilderABC):
self._dao = None self._dao = None
self._filter_type = None self._filter_type = None
self._sort_type = None self._sort_type = None
self._direct_result = False
def with_dao(self, dao: DataAccessObjectABC) -> Self: def with_dao(self, dao: DataAccessObjectABC) -> Self:
assert dao is not None, "dao cannot be None" assert dao is not None, "dao cannot be None"
@ -31,6 +32,10 @@ class DaoFieldBuilder(FieldBuilderABC):
self._sort_type = sort_type self._sort_type = sort_type
return self return self
def with_direct_result(self) -> Self:
self._direct_result = True
return self
def build(self) -> DaoField: def build(self) -> DaoField:
assert self._dao is not None, "dao cannot be None" assert self._dao is not None, "dao cannot be None"
return DaoField( return DaoField(
@ -41,4 +46,5 @@ class DaoFieldBuilder(FieldBuilderABC):
self._dao, self._dao,
self._filter_type, self._filter_type,
self._sort_type, self._sort_type,
self._direct_result,
) )

View File

@ -1,7 +1,9 @@
from asyncio import iscoroutinefunction
from typing import Self, Type from typing import Self, Type
from ariadne.types import Resolver from ariadne.types import Resolver
from api.broadcast import broadcast
from api_graphql.abc.field_builder_abc import FieldBuilderABC from api_graphql.abc.field_builder_abc import FieldBuilderABC
from api_graphql.abc.input_abc import InputABC from api_graphql.abc.input_abc import InputABC
from api_graphql.field.mutation_field import MutationField from api_graphql.field.mutation_field import MutationField
@ -18,9 +20,41 @@ class MutationFieldBuilder(FieldBuilderABC):
def with_resolver(self, resolver: Resolver) -> Self: def with_resolver(self, resolver: Resolver) -> Self:
assert resolver is not None, "resolver cannot be None" assert resolver is not None, "resolver cannot be None"
self._resolver = resolver self._resolver = resolver
return self return self
def with_broadcast(self, source: str):
assert self._resolver is not None, "resolver cannot be None for broadcast"
resolver = self._resolver
async def resolver_wrapper(*args, **kwargs):
result = (
await resolver(*args, **kwargs)
if iscoroutinefunction(resolver)
else resolver(*args, **kwargs)
)
await broadcast.publish(f"{source}", result)
return result
def with_change_broadcast(self, source: str):
assert self._resolver is not None, "resolver cannot be None for broadcast"
resolver = self._resolver
async def resolver_wrapper(*args, **kwargs):
result = (
await resolver(*args, **kwargs)
if iscoroutinefunction(resolver)
else resolver(*args, **kwargs)
)
await broadcast.publish(f"{source}", {})
return result
self._resolver = resolver_wrapper
return self
def with_input(self, input_type: Type[InputABC], input_key: str = None) -> Self: def with_input(self, input_type: Type[InputABC], input_key: str = None) -> Self:
self._input_type = input_type self._input_type = input_type
self._input_key = input_key self._input_key = input_key

View File

@ -16,11 +16,17 @@ class ResolverField(FieldABC):
require_any: TRequireAny = None, require_any: TRequireAny = None,
public: bool = False, public: bool = False,
resolver: Resolver = None, resolver: Resolver = None,
direct_result: bool = False,
): ):
FieldABC.__init__(self, name, require_any_permission, require_any, public) FieldABC.__init__(self, name, require_any_permission, require_any, public)
self._resolver = resolver self._resolver = resolver
self._direct_result = direct_result
@property @property
def resolver(self) -> Optional[Resolver]: def resolver(self) -> Optional[Resolver]:
return self._resolver return self._resolver
@property
def direct_result(self) -> bool:
return self._direct_result

View File

@ -12,12 +12,17 @@ class ResolverFieldBuilder(FieldBuilderABC):
FieldBuilderABC.__init__(self, name) FieldBuilderABC.__init__(self, name)
self._resolver = None self._resolver = None
self._direct_result = False
def with_resolver(self, resolver: Resolver) -> Self: def with_resolver(self, resolver: Resolver) -> Self:
assert resolver is not None, "resolver cannot be None" assert resolver is not None, "resolver cannot be None"
self._resolver = resolver self._resolver = resolver
return self return self
def with_direct_result(self) -> Self:
self._direct_result = True
return self
def build(self) -> ResolverField: def build(self) -> ResolverField:
assert self._resolver is not None, "resolver cannot be None" assert self._resolver is not None, "resolver cannot be None"
return ResolverField( return ResolverField(
@ -26,4 +31,5 @@ class ResolverFieldBuilder(FieldBuilderABC):
self._require_any, self._require_any,
self._public, self._public,
self._resolver, self._resolver,
self._direct_result,
) )

View File

@ -0,0 +1,32 @@
from typing import Optional
from ariadne.types import Resolver
from api_graphql.abc.field_abc import FieldABC
from api_graphql.typing import TRequireAny
from service.permission.permissions_enum import Permissions
class SubscriptionField(FieldABC):
def __init__(
self,
name: str,
require_any_permission: list[Permissions] = None,
require_any: TRequireAny = None,
public: bool = False,
resolver: Resolver = None,
generator: Resolver = None,
):
FieldABC.__init__(self, name, require_any_permission, require_any, public)
self._resolver = resolver
self._generator = generator
@property
def resolver(self) -> Optional[Resolver]:
return self._resolver
@property
def generator(self) -> Optional[Resolver]:
return self._generator

View File

@ -0,0 +1,46 @@
from typing import Self, AsyncGenerator
from ariadne.types import Resolver
from api.broadcast import broadcast
from api_graphql.abc.field_builder_abc import FieldBuilderABC
from api_graphql.field.subscription_field import SubscriptionField
class SubscriptionFieldBuilder(FieldBuilderABC):
def __init__(self, name: str):
FieldBuilderABC.__init__(self, name)
self._resolver = None
self._generator = None
def with_resolver(self, resolver: Resolver) -> Self:
assert resolver is not None, "resolver cannot be None"
self._resolver = resolver
return self
def with_generator(self, generator: Resolver) -> Self:
assert generator is not None, "generator cannot be None"
self._generator = generator
return self
def build(self) -> SubscriptionField:
assert self._resolver is not None, "resolver cannot be None"
if self._generator is None:
async def generator(*args, **kwargs) -> AsyncGenerator[str, None]:
async with broadcast.subscribe(channel=self._name) as subscriber:
async for message in subscriber:
yield message
self._generator = generator
return SubscriptionField(
self._name,
self._require_any_permission,
self._require_any,
self._public,
self._resolver,
self._generator,
)

View File

@ -0,0 +1,13 @@
from api_graphql.abc.db_model_filter_abc import DbModelFilterABC
from api_graphql.abc.filter.string_filter import StringFilter
class DomainFilter(DbModelFilterABC):
def __init__(
self,
obj: dict,
):
DbModelFilterABC.__init__(self, obj)
self.add_field("name", StringFilter)
self.add_field("description", StringFilter)

View File

@ -0,0 +1,15 @@
from typing import Optional
from api_graphql.abc.filter_abc import FilterABC
class FuzzyFilter(FilterABC):
def __init__(
self,
obj: Optional[dict],
):
FilterABC.__init__(self, obj)
self.add_field("fields", list)
self.add_field("term", str)
self.add_field("threshold", int)

View File

@ -9,6 +9,6 @@ class ShortUrlFilter(DbModelFilterABC):
): ):
DbModelFilterABC.__init__(self, obj) DbModelFilterABC.__init__(self, obj)
self.add_field("short_url", StringFilter) self.add_field("shortUrl", StringFilter, db_name="short_url")
self.add_field("target_url", StringFilter) self.add_field("targetUrl", StringFilter, db_name="target_url")
self.add_field("description", StringFilter) self.add_field("description", StringFilter)

View File

@ -21,11 +21,21 @@ input ApiKeySort {
identifier: SortOrder identifier: SortOrder
deleted: SortOrder deleted: SortOrder
editorId: SortOrder editor: UserSort
createdUtc: SortOrder createdUtc: SortOrder
updatedUtc: SortOrder updatedUtc: SortOrder
} }
enum ApiKeyFuzzyFields {
identifier
}
input ApiKeyFuzzy {
fields: [ApiKeyFuzzyFields]
term: String
threshold: Int
}
input ApiKeyFilter { input ApiKeyFilter {
id: IntFilter id: IntFilter
identifier: StringFilter identifier: StringFilter

View File

@ -0,0 +1,65 @@
type DomainResult {
totalCount: Int
count: Int
nodes: [Domain]
}
type Domain implements DbModel {
id: ID
name: String
shortUrls: [ShortUrl]
deleted: Boolean
editor: User
createdUtc: String
updatedUtc: String
}
input DomainSort {
id: SortOrder
name: SortOrder
deleted: SortOrder
editorId: SortOrder
createdUtc: SortOrder
updatedUtc: SortOrder
}
enum DomainFuzzyFields {
name
}
input DomainFuzzy {
fields: [DomainFuzzyFields]
term: String
threshold: Int
}
input DomainFilter {
id: IntFilter
name: StringFilter
fuzzy: DomainFuzzy
deleted: BooleanFilter
editor: IntFilter
createdUtc: DateFilter
updatedUtc: DateFilter
}
type DomainMutation {
create(input: DomainCreateInput!): Domain
update(input: DomainUpdateInput!): Domain
delete(id: ID!): Boolean
restore(id: ID!): Boolean
}
input DomainCreateInput {
name: String!
}
input DomainUpdateInput {
id: ID!
name: String
}

View File

@ -0,0 +1,19 @@
type FeatureFlag implements DbModel {
id: ID
key: String
value: Boolean
deleted: Boolean
editor: User
createdUtc: String
updatedUtc: String
}
type FeatureFlagMutation {
change(input: FeatureFlagInput!): FeatureFlag
}
input FeatureFlagInput {
key: String!
value: Boolean!
}

View File

@ -9,6 +9,7 @@ type Group implements DbModel {
name: String name: String
shortUrls: [ShortUrl] shortUrls: [ShortUrl]
roles: [Role]
deleted: Boolean deleted: Boolean
editor: User editor: User
@ -26,10 +27,22 @@ input GroupSort {
updatedUtc: SortOrder updatedUtc: SortOrder
} }
enum GroupFuzzyFields {
name
}
input GroupFuzzy {
fields: [GroupFuzzyFields]
term: String
threshold: Int
}
input GroupFilter { input GroupFilter {
id: IntFilter id: IntFilter
name: StringFilter name: StringFilter
fuzzy: GroupFuzzy
deleted: BooleanFilter deleted: BooleanFilter
editor: IntFilter editor: IntFilter
createdUtc: DateFilter createdUtc: DateFilter
@ -45,9 +58,11 @@ type GroupMutation {
input GroupCreateInput { input GroupCreateInput {
name: String! name: String!
roles: [ID]
} }
input GroupUpdateInput { input GroupUpdateInput {
id: ID! id: ID!
name: String name: String
roles: [ID]
} }

View File

@ -5,5 +5,10 @@ type Mutation {
role: RoleMutation role: RoleMutation
group: GroupMutation group: GroupMutation
domain: DomainMutation
shortUrl: ShortUrlMutation shortUrl: ShortUrlMutation
setting: SettingMutation
userSetting: UserSettingMutation
featureFlag: FeatureFlagMutation
} }

View File

@ -11,6 +11,11 @@ type Query {
userHasAnyPermission(permissions: [String]!): Boolean userHasAnyPermission(permissions: [String]!): Boolean
notExistingUsersFromKeycloak: KeycloakUserResult notExistingUsersFromKeycloak: KeycloakUserResult
domains(filter: [DomainFilter], sort: [DomainSort], skip: Int, take: Int): DomainResult
groups(filter: [GroupFilter], sort: [GroupSort], skip: Int, take: Int): GroupResult groups(filter: [GroupFilter], sort: [GroupSort], skip: Int, take: Int): GroupResult
shortUrls(filter: [ShortUrlFilter], sort: [ShortUrlSort], skip: Int, take: Int): ShortUrlResult shortUrls(filter: [ShortUrlFilter], sort: [ShortUrlSort], skip: Int, take: Int): ShortUrlResult
settings(key: String): [Setting]
userSettings(key: String): [Setting]
featureFlags(key: String): [FeatureFlag]
} }

View File

@ -23,18 +23,31 @@ input RoleSort {
description: SortOrder description: SortOrder
deleted: SortOrder deleted: SortOrder
editorId: SortOrder editor: UserSort
createdUtc: SortOrder createdUtc: SortOrder
updatedUtc: SortOrder updatedUtc: SortOrder
} }
enum RoleFuzzyFields {
name
description
}
input RoleFuzzy {
fields: [RoleFuzzyFields]
term: String
threshold: Int
}
input RoleFilter { input RoleFilter {
id: IntFilter id: IntFilter
name: StringFilter name: StringFilter
description: StringFilter description: StringFilter
fuzzy: RoleFuzzy
deleted: BooleanFilter deleted: BooleanFilter
editorId: IntFilter editor_id: IntFilter
createdUtc: DateFilter createdUtc: DateFilter
updatedUtc: DateFilter updatedUtc: DateFilter
} }

View File

@ -0,0 +1,19 @@
type Setting implements DbModel {
id: ID
key: String
value: String
deleted: Boolean
editor: User
createdUtc: String
updatedUtc: String
}
type SettingMutation {
change(input: SettingInput!): Setting
}
input SettingInput {
key: String!
value: String!
}

View File

@ -11,6 +11,7 @@ type ShortUrl implements DbModel {
description: String description: String
visits: Int visits: Int
group: Group group: Group
domain: Domain
loadingScreen: Boolean loadingScreen: Boolean
deleted: Boolean deleted: Boolean
@ -31,12 +32,27 @@ input ShortUrlSort {
updatedUtc: SortOrder updatedUtc: SortOrder
} }
enum ShortUrlFuzzyFields {
shortUrl
targetUrl
description
}
input ShortUrlFuzzy {
fields: [ShortUrlFuzzyFields]
term: String
threshold: Int
}
input ShortUrlFilter { input ShortUrlFilter {
id: IntFilter id: IntFilter
name: StringFilter shortUrl: StringFilter
targetUrl: StringFilter
description: StringFilter description: StringFilter
loadingScreen: BooleanFilter loadingScreen: BooleanFilter
fuzzy: ShortUrlFuzzy
deleted: BooleanFilter deleted: BooleanFilter
editor: IntFilter editor: IntFilter
createdUtc: DateFilter createdUtc: DateFilter
@ -48,6 +64,7 @@ type ShortUrlMutation {
update(input: ShortUrlUpdateInput!): ShortUrl update(input: ShortUrlUpdateInput!): ShortUrl
delete(id: ID!): Boolean delete(id: ID!): Boolean
restore(id: ID!): Boolean restore(id: ID!): Boolean
trackVisit(id: ID!, agent: String): Boolean
} }
input ShortUrlCreateInput { input ShortUrlCreateInput {
@ -55,6 +72,7 @@ input ShortUrlCreateInput {
targetUrl: String! targetUrl: String!
description: String description: String
groupId: ID groupId: ID
domainId: ID
loadingScreen: Boolean loadingScreen: Boolean
} }
@ -64,5 +82,6 @@ input ShortUrlUpdateInput {
targetUrl: String targetUrl: String
description: String description: String
groupId: ID groupId: ID
domainId: ID
loadingScreen: Boolean loadingScreen: Boolean
} }

View File

@ -0,0 +1,16 @@
scalar SubscriptionChange
type Subscription {
ping: String
apiKeyChange: SubscriptionChange
featureFlagChange: SubscriptionChange
roleChange: SubscriptionChange
settingChange: SubscriptionChange
userChange: SubscriptionChange
userSettingChange: SubscriptionChange
domainChange: SubscriptionChange
groupChange: SubscriptionChange
shortUrlChange: SubscriptionChange
}

View File

@ -35,19 +35,33 @@ input UserSort {
email: SortOrder email: SortOrder
deleted: SortOrder deleted: SortOrder
editorId: SortOrder editor: UserSort
createdUtc: SortOrder createdUtc: SortOrder
updatedUtc: SortOrder updatedUtc: SortOrder
} }
enum UserFuzzyFields {
keycloakId
username
email
}
input UserFuzzy {
fields: [UserFuzzyFields]
term: String
threshold: Int
}
input UserFilter { input UserFilter {
id: IntFilter id: IntFilter
keycloakId: StringFilter keycloakId: StringFilter
username: StringFilter username: StringFilter
email: StringFilter email: StringFilter
fuzzy: UserFuzzy
deleted: BooleanFilter deleted: BooleanFilter
editor: IntFilter editor: UserFilter
createdUtc: DateFilter createdUtc: DateFilter
updatedUtc: DateFilter updatedUtc: DateFilter
} }

View File

@ -0,0 +1,19 @@
type UserSetting implements DbModel {
id: ID
key: String
value: String
deleted: Boolean
editor: User
createdUtc: String
updatedUtc: String
}
type UserSettingMutation {
change(input: UserSettingInput!): UserSetting
}
input UserSettingInput {
key: String!
value: String!
}

View File

@ -0,0 +1,13 @@
from api_graphql.abc.input_abc import InputABC
class DomainCreateInput(InputABC):
def __init__(self, src: dict):
InputABC.__init__(self, src)
self._name = self.option("name", str, required=True)
@property
def name(self) -> str:
return self._name

View File

@ -0,0 +1,18 @@
from api_graphql.abc.input_abc import InputABC
class DomainUpdateInput(InputABC):
def __init__(self, src: dict):
InputABC.__init__(self, src)
self._id = self.option("id", int, required=True)
self._name = self.option("name", str)
@property
def id(self) -> int:
return self._id
@property
def name(self) -> str:
return self._name

View File

@ -0,0 +1,18 @@
from api_graphql.abc.input_abc import InputABC
class FeatureFlagInput(InputABC):
def __init__(self, src: dict):
InputABC.__init__(self, src)
self._key = self.option("key", str, required=True)
self._value = self.option("value", bool, required=True)
@property
def key(self) -> str:
return self._key
@property
def value(self) -> bool:
return self._value

View File

@ -7,7 +7,12 @@ class GroupCreateInput(InputABC):
InputABC.__init__(self, src) InputABC.__init__(self, src)
self._name = self.option("name", str, required=True) self._name = self.option("name", str, required=True)
self._roles = self.option("roles", list[int])
@property @property
def name(self) -> str: def name(self) -> str:
return self._name return self._name
@property
def roles(self) -> list[int]:
return self._roles

View File

@ -8,6 +8,7 @@ class GroupUpdateInput(InputABC):
self._id = self.option("id", int, required=True) self._id = self.option("id", int, required=True)
self._name = self.option("name", str) self._name = self.option("name", str)
self._roles = self.option("roles", list[int])
@property @property
def id(self) -> int: def id(self) -> int:
@ -16,3 +17,7 @@ class GroupUpdateInput(InputABC):
@property @property
def name(self) -> str: def name(self) -> str:
return self._name return self._name
@property
def roles(self) -> list[int]:
return self._roles

View File

@ -0,0 +1,18 @@
from api_graphql.abc.input_abc import InputABC
class SettingInput(InputABC):
def __init__(self, src: dict):
InputABC.__init__(self, src)
self._key = self.option("key", str, required=True)
self._value = self.option("value", str, required=True)
@property
def key(self) -> str:
return self._key
@property
def value(self) -> str:
return self._value

View File

@ -12,6 +12,7 @@ class ShortUrlCreateInput(InputABC):
self._target_url = self.option("targetUrl", str, required=True) self._target_url = self.option("targetUrl", str, required=True)
self._description = self.option("description", str) self._description = self.option("description", str)
self._group_id = self.option("groupId", int) self._group_id = self.option("groupId", int)
self._domain_id = self.option("domainId", int)
self._loading_screen = self.option("loadingScreen", bool) self._loading_screen = self.option("loadingScreen", bool)
@property @property
@ -30,6 +31,10 @@ class ShortUrlCreateInput(InputABC):
def group_id(self) -> Optional[int]: def group_id(self) -> Optional[int]:
return self._group_id return self._group_id
@property
def domain_id(self) -> Optional[int]:
return self._domain_id
@property @property
def loading_screen(self) -> Optional[str]: def loading_screen(self) -> Optional[str]:
return self._loading_screen return self._loading_screen

View File

@ -13,6 +13,7 @@ class ShortUrlUpdateInput(InputABC):
self._target_url = self.option("targetUrl", str) self._target_url = self.option("targetUrl", str)
self._description = self.option("description", str) self._description = self.option("description", str)
self._group_id = self.option("groupId", int) self._group_id = self.option("groupId", int)
self._domain_id = self.option("domainId", int)
self._loading_screen = self.option("loadingScreen", bool) self._loading_screen = self.option("loadingScreen", bool)
@property @property
@ -35,6 +36,10 @@ class ShortUrlUpdateInput(InputABC):
def group_id(self) -> Optional[int]: def group_id(self) -> Optional[int]:
return self._group_id return self._group_id
@property
def domain_id(self) -> Optional[int]:
return self._domain_id
@property @property
def loading_screen(self) -> Optional[str]: def loading_screen(self) -> Optional[str]:
return self._loading_screen return self._loading_screen

View File

@ -0,0 +1,18 @@
from api_graphql.abc.input_abc import InputABC
class UserSettingInput(InputABC):
def __init__(self, src: dict):
InputABC.__init__(self, src)
self._key = self.option("key", str, required=True)
self._value = self.option("value", str, required=True)
@property
def key(self) -> str:
return self._key
@property
def value(self) -> str:
return self._value

View File

@ -33,6 +33,15 @@ class Mutation(MutationABC):
], ],
) )
self.add_mutation_type(
"domain",
"Domain",
require_any_permission=[
Permissions.domains_create,
Permissions.domains_update,
Permissions.domains_delete,
],
)
self.add_mutation_type( self.add_mutation_type(
"group", "group",
"Group", "Group",
@ -51,3 +60,22 @@ class Mutation(MutationABC):
Permissions.short_urls_delete, Permissions.short_urls_delete,
], ],
) )
self.add_mutation_type(
"setting",
"Setting",
require_any_permission=[
Permissions.settings_update,
],
)
self.add_mutation_type(
"userSetting",
"UserSetting",
)
self.add_mutation_type(
"featureFlag",
"FeatureFlag",
require_any_permission=[
Permissions.administrator,
],
)

View File

@ -0,0 +1,75 @@
from api_graphql.abc.mutation_abc import MutationABC
from api_graphql.input.domain_create_input import DomainCreateInput
from api_graphql.input.domain_update_input import DomainUpdateInput
from api_graphql.input.group_create_input import GroupCreateInput
from api_graphql.input.group_update_input import GroupUpdateInput
from core.logger import APILogger
from data.schemas.public.domain_dao import domainDao
from data.schemas.public.group import Group
from service.permission.permissions_enum import Permissions
logger = APILogger(__name__)
class DomainMutation(MutationABC):
def __init__(self):
MutationABC.__init__(self, "Domain")
self.mutation(
"create",
self.resolve_create,
DomainCreateInput,
require_any_permission=[Permissions.domains_create],
)
self.mutation(
"update",
self.resolve_update,
DomainUpdateInput,
require_any_permission=[Permissions.domains_update],
)
self.mutation(
"delete",
self.resolve_delete,
require_any_permission=[Permissions.domains_delete],
)
self.mutation(
"restore",
self.resolve_restore,
require_any_permission=[Permissions.domains_delete],
)
@staticmethod
async def resolve_create(obj: GroupCreateInput, *_):
logger.debug(f"create domain: {obj.__dict__}")
domain = Group(
0,
obj.name,
)
nid = await domainDao.create(domain)
return await domainDao.get_by_id(nid)
@staticmethod
async def resolve_update(obj: GroupUpdateInput, *_):
logger.debug(f"update domain: {input}")
if obj.name is not None:
domain = await domainDao.get_by_id(obj.id)
domain.name = obj.name
await domainDao.update(domain)
return await domainDao.get_by_id(obj.id)
@staticmethod
async def resolve_delete(*_, id: str):
logger.debug(f"delete domain: {id}")
domain = await domainDao.get_by_id(id)
await domainDao.delete(domain)
return True
@staticmethod
async def resolve_restore(*_, id: str):
logger.debug(f"restore domain: {id}")
domain = await domainDao.get_by_id(id)
await domainDao.restore(domain)
return True

View File

@ -0,0 +1,32 @@
from api_graphql.abc.mutation_abc import MutationABC
from api_graphql.input.feature_flag_input import FeatureFlagInput
from core.logger import APILogger
from data.schemas.system.feature_flag import FeatureFlag
from data.schemas.system.feature_flag_dao import featureFlagDao
from service.permission.permissions_enum import Permissions
logger = APILogger(__name__)
class FeatureFlagMutation(MutationABC):
def __init__(self):
MutationABC.__init__(self, "FeatureFlag")
self.mutation(
"change",
self.resolve_change,
FeatureFlagInput,
require_any_permission=[Permissions.administrator],
)
@staticmethod
async def resolve_change(obj: FeatureFlagInput, *_):
logger.debug(f"create new feature flag: {input}")
setting = await featureFlagDao.find_single_by({FeatureFlag.key: obj.key})
if setting is None:
raise ValueError(f"FeatureFlag {obj.key} not found")
setting.value = obj.value
await featureFlagDao.update(setting)
return await featureFlagDao.get_by_id(setting.id)

View File

@ -1,9 +1,13 @@
from typing import Optional
from api_graphql.abc.mutation_abc import MutationABC from api_graphql.abc.mutation_abc import MutationABC
from api_graphql.input.group_create_input import GroupCreateInput from api_graphql.input.group_create_input import GroupCreateInput
from api_graphql.input.group_update_input import GroupUpdateInput from api_graphql.input.group_update_input import GroupUpdateInput
from core.logger import APILogger from core.logger import APILogger
from data.schemas.public.group import Group from data.schemas.public.group import Group
from data.schemas.public.group_dao import groupDao from data.schemas.public.group_dao import groupDao
from data.schemas.public.group_role_assignment import GroupRoleAssignment
from data.schemas.public.group_role_assignment_dao import groupRoleAssignmentDao
from service.permission.permissions_enum import Permissions from service.permission.permissions_enum import Permissions
logger = APILogger(__name__) logger = APILogger(__name__)
@ -37,25 +41,61 @@ class GroupMutation(MutationABC):
) )
@staticmethod @staticmethod
async def resolve_create(obj: GroupCreateInput, *_): async def _handle_group_role_assignments(gid: int, roles: Optional[list[int]]):
if roles is None:
return
existing_roles = await groupDao.get_roles(gid)
existing_role_ids = {role.id for role in existing_roles}
new_role_ids = set(roles)
roles_to_add = new_role_ids - existing_role_ids
roles_to_remove = existing_role_ids - new_role_ids
if roles_to_add:
group_role_assignments = [
GroupRoleAssignment(0, gid, role_id) for role_id in roles_to_add
]
await groupRoleAssignmentDao.create_many(group_role_assignments)
if roles_to_remove:
assignments_to_remove = await groupRoleAssignmentDao.find_by(
[
{GroupRoleAssignment.group_id: gid},
{GroupRoleAssignment.role_id: {"in": roles_to_remove}},
]
)
await groupRoleAssignmentDao.delete_many(assignments_to_remove)
@classmethod
async def resolve_create(cls, obj: GroupCreateInput, *_):
logger.debug(f"create group: {obj.__dict__}") logger.debug(f"create group: {obj.__dict__}")
group = Group( group = Group(
0, 0,
obj.name, obj.name,
) )
nid = await groupDao.create(group) gid = await groupDao.create(group)
return await groupDao.get_by_id(nid)
@staticmethod await cls._handle_group_role_assignments(gid, obj.roles)
async def resolve_update(obj: GroupUpdateInput, *_):
return await groupDao.get_by_id(gid)
@classmethod
async def resolve_update(cls, obj: GroupUpdateInput, *_):
logger.debug(f"update group: {input}") logger.debug(f"update group: {input}")
if await groupDao.find_by_id(obj.id) is None:
raise ValueError(f"Group with id {obj.id} not found")
if obj.name is not None: if obj.name is not None:
group = await groupDao.get_by_id(obj.id) group = await groupDao.get_by_id(obj.id)
group.name = obj.name group.name = obj.name
await groupDao.update(group) await groupDao.update(group)
await cls._handle_group_role_assignments(obj.id, obj.roles)
return await groupDao.get_by_id(obj.id) return await groupDao.get_by_id(obj.id)
@staticmethod @staticmethod

View File

@ -0,0 +1,32 @@
from api_graphql.abc.mutation_abc import MutationABC
from api_graphql.input.setting_input import SettingInput
from core.logger import APILogger
from data.schemas.system.setting import Setting
from data.schemas.system.setting_dao import settingsDao
from service.permission.permissions_enum import Permissions
logger = APILogger(__name__)
class SettingMutation(MutationABC):
def __init__(self):
MutationABC.__init__(self, "Setting")
self.mutation(
"change",
self.resolve_change,
SettingInput,
require_any_permission=[Permissions.settings_update],
)
@staticmethod
async def resolve_change(obj: SettingInput, *_):
logger.debug(f"create new setting: {input}")
setting = await settingsDao.find_single_by({Setting.key: obj.key})
if setting is None:
raise ValueError(f"Setting with key {obj.key} not found")
setting.value = obj.value
await settingsDao.update(setting)
return await settingsDao.get_by_id(setting.id)

View File

@ -1,12 +1,13 @@
from werkzeug.exceptions import NotFound
from api_graphql.abc.mutation_abc import MutationABC from api_graphql.abc.mutation_abc import MutationABC
from api_graphql.input.short_url_create_input import ShortUrlCreateInput from api_graphql.input.short_url_create_input import ShortUrlCreateInput
from api_graphql.input.short_url_update_input import ShortUrlUpdateInput from api_graphql.input.short_url_update_input import ShortUrlUpdateInput
from core.logger import APILogger from core.logger import APILogger
from data.schemas.public.domain_dao import domainDao
from data.schemas.public.group_dao import groupDao from data.schemas.public.group_dao import groupDao
from data.schemas.public.short_url import ShortUrl from data.schemas.public.short_url import ShortUrl
from data.schemas.public.short_url_dao import shortUrlDao from data.schemas.public.short_url_dao import shortUrlDao
from data.schemas.public.short_url_visit import ShortUrlVisit
from data.schemas.public.short_url_visit_dao import shortUrlVisitDao
from service.permission.permissions_enum import Permissions from service.permission.permissions_enum import Permissions
logger = APILogger(__name__) logger = APILogger(__name__)
@ -38,6 +39,11 @@ class ShortUrlMutation(MutationABC):
self.resolve_restore, self.resolve_restore,
require_any_permission=[Permissions.short_urls_delete], require_any_permission=[Permissions.short_urls_delete],
) )
self.mutation(
"trackVisit",
self.resolve_track_visit,
require_any_permission=[Permissions.short_urls_update],
)
@staticmethod @staticmethod
async def resolve_create(obj: ShortUrlCreateInput, *_): async def resolve_create(obj: ShortUrlCreateInput, *_):
@ -49,6 +55,7 @@ class ShortUrlMutation(MutationABC):
obj.target_url, obj.target_url,
obj.description, obj.description,
obj.group_id, obj.group_id,
obj.domain_id,
obj.loading_screen, obj.loading_screen,
) )
nid = await shortUrlDao.create(short_url) nid = await shortUrlDao.create(short_url)
@ -72,8 +79,18 @@ class ShortUrlMutation(MutationABC):
if obj.group_id is not None: if obj.group_id is not None:
group_by_id = await groupDao.find_by_id(obj.group_id) group_by_id = await groupDao.find_by_id(obj.group_id)
if group_by_id is None: if group_by_id is None:
raise NotFound(f"Group with id {obj.group_id} does not exist") raise ValueError(f"Group with id {obj.group_id} does not exist")
short_url.group_id = obj.group_id short_url.group_id = obj.group_id
else:
short_url.group_id = None
if obj.domain_id is not None:
domain_by_id = await domainDao.find_by_id(obj.domain_id)
if domain_by_id is None:
raise ValueError(f"Domain with id {obj.domain_id} does not exist")
short_url.domain_id = obj.domain_id
else:
short_url.domain_id = None
if obj.loading_screen is not None: if obj.loading_screen is not None:
short_url.loading_screen = obj.loading_screen short_url.loading_screen = obj.loading_screen
@ -94,3 +111,9 @@ class ShortUrlMutation(MutationABC):
short_url = await shortUrlDao.get_by_id(id) short_url = await shortUrlDao.get_by_id(id)
await shortUrlDao.restore(short_url) await shortUrlDao.restore(short_url)
return True return True
@staticmethod
async def resolve_track_visit(*_, id: int, agent: str):
logger.debug(f"track visit: {id} -- {agent}")
await shortUrlVisitDao.create(ShortUrlVisit(0, id, agent))
return True

View File

@ -0,0 +1,40 @@
from api.route import Route
from api_graphql.abc.mutation_abc import MutationABC
from api_graphql.input.user_setting_input import UserSettingInput
from core.logger import APILogger
from data.schemas.public.user_setting import UserSetting
from data.schemas.public.user_setting_dao import userSettingsDao
from data.schemas.system.setting_dao import settingsDao
from service.permission.permissions_enum import Permissions
logger = APILogger(__name__)
class UserSettingMutation(MutationABC):
def __init__(self):
MutationABC.__init__(self, "UserSetting")
self.mutation(
"change",
self.resolve_change,
UserSettingInput,
require_any_permission=[Permissions.settings_update],
)
@staticmethod
async def resolve_change(obj: UserSettingInput, *_):
logger.debug(f"create new setting: {input}")
user = await Route.get_user_or_default()
if user is None:
logger.debug("user not authorized")
return None
setting = await userSettingsDao.find_single_by(
[{UserSetting.user_id: user.id}, {UserSetting.key: obj.key}]
)
if setting is None:
await userSettingsDao.create(UserSetting(0, user.id, obj.key, obj.value))
else:
setting.value = obj.value
await userSettingsDao.update(setting)
return await userSettingsDao.find_by_key(user, obj.key)

View File

@ -0,0 +1,17 @@
from api_graphql.abc.db_model_query_abc import DbModelQueryABC
from data.schemas.public.domain import Domain
from data.schemas.public.group import Group
from data.schemas.public.short_url import ShortUrl
from data.schemas.public.short_url_dao import shortUrlDao
class DomainQuery(DbModelQueryABC):
def __init__(self):
DbModelQueryABC.__init__(self, "Domain")
self.set_field("name", lambda x, *_: x.name)
self.set_field("shortUrls", self._get_urls)
@staticmethod
async def _get_urls(domain: Domain, *_):
return await shortUrlDao.find_by({ShortUrl.domain_id: domain.id})

View File

@ -1,7 +1,11 @@
from api_graphql.abc.db_model_query_abc import DbModelQueryABC from api_graphql.abc.db_model_query_abc import DbModelQueryABC
from api_graphql.field.resolver_field_builder import ResolverFieldBuilder
from api_graphql.require_any_resolvers import group_by_assignment_resolver
from data.schemas.public.group import Group from data.schemas.public.group import Group
from data.schemas.public.group_dao import groupDao
from data.schemas.public.short_url import ShortUrl from data.schemas.public.short_url import ShortUrl
from data.schemas.public.short_url_dao import shortUrlDao from data.schemas.public.short_url_dao import shortUrlDao
from service.permission.permissions_enum import Permissions
class GroupQuery(DbModelQueryABC): class GroupQuery(DbModelQueryABC):
@ -9,8 +13,22 @@ class GroupQuery(DbModelQueryABC):
DbModelQueryABC.__init__(self, "Group") DbModelQueryABC.__init__(self, "Group")
self.set_field("name", lambda x, *_: x.name) self.set_field("name", lambda x, *_: x.name)
self.set_field("shortUrls", self._get_urls) self.field(
ResolverFieldBuilder("shortUrls")
.with_resolver(self._get_urls)
.with_require_any(
[
Permissions.groups,
],
[group_by_assignment_resolver],
)
)
self.set_field("roles", self._get_roles)
@staticmethod @staticmethod
async def _get_urls(group: Group, *_): async def _get_urls(group: Group, *_):
return await shortUrlDao.find_by({ShortUrl.group_id: group.id}) return await shortUrlDao.find_by({ShortUrl.group_id: group.id})
@staticmethod
async def _get_roles(group: Group, *_):
return await groupDao.get_roles(group.id)

View File

@ -9,5 +9,6 @@ class ShortUrlQuery(DbModelQueryABC):
self.set_field("targetUrl", lambda x, *_: x.target_url) self.set_field("targetUrl", lambda x, *_: x.target_url)
self.set_field("description", lambda x, *_: x.description) self.set_field("description", lambda x, *_: x.description)
self.set_field("group", lambda x, *_: x.group) self.set_field("group", lambda x, *_: x.group)
self.set_field("domain", lambda x, *_: x.domain)
self.set_field("visits", lambda x, *_: x.visit_count) self.set_field("visits", lambda x, *_: x.visit_count)
self.set_field("loadingScreen", lambda x, *_: x.loading_screen) self.set_field("loadingScreen", lambda x, *_: x.loading_screen)

View File

@ -5,11 +5,13 @@ from api_graphql.abc.sort_abc import Sort
from api_graphql.field.dao_field_builder import DaoFieldBuilder from api_graphql.field.dao_field_builder import DaoFieldBuilder
from api_graphql.field.resolver_field_builder import ResolverFieldBuilder from api_graphql.field.resolver_field_builder import ResolverFieldBuilder
from api_graphql.filter.api_key_filter import ApiKeyFilter from api_graphql.filter.api_key_filter import ApiKeyFilter
from api_graphql.filter.domain_filter import DomainFilter
from api_graphql.filter.group_filter import GroupFilter from api_graphql.filter.group_filter import GroupFilter
from api_graphql.filter.permission_filter import PermissionFilter from api_graphql.filter.permission_filter import PermissionFilter
from api_graphql.filter.role_filter import RoleFilter from api_graphql.filter.role_filter import RoleFilter
from api_graphql.filter.short_url_filter import ShortUrlFilter from api_graphql.filter.short_url_filter import ShortUrlFilter
from api_graphql.filter.user_filter import UserFilter from api_graphql.filter.user_filter import UserFilter
from api_graphql.require_any_resolvers import group_by_assignment_resolver
from data.schemas.administration.api_key import ApiKey from data.schemas.administration.api_key import ApiKey
from data.schemas.administration.api_key_dao import apiKeyDao from data.schemas.administration.api_key_dao import apiKeyDao
from data.schemas.administration.user import User from data.schemas.administration.user import User
@ -18,10 +20,16 @@ from data.schemas.permission.permission import Permission
from data.schemas.permission.permission_dao import permissionDao from data.schemas.permission.permission_dao import permissionDao
from data.schemas.permission.role import Role from data.schemas.permission.role import Role
from data.schemas.permission.role_dao import roleDao from data.schemas.permission.role_dao import roleDao
from data.schemas.public.domain import Domain
from data.schemas.public.domain_dao import domainDao
from data.schemas.public.group import Group from data.schemas.public.group import Group
from data.schemas.public.group_dao import groupDao from data.schemas.public.group_dao import groupDao
from data.schemas.public.short_url import ShortUrl from data.schemas.public.short_url import ShortUrl
from data.schemas.public.short_url_dao import shortUrlDao from data.schemas.public.short_url_dao import shortUrlDao
from data.schemas.public.user_setting import UserSetting
from data.schemas.public.user_setting_dao import userSettingsDao
from data.schemas.system.feature_flag_dao import featureFlagDao
from data.schemas.system.setting_dao import settingsDao
from service.permission.permissions_enum import Permissions from service.permission.permissions_enum import Permissions
@ -48,7 +56,15 @@ class Query(QueryABC):
.with_dao(roleDao) .with_dao(roleDao)
.with_filter(RoleFilter) .with_filter(RoleFilter)
.with_sort(Sort[Role]) .with_sort(Sort[Role])
.with_require_any_permission([Permissions.roles]) .with_require_any_permission(
[
Permissions.roles,
Permissions.users_create,
Permissions.users_update,
Permissions.groups_create,
Permissions.groups_update,
]
)
) )
self.field( self.field(
@ -81,25 +97,55 @@ class Query(QueryABC):
) )
self.field( self.field(
DaoFieldBuilder("groups") DaoFieldBuilder("domains")
.with_dao(groupDao) .with_dao(domainDao)
.with_filter(GroupFilter) .with_filter(DomainFilter)
.with_sort(Sort[Group]) .with_sort(Sort[Domain])
.with_require_any_permission( .with_require_any_permission(
[ [
Permissions.groups, Permissions.domains,
Permissions.short_urls_create, Permissions.short_urls_create,
Permissions.short_urls_update, Permissions.short_urls_update,
] ]
) )
) )
# partially public to load redirect if not resolved/redirected by api self.field(
DaoFieldBuilder("groups")
.with_dao(groupDao)
.with_filter(GroupFilter)
.with_sort(Sort[Group])
.with_require_any(
[
Permissions.groups,
Permissions.short_urls_create,
Permissions.short_urls_update,
],
[group_by_assignment_resolver],
)
)
self.field( self.field(
DaoFieldBuilder("shortUrls") DaoFieldBuilder("shortUrls")
.with_dao(shortUrlDao) .with_dao(shortUrlDao)
.with_filter(ShortUrlFilter) .with_filter(ShortUrlFilter)
.with_sort(Sort[ShortUrl]) .with_sort(Sort[ShortUrl])
.with_require_any_permission([Permissions.short_urls]) .with_require_any([Permissions.short_urls], [group_by_assignment_resolver])
)
self.field(
ResolverFieldBuilder("settings")
.with_resolver(self._resolve_settings)
.with_direct_result()
.with_public(True)
)
self.field(
ResolverFieldBuilder("userSettings")
.with_resolver(self._resolve_user_settings)
.with_direct_result()
)
self.field(
ResolverFieldBuilder("featureFlags")
.with_resolver(self._resolve_feature_flags)
.with_direct_result()
) )
@staticmethod @staticmethod
@ -132,3 +178,27 @@ class Query(QueryABC):
for x in kc_users for x in kc_users
if x["id"] not in existing_user_keycloak_ids if x["id"] not in existing_user_keycloak_ids
] ]
@staticmethod
async def _resolve_settings(*args, **kwargs):
if "key" in kwargs:
return [await settingsDao.find_by_key(kwargs["key"])]
return await settingsDao.get_all()
@staticmethod
async def _resolve_user_settings(*args, **kwargs):
user = await Route.get_user()
if user is None:
return None
if "key" in kwargs:
return await userSettingsDao.find_by(
{UserSetting.user_id: user.id, UserSetting.key: kwargs["key"]}
)
return await userSettingsDao.find_by({UserSetting.user_id: user.id})
@staticmethod
async def _resolve_feature_flags(*args, **kwargs):
if "key" in kwargs:
return [await featureFlagDao.find_by_key(kwargs["key"])]
return await featureFlagDao.get_all()

View File

@ -0,0 +1,30 @@
from api_graphql.service.collection_result import CollectionResult
from api_graphql.service.query_context import QueryContext
from data.schemas.public.group_dao import groupDao
from service.permission.permissions_enum import Permissions
async def group_by_assignment_resolver(ctx: QueryContext) -> bool:
if not isinstance(ctx.data, CollectionResult):
return False
if ctx.has_permission(Permissions.short_urls_by_assignment):
groups = [await x.group for x in ctx.data.nodes]
role_ids = {x.id for x in await ctx.user.roles}
filtered_groups = [
g.id
for g in groups
if g is not None
and (roles := await groupDao.get_roles(g.id))
and all(r.id in role_ids for r in roles)
]
ctx.data.nodes = [
node
for node in ctx.data.nodes
if (await node.group) is not None
and (await node.group).id in filtered_groups
]
return True
return True

View File

@ -14,7 +14,7 @@ class QueryContext:
self, self,
data: Any, data: Any,
user: Optional[User], user: Optional[User],
user_permissions: Optional[list[Permission]], user_permissions: Optional[list[Permissions]],
*args, *args,
**kwargs **kwargs
): ):
@ -23,7 +23,7 @@ class QueryContext:
self._user = user self._user = user
if user_permissions is None: if user_permissions is None:
user_permissions = [] user_permissions = []
self._user_permissions: list[str] = [x.name for x in user_permissions] self._user_permissions: list[str] = [x.value for x in user_permissions]
self._resolve_info = None self._resolve_info = None
for arg in args: for arg in args:

View File

@ -5,6 +5,7 @@ from ariadne import make_executable_schema, load_schema_from_path
from api_graphql.definition import QUERIES from api_graphql.definition import QUERIES
from api_graphql.mutation import Mutation from api_graphql.mutation import Mutation
from api_graphql.query import Query from api_graphql.query import Query
from api_graphql.subscription import Subscription
type_defs = load_schema_from_path( type_defs = load_schema_from_path(
os.path.join(os.path.dirname(os.path.realpath(__file__)), "../graphql/") os.path.join(os.path.dirname(os.path.realpath(__file__)), "../graphql/")
@ -13,5 +14,6 @@ schema = make_executable_schema(
type_defs, type_defs,
Query(), Query(),
Mutation(), Mutation(),
Subscription(),
*QUERIES, *QUERIES,
) )

View File

@ -0,0 +1,66 @@
from api_graphql.abc.subscription_abc import SubscriptionABC
from api_graphql.field.subscription_field_builder import SubscriptionFieldBuilder
from service.permission.permissions_enum import Permissions
class Subscription(SubscriptionABC):
def __init__(self):
SubscriptionABC.__init__(self)
self.subscribe(
SubscriptionFieldBuilder("ping")
.with_resolver(lambda message, *_: message.message)
.with_public(True)
)
self.subscribe(
SubscriptionFieldBuilder("apiKeyChange")
.with_resolver(lambda message, *_: message.message)
.with_require_any_permission([Permissions.api_keys])
)
self.subscribe(
SubscriptionFieldBuilder("featureFlagChange")
.with_resolver(lambda message, *_: message.message)
.with_public(True)
)
self.subscribe(
SubscriptionFieldBuilder("roleChange")
.with_resolver(lambda message, *_: message.message)
.with_require_any_permission([Permissions.roles])
)
self.subscribe(
SubscriptionFieldBuilder("settingChange")
.with_resolver(lambda message, *_: message.message)
.with_require_any_permission([Permissions.settings])
)
self.subscribe(
SubscriptionFieldBuilder("userChange")
.with_resolver(lambda message, *_: message.message)
.with_require_any_permission([Permissions.users])
)
self.subscribe(
SubscriptionFieldBuilder("userSettingChange")
.with_resolver(lambda message, *_: message.message)
.with_public(True)
)
self.subscribe(
SubscriptionFieldBuilder("domainChange")
.with_resolver(lambda message, *_: message.message)
.with_require_any_permission([Permissions.domains])
)
self.subscribe(
SubscriptionFieldBuilder("groupChange")
.with_resolver(lambda message, *_: message.message)
.with_require_any_permission([Permissions.groups])
)
self.subscribe(
SubscriptionFieldBuilder("shortUrlChange")
.with_resolver(lambda message, *_: message.message)
.with_require_any_permission([Permissions.short_urls])
)

View File

@ -1,11 +1,15 @@
from collections.abc import Awaitable from collections.abc import Awaitable
from typing import Callable, Union, Optional from typing import Callable, Union, Optional, Coroutine, Any
from api_graphql.service.query_context import QueryContext from api_graphql.service.query_context import QueryContext
from service.permission.permissions_enum import Permissions from service.permission.permissions_enum import Permissions
TRequireAnyPermissions = Optional[list[Permissions]] TRequireAnyPermissions = Optional[list[Permissions]]
TRequireAnyResolvers = list[ TRequireAnyResolvers = list[
Union[Callable[[QueryContext], bool], Awaitable[[QueryContext], bool]] Union[
Callable[[QueryContext], bool],
Awaitable[[QueryContext], bool],
Callable[[QueryContext], Coroutine[Any, Any, bool]],
]
] ]
TRequireAny = tuple[TRequireAnyPermissions, TRequireAnyResolvers] TRequireAny = tuple[TRequireAnyPermissions, TRequireAnyResolvers]

View File

View File

@ -0,0 +1,20 @@
from core.configuration.feature_flags_enum import FeatureFlagsEnum
from data.schemas.system.feature_flag_dao import featureFlagDao
class FeatureFlags:
_flags = {
FeatureFlagsEnum.version_endpoint.value: True, # 15.01.2025
}
@staticmethod
def get_default(key: FeatureFlagsEnum) -> bool:
return FeatureFlags._flags[key.value]
@staticmethod
async def has_feature(key: FeatureFlagsEnum) -> bool:
value = await featureFlagDao.find_by_key(key.value)
if value is None:
return False
return value.value

View File

@ -0,0 +1,6 @@
from enum import Enum
class FeatureFlagsEnum(Enum):
# modules
version_endpoint = "VersionEndpoint"

1
api/src/core/const.py Normal file
View File

@ -0,0 +1 @@
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f %z"

View File

@ -4,9 +4,12 @@ from enum import Enum
from types import NoneType from types import NoneType
from typing import Generic, Optional, Union, TypeVar, Any, Type from typing import Generic, Optional, Union, TypeVar, Any, Type
from core.const import DATETIME_FORMAT
from core.database.abc.db_model_abc import DbModelABC from core.database.abc.db_model_abc import DbModelABC
from core.database.database import Database from core.database.database import Database
from core.get_value import get_value
from core.logger import DBLogger from core.logger import DBLogger
from core.string import camel_to_snake
from core.typing import T, Attribute, AttributeFilters, AttributeSorts from core.typing import T, Attribute, AttributeFilters, AttributeSorts
T_DBM = TypeVar("T_DBM", bound=DbModelABC) T_DBM = TypeVar("T_DBM", bound=DbModelABC)
@ -23,7 +26,11 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
self._default_filter_condition = None self._default_filter_condition = None
self.__attributes: dict[str, type] = {} self.__attributes: dict[str, type] = {}
self.__joins: dict[str, str] = {}
self.__db_names: dict[str, str] = {} self.__db_names: dict[str, str] = {}
self.__foreign_tables: dict[str, str] = {}
self.__date_attributes: set[str] = set() self.__date_attributes: set[str] = set()
self.__ignored_attributes: set[str] = set() self.__ignored_attributes: set[str] = set()
@ -69,6 +76,40 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
if attr_type in [datetime, datetime.datetime]: if attr_type in [datetime, datetime.datetime]:
self.__date_attributes.add(db_name) self.__date_attributes.add(db_name)
def reference(
self,
attr: Attribute,
primary_attr: Attribute,
foreign_attr: Attribute,
table_name: str,
):
"""
Add a reference to another table for the given attribute
:param str primary_attr: Name of the primary key in the foreign object
:param str foreign_attr: Name of the foreign key in the object
:param str table_name: Name of the table to reference
:return:
"""
if table_name == self._table_name:
return
if isinstance(attr, property):
attr = attr.fget.__name__
if isinstance(primary_attr, property):
primary_attr = primary_attr.fget.__name__
primary_attr = primary_attr.lower().replace("_", "")
if isinstance(foreign_attr, property):
foreign_attr = foreign_attr.fget.__name__
foreign_attr = foreign_attr.lower().replace("_", "")
self.__joins[foreign_attr] = (
f"LEFT JOIN {table_name} ON {table_name}.{primary_attr} = {self._table_name}.{foreign_attr}"
)
self.__foreign_tables[attr] = table_name
def to_object(self, result: dict) -> T_DBM: def to_object(self, result: dict) -> T_DBM:
""" """
Convert a result from the database to an object Convert a result from the database to an object
@ -89,8 +130,13 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
return self._model_type(**value_map) return self._model_type(**value_map)
async def count(self) -> int: async def count(self, filters: AttributeFilters = None) -> int:
result = await self._db.select_map(f"SELECT COUNT(*) FROM {self._table_name}") query = f"SELECT COUNT(*) FROM {self._table_name}"
if filters is not None and (not isinstance(filters, list) or len(filters) > 0):
query += f" WHERE {self._build_conditions(filters)}"
result = await self._db.select_map(query)
return result[0]["count"] return result[0]["count"]
async def get_all(self) -> list[T_DBM]: async def get_all(self) -> list[T_DBM]:
@ -370,6 +416,9 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
if isinstance(value, NoneType): if isinstance(value, NoneType):
return "NULL" return "NULL"
if value is None:
return "NULL"
if isinstance(value, Enum): if isinstance(value, Enum):
return str(value.value) return str(value.value)
@ -381,6 +430,12 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
return "ARRAY[]::text[]" return "ARRAY[]::text[]"
return f"ARRAY[{", ".join([DataAccessObjectABC._get_value_sql(x) for x in value])}]" return f"ARRAY[{", ".join([DataAccessObjectABC._get_value_sql(x) for x in value])}]"
if isinstance(value, datetime.datetime):
if value.tzinfo is None:
value = value.replace(tzinfo=datetime.timezone.utc)
return f"'{value.strftime(DATETIME_FORMAT)}'"
return str(value) return str(value)
@staticmethod @staticmethod
@ -409,15 +464,18 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
take: int = None, take: int = None,
skip: int = None, skip: int = None,
) -> str: ) -> str:
query = f"SELECT * FROM {self._table_name}" query = f"SELECT {self._table_name}.* FROM {self._table_name}"
if filters and len(filters) > 0: for join in self.__joins:
query += f" {self.__joins[join]}"
if filters is not None and (not isinstance(filters, list) or len(filters) > 0):
query += f" WHERE {self._build_conditions(filters)}" query += f" WHERE {self._build_conditions(filters)}"
if sorts and len(sorts) > 0: if sorts is not None and (not isinstance(sorts, list) or len(sorts) > 0):
query += f" ORDER BY {self._build_order_by(sorts)}" query += f" ORDER BY {self._build_order_by(sorts)}"
if take: if take is not None:
query += f" LIMIT {take}" query += f" LIMIT {take}"
if skip: if skip is not None:
query += f" OFFSET {skip}" query += f" OFFSET {skip}"
return query return query
@ -435,12 +493,41 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
for attr, values in f.items(): for attr, values in f.items():
if isinstance(attr, property): if isinstance(attr, property):
attr = attr.fget.__name__ attr = attr.fget.__name__
if attr in self.__foreign_tables:
foreign_table = self.__foreign_tables[attr]
conditions.extend(
self._build_foreign_conditions(foreign_table, values)
)
continue
if attr == "fuzzy":
conditions.append(
" OR ".join(
self._build_fuzzy_conditions(
[
(
self.__db_names[x]
if x in self.__db_names
else self.__db_names[camel_to_snake(x)]
)
for x in get_value(values, "fields", list[str])
],
get_value(values, "term", str),
get_value(values, "threshold", int, 5),
)
)
)
continue
db_name = self.__db_names[attr] db_name = self.__db_names[attr]
if isinstance(values, dict): if isinstance(values, dict):
for operator, value in values.items(): for operator, value in values.items():
conditions.append( conditions.append(
self._build_condition(db_name, operator, value) self._build_condition(
f"{self._table_name}.{db_name}", operator, value
)
) )
elif isinstance(values, list): elif isinstance(values, list):
sub_conditions = [] sub_conditions = []
@ -448,18 +535,80 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
if isinstance(value, dict): if isinstance(value, dict):
for operator, val in value.items(): for operator, val in value.items():
sub_conditions.append( sub_conditions.append(
self._build_condition(db_name, operator, val) self._build_condition(
f"{self._table_name}.{db_name}", operator, val
)
) )
else: else:
sub_conditions.append( sub_conditions.append(
f"{db_name} = {self._get_value_sql(value)}" self._get_value_validation_sql(db_name, value)
) )
conditions.append(f"({' OR '.join(sub_conditions)})") conditions.append(f"({' OR '.join(sub_conditions)})")
else: else:
conditions.append(f"{db_name} = {self._get_value_sql(values)}") conditions.append(self._get_value_validation_sql(db_name, values))
return " AND ".join(conditions) return " AND ".join(conditions)
def _build_fuzzy_conditions(
self, fields: list[str], term: str, threshold: int = 10
) -> list[str]:
conditions = []
for field in fields:
conditions.append(
f"levenshtein({field}, '{term}') <= {threshold}"
) # Adjust the threshold as needed
return conditions
def _build_foreign_conditions(self, table: str, values: dict) -> list[str]:
"""
Build SQL conditions for foreign key references
:param table: Foreign table name
:param values: Filter values
:return: List of conditions
"""
conditions = []
for attr, sub_values in values.items():
if isinstance(attr, property):
attr = attr.fget.__name__
if attr in self.__foreign_tables:
foreign_table = self.__foreign_tables[attr]
conditions.extend(
self._build_foreign_conditions(foreign_table, sub_values)
)
continue
db_name = f"{table}.{attr.lower().replace('_', '')}"
if isinstance(sub_values, dict):
for operator, value in sub_values.items():
conditions.append(self._build_condition(db_name, operator, value))
elif isinstance(sub_values, list):
sub_conditions = []
for value in sub_values:
if isinstance(value, dict):
for operator, val in value.items():
sub_conditions.append(
self._build_condition(db_name, operator, val)
)
else:
sub_conditions.append(
self._get_value_validation_sql(db_name, value)
)
conditions.append(f"({' OR '.join(sub_conditions)})")
else:
conditions.append(self._get_value_validation_sql(db_name, sub_values))
return conditions
def _get_value_validation_sql(self, field: str, value: Any):
value = self._get_value_sql(value)
if value == "NULL":
return f"{self._table_name}.{field} IS NULL"
return f"{self._table_name}.{field} = {value}"
def _build_condition(self, db_name: str, operator: str, value: Any) -> str: def _build_condition(self, db_name: str, operator: str, value: Any) -> str:
""" """
Build individual SQL condition based on the operator Build individual SQL condition based on the operator
@ -520,6 +669,13 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
if isinstance(attr, property): if isinstance(attr, property):
attr = attr.fget.__name__ attr = attr.fget.__name__
if attr in self.__foreign_tables:
foreign_table = self.__foreign_tables[attr]
sort_clauses.extend(
self._build_foreign_order_by(foreign_table, direction)
)
continue
match attr: match attr:
case "createdUtc": case "createdUtc":
attr = "created" attr = "created"
@ -537,6 +693,30 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]):
return ", ".join(sort_clauses) return ", ".join(sort_clauses)
def _build_foreign_order_by(self, table: str, direction: str) -> list[str]:
"""
Build SQL order by clause for foreign key references
:param table: Foreign table name
:param direction: Sort direction
:return: List of order by clauses
"""
sort_clauses = []
for attr, sub_direction in direction.items():
if isinstance(attr, property):
attr = attr.fget.__name__
if attr in self.__foreign_tables:
foreign_table = self.__foreign_tables[attr]
sort_clauses.extend(
self._build_foreign_order_by(foreign_table, sub_direction)
)
continue
db_name = f"{table}.{attr.lower().replace('_', '')}"
sort_clauses.append(f"{db_name} {sub_direction.upper()}")
return sort_clauses
@staticmethod @staticmethod
async def _get_editor_id(obj: T_DBM): async def _get_editor_id(obj: T_DBM):
editor_id = obj.editor_id editor_id = obj.editor_id

View File

@ -1,5 +1,4 @@
import ast from typing import Type, Optional
from typing import Type, Optional, Any
from core.typing import T from core.typing import T
@ -22,26 +21,38 @@ def get_value(
:rtype: Optional[T] :rtype: Optional[T]
""" """
if key in source: if key not in source:
value = source[key] return default
if isinstance(value, cast_type):
return value value = source[key]
if isinstance(
try: value,
if cast_type == bool: cast_type if not hasattr(cast_type, "__origin__") else cast_type.__origin__,
return value.lower() in ["true", "1"] ):
return value
if cast_type == list:
subtype = ( try:
cast_type.__args__[0] if hasattr(cast_type, "__args__") else None if cast_type == bool:
) return value.lower() in ["true", "1"]
value = ast.literal_eval(value)
return [ if (
subtype(item) if subtype is not None else item for item in value cast_type if not hasattr(cast_type, "__origin__") else cast_type.__origin__
] ) == list:
if (
return cast_type(value) not (value.startswith("[") and value.endswith("]"))
except (ValueError, TypeError): and list_delimiter not in value
return default ):
else: raise ValueError(
"List values must be enclosed in square brackets or use a delimiter."
)
if value.startswith("[") and value.endswith("]"):
value = value[1:-1]
value = value.split(list_delimiter)
subtype = cast_type.__args__[0] if hasattr(cast_type, "__args__") else None
return [subtype(item) if subtype is not None else item for item in value]
return cast_type(value)
except (ValueError, TypeError):
return default return default

View File

@ -1,8 +1,13 @@
import asyncio
import os import os
import traceback import traceback
from datetime import datetime from datetime import datetime
from api.middleware.request import get_request
from core.environment import Environment
class Logger: class Logger:
_level = "info" _level = "info"
_levels = ["trace", "debug", "info", "warning", "error", "fatal"] _levels = ["trace", "debug", "info", "warning", "error", "fatal"]
@ -54,6 +59,34 @@ class Logger:
else: else:
raise ValueError(f"Invalid log level: {level}") raise ValueError(f"Invalid log level: {level}")
def _get_structured_message(self, level: str, timestamp: str, messages: str) -> str:
structured_message = {
"timestamp": timestamp,
"level": level.upper(),
"source": self.source,
"messages": messages,
}
request = get_request()
if request is not None:
structured_message["request"] = {
"url": str(request.url),
"method": request.method if request.scope == "http" else "ws",
"data": (
asyncio.create_task(request.body())
if request.scope == "http"
else None
),
}
return str(structured_message)
def _write_log_to_file(self, content: str):
self._ensure_file_size()
with open(self.log_file, "a") as log_file:
log_file.write(content + "\n")
log_file.close()
def _log(self, level: str, *messages): def _log(self, level: str, *messages):
try: try:
if self._levels.index(level) < self._levels.index(self._level): if self._levels.index(level) < self._levels.index(self._level):
@ -63,17 +96,18 @@ class Logger:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
formatted_message = f"<{timestamp}> [{level.upper():^7}] [{self._file_prefix:^5}] - [{self.source}]: {' '.join(messages)}" formatted_message = f"<{timestamp}> [{level.upper():^7}] [{self._file_prefix:^5}] - [{self.source}]: {' '.join(messages)}"
self._ensure_file_size() if Environment.get("STRUCTURED_LOGGING", bool, False):
with open(self.log_file, "a") as log_file: self._write_log_to_file(
log_file.write(formatted_message + "\n") self._get_structured_message(level, timestamp, " ".join(messages))
log_file.close() )
else:
self._write_log_to_file(formatted_message)
color = self.COLORS.get(level, self.COLORS["reset"]) print(
reset_color = self.COLORS["reset"] f"{self.COLORS.get(level, self.COLORS["reset"])}{formatted_message}{self.COLORS["reset"]}"
)
print(f"{color}{formatted_message}{reset_color}")
except Exception as e: except Exception as e:
print(f"Error while logging: {e}") print(f"Error while logging: {e} -> {traceback.format_exc()}")
def trace(self, *messages): def trace(self, *messages):
self._log("trace", *messages) self._log("trace", *messages)

9
api/src/core/string.py Normal file
View File

@ -0,0 +1,9 @@
import re
def first_to_lower(s: str) -> str:
return s[0].lower() + s[1:] if s else s
def camel_to_snake(s: str) -> str:
return re.sub(r"(?<!^)(?=[A-Z])", "_", s).lower()

View File

@ -42,17 +42,9 @@ class User(DbModelABC):
@async_property @async_property
async def permissions(self): async def permissions(self):
from data.schemas.permission.role_user_dao import roleUserDao from data.schemas.administration.user_dao import userDao
from data.schemas.permission.role_permission_dao import rolePermissionDao
from data.schemas.permission.permission_dao import permissionDao
x = [ return await userDao.get_permissions(self.id)
rp.permission_id
for x in await roleUserDao.get_by_user_id(self.id)
for rp in await rolePermissionDao.get_by_role_id(x.role_id)
]
return await permissionDao.get_by({"id": {"in": x}})
async def has_permission(self, permission: Permissions) -> bool: async def has_permission(self, permission: Permissions) -> bool:
from data.schemas.administration.user_dao import userDao from data.schemas.administration.user_dao import userDao

View File

@ -33,13 +33,30 @@ class UserDao(DbModelDaoABC[User]):
SELECT COUNT(*) SELECT COUNT(*)
FROM permission.role_users ru FROM permission.role_users ru
JOIN permission.role_permissions rp ON ru.roleId = rp.roleId JOIN permission.role_permissions rp ON ru.roleId = rp.roleId
WHERE ru.userId = {user_id} AND rp.permissionId = {p.id}; WHERE ru.userId = {user_id}
AND rp.permissionId = {p.id}
AND ru.deleted = FALSE
AND rp.deleted = FALSE;
""" """
) )
if result is None or len(result) == 0: if result is None or len(result) == 0:
return False return False
return True return result[0]["count"] > 0
async def get_permissions(self, user_id: int) -> list[Permissions]:
result = await self._db.select_map(
f"""
SELECT p.*
FROM permission.permissions p
JOIN permission.role_permissions rp ON p.id = rp.permissionId
JOIN permission.role_users ru ON rp.roleId = ru.roleId
WHERE ru.userId = {user_id}
AND rp.deleted = FALSE
AND ru.deleted = FALSE;
"""
)
return [Permissions(p["name"]) for p in result]
userDao = UserDao() userDao = UserDao()

View File

@ -0,0 +1,27 @@
from datetime import datetime
from typing import Optional
from core.database.abc.db_model_abc import DbModelABC
from core.typing import SerialId
class Domain(DbModelABC):
def __init__(
self,
id: SerialId,
name: str,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._name = name
@property
def name(self) -> str:
return self._name
@name.setter
def name(self, value: str):
self._name = value

View File

@ -0,0 +1,22 @@
from core.logger import DBLogger
from data.schemas.public.domain import Domain
from data.schemas.public.group import Group
logger = DBLogger(__name__)
from core.database.abc.db_model_dao_abc import DbModelDaoABC
class DomainDao(DbModelDaoABC[Group]):
def __init__(self):
DbModelDaoABC.__init__(self, __name__, Group, "public.domains")
self.attribute(Domain.name, str)
async def get_by_name(self, name: str) -> Group:
result = await self._db.select_map(
f"SELECT * FROM {self._table_name} WHERE Name = '{name}'"
)
return self.to_object(result[0])
domainDao = DomainDao()

View File

@ -17,5 +17,19 @@ class GroupDao(DbModelDaoABC[Group]):
) )
return self.to_object(result[0]) return self.to_object(result[0])
async def get_roles(self, group_id: int):
result = await self._db.select_map(
f"""
SELECT r.*
FROM permission.roles r
JOIN public.group_role_assignments gra ON r.id = gra.roleId
WHERE gra.groupId = {group_id}
AND gra.deleted = FALSE
"""
)
from data.schemas.permission.role_dao import roleDao
return [roleDao.to_object(x) for x in result]
groupDao = GroupDao() groupDao = GroupDao()

View File

@ -0,0 +1,43 @@
from datetime import datetime
from typing import Optional
from async_property import async_property
from core.database.abc.db_model_abc import DbModelABC
from core.typing import SerialId
class GroupRoleAssignment(DbModelABC):
def __init__(
self,
id: SerialId,
group_id: SerialId,
role_id: SerialId,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._group_id = group_id
self._role_id = role_id
@property
def group_id(self) -> SerialId:
return self._group_id
@async_property
async def group(self):
from data.schemas.public.group_dao import groupDao
return await groupDao.get_by_id(self._group_id)
@property
def role_id(self) -> SerialId:
return self._role_id
@async_property
async def role(self):
from data.schemas.permission.role_dao import roleDao
return await roleDao.get_by_id(self._role_id)

View File

@ -0,0 +1,37 @@
from core.logger import DBLogger
from data.schemas.public.group_role_assignment import GroupRoleAssignment
logger = DBLogger(__name__)
from core.database.abc.db_model_dao_abc import DbModelDaoABC
class GroupRoleAssignmentDao(DbModelDaoABC[GroupRoleAssignment]):
def __init__(self):
DbModelDaoABC.__init__(
self, __name__, GroupRoleAssignment, "public.group_role_assignments"
)
self.attribute(GroupRoleAssignment.group_id, int)
self.attribute(GroupRoleAssignment.role_id, int)
async def get_by_group_id(
self, gid: int, with_deleted=False
) -> list[GroupRoleAssignment]:
f = [{GroupRoleAssignment.group_id: gid}]
if not with_deleted:
f.append({GroupRoleAssignment.deleted: False})
return await self.find_by(f)
async def get_by_role_id(
self, rid: int, with_deleted=False
) -> list[GroupRoleAssignment]:
f = [{GroupRoleAssignment.role_id: rid}]
if not with_deleted:
f.append({GroupRoleAssignment.deleted: False})
return await self.find_by(f)
groupRoleAssignmentDao = GroupRoleAssignmentDao()

View File

@ -16,6 +16,7 @@ class ShortUrl(DbModelABC):
target_url: str, target_url: str,
description: Optional[str], description: Optional[str],
group_id: Optional[SerialId], group_id: Optional[SerialId],
domain_id: Optional[SerialId],
loading_screen: Optional[str] = None, loading_screen: Optional[str] = None,
deleted: bool = False, deleted: bool = False,
editor_id: Optional[SerialId] = None, editor_id: Optional[SerialId] = None,
@ -27,6 +28,10 @@ class ShortUrl(DbModelABC):
self._target_url = target_url self._target_url = target_url
self._description = description self._description = description
self._group_id = group_id self._group_id = group_id
self._domain_id = domain_id
if loading_screen is None or loading_screen == "":
loading_screen = False
self._loading_screen = loading_screen self._loading_screen = loading_screen
@property @property
@ -70,6 +75,23 @@ class ShortUrl(DbModelABC):
return await groupDao.get_by_id(self._group_id) return await groupDao.get_by_id(self._group_id)
@property
def domain_id(self) -> SerialId:
return self._domain_id
@domain_id.setter
def domain_id(self, value: SerialId):
self._domain_id = value
@async_property
async def domain(self) -> Optional[Group]:
if self._domain_id is None:
return None
from data.schemas.public.domain_dao import domainDao
return await domainDao.get_by_id(self._domain_id)
@async_property @async_property
async def visit_count(self) -> int: async def visit_count(self) -> int:
from data.schemas.public.short_url_visit_dao import shortUrlVisitDao from data.schemas.public.short_url_visit_dao import shortUrlVisitDao

View File

@ -13,6 +13,7 @@ class ShortUrlDao(DbModelDaoABC[ShortUrl]):
self.attribute(ShortUrl.target_url, str) self.attribute(ShortUrl.target_url, str)
self.attribute(ShortUrl.description, str) self.attribute(ShortUrl.description, str)
self.attribute(ShortUrl.group_id, int) self.attribute(ShortUrl.group_id, int)
self.attribute(ShortUrl.domain_id, int)
self.attribute(ShortUrl.loading_screen, bool) self.attribute(ShortUrl.loading_screen, bool)

View File

@ -0,0 +1,48 @@
from datetime import datetime
from typing import Optional, Union
from async_property import async_property
from core.database.abc.db_model_abc import DbModelABC
from core.typing import SerialId
class UserSetting(DbModelABC):
def __init__(
self,
id: SerialId,
user_id: SerialId,
key: str,
value: str,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._user_id = user_id
self._key = key
self._value = value
@property
def user_id(self) -> SerialId:
return self._user_id
@async_property
async def user(self):
from data.schemas.administration.user_dao import userDao
return await userDao.get_by_id(self._user_id)
@property
def key(self) -> str:
return self._key
@property
def value(self) -> str:
return self._value
@value.setter
def value(self, value: Union[str, int, float, bool]):
self._value = str(value)

View File

@ -0,0 +1,24 @@
from core.database.abc.db_model_dao_abc import DbModelDaoABC
from core.logger import DBLogger
from data.schemas.administration.user import User
from data.schemas.public.user_setting import UserSetting
logger = DBLogger(__name__)
class UserSettingDao(DbModelDaoABC[UserSetting]):
def __init__(self):
DbModelDaoABC.__init__(self, __name__, UserSetting, "public.user_settings")
self.attribute(UserSetting.user_id, int)
self.attribute(UserSetting.key, str)
self.attribute(UserSetting.value, str)
async def find_by_key(self, user: User, key: str) -> UserSetting:
return await self.find_single_by(
[{UserSetting.user_id: user.id}, {UserSetting.key: key}]
)
userSettingsDao = UserSettingDao()

View File

@ -0,0 +1,34 @@
from datetime import datetime
from typing import Optional
from core.database.abc.db_model_abc import DbModelABC
from core.typing import SerialId
class FeatureFlag(DbModelABC):
def __init__(
self,
id: SerialId,
key: str,
value: bool,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._key = key
self._value = value
@property
def key(self) -> str:
return self._key
@property
def value(self) -> bool:
return self._value
@value.setter
def value(self, value: bool):
self._value = value

View File

@ -0,0 +1,20 @@
from core.database.abc.db_model_dao_abc import DbModelDaoABC
from core.logger import DBLogger
from data.schemas.system.feature_flag import FeatureFlag
logger = DBLogger(__name__)
class FeatureFlagDao(DbModelDaoABC[FeatureFlag]):
def __init__(self):
DbModelDaoABC.__init__(self, __name__, FeatureFlag, "system.feature_flags")
self.attribute(FeatureFlag.key, str)
self.attribute(FeatureFlag.value, bool)
async def find_by_key(self, key: str) -> FeatureFlag:
return await self.find_single_by({FeatureFlag.key: key})
featureFlagDao = FeatureFlagDao()

View File

@ -0,0 +1,34 @@
from datetime import datetime
from typing import Optional, Union
from core.database.abc.db_model_abc import DbModelABC
from core.typing import SerialId
class Setting(DbModelABC):
def __init__(
self,
id: SerialId,
key: str,
value: str,
deleted: bool = False,
editor_id: Optional[SerialId] = None,
created: Optional[datetime] = None,
updated: Optional[datetime] = None,
):
DbModelABC.__init__(self, id, deleted, editor_id, created, updated)
self._key = key
self._value = value
@property
def key(self) -> str:
return self._key
@property
def value(self) -> str:
return self._value
@value.setter
def value(self, value: Union[str, int, float, bool]):
self._value = str(value)

View File

@ -0,0 +1,20 @@
from core.database.abc.db_model_dao_abc import DbModelDaoABC
from core.logger import DBLogger
from data.schemas.system.setting import Setting
logger = DBLogger(__name__)
class SettingDao(DbModelDaoABC[Setting]):
def __init__(self):
DbModelDaoABC.__init__(self, __name__, Setting, "system.settings")
self.attribute(Setting.key, str)
self.attribute(Setting.value, str)
async def find_by_key(self, key: str) -> Setting:
return await self.find_single_by({Setting.key: key})
settingsDao = SettingDao()

View File

@ -0,0 +1,32 @@
CREATE
SCHEMA IF NOT EXISTS public;
-- groups
CREATE TABLE IF NOT EXISTS public.domains
(
Id SERIAL PRIMARY KEY,
Name VARCHAR(255) NOT NULL,
-- for history
Deleted BOOLEAN NOT NULL DEFAULT FALSE,
EditorId INT NULL REFERENCES administration.users (Id),
CreatedUtc timestamptz NOT NULL DEFAULT NOW(),
UpdatedUtc timestamptz NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS public.domains_history
(
LIKE public.domains
);
CREATE TRIGGER domains_history_trigger
BEFORE INSERT OR UPDATE OR DELETE
ON public.domains
FOR EACH ROW
EXECUTE FUNCTION public.history_trigger_function();
ALTER TABLE public.short_urls
ADD COLUMN domainId INT NULL REFERENCES public.domains (Id);
ALTER TABLE public.short_urls_history
ADD COLUMN domainId INT NULL REFERENCES public.domains (Id);

View File

@ -0,0 +1,27 @@
CREATE
SCHEMA IF NOT EXISTS public;
-- groups
CREATE TABLE IF NOT EXISTS public.group_role_assignments
(
Id SERIAL PRIMARY KEY,
GroupId INT NOT NULL REFERENCES public.groups (Id),
RoleId INT NOT NULL REFERENCES permission.roles (Id),
-- for history
Deleted BOOLEAN NOT NULL DEFAULT FALSE,
EditorId INT NULL REFERENCES administration.users (Id),
CreatedUtc timestamptz NOT NULL DEFAULT NOW(),
UpdatedUtc timestamptz NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS public.group_role_assignments_history
(
LIKE public.group_role_assignments
);
CREATE TRIGGER group_role_assignment_history_trigger
BEFORE INSERT OR UPDATE OR DELETE
ON public.group_role_assignments
FOR EACH ROW
EXECUTE FUNCTION public.history_trigger_function();

View File

@ -0,0 +1,24 @@
CREATE SCHEMA IF NOT EXISTS system;
CREATE TABLE IF NOT EXISTS system.settings
(
Id SERIAL PRIMARY KEY,
Key TEXT NOT NULL,
Value TEXT NOT NULL,
-- for history
Deleted BOOLEAN NOT NULL DEFAULT FALSE,
EditorId INT NULL REFERENCES administration.users (Id),
CreatedUtc timestamptz NOT NULL DEFAULT NOW(),
UpdatedUtc timestamptz NOT NULL DEFAULT NOW()
);
CREATE TABLE system.settings_history
(
LIKE system.settings
);
CREATE TRIGGER ip_list_history_trigger
BEFORE INSERT OR UPDATE OR DELETE
ON system.settings
FOR EACH ROW
EXECUTE FUNCTION public.history_trigger_function();

View File

@ -0,0 +1,24 @@
CREATE SCHEMA IF NOT EXISTS system;
CREATE TABLE IF NOT EXISTS system.feature_flags
(
Id SERIAL PRIMARY KEY,
Key TEXT NOT NULL,
Value BOOLEAN NOT NULL,
-- for history
Deleted BOOLEAN NOT NULL DEFAULT FALSE,
EditorId INT NULL REFERENCES administration.users (Id),
CreatedUtc timestamptz NOT NULL DEFAULT NOW(),
UpdatedUtc timestamptz NOT NULL DEFAULT NOW()
);
CREATE TABLE system.feature_flags_history
(
LIKE system.feature_flags
);
CREATE TRIGGER ip_list_history_trigger
BEFORE INSERT OR UPDATE OR DELETE
ON system.feature_flags
FOR EACH ROW
EXECUTE FUNCTION public.history_trigger_function();

View File

@ -0,0 +1,25 @@
CREATE SCHEMA IF NOT EXISTS public;
CREATE TABLE IF NOT EXISTS public.user_settings
(
Id SERIAL PRIMARY KEY,
Key TEXT NOT NULL,
Value TEXT NOT NULL,
UserId INT NOT NULL REFERENCES administration.users (Id) ON DELETE CASCADE,
-- for history
Deleted BOOLEAN NOT NULL DEFAULT FALSE,
EditorId INT NULL REFERENCES administration.users (Id),
CreatedUtc timestamptz NOT NULL DEFAULT NOW(),
UpdatedUtc timestamptz NOT NULL DEFAULT NOW()
);
CREATE TABLE public.user_settings_history
(
LIKE public.user_settings
);
CREATE TRIGGER ip_list_history_trigger
BEFORE INSERT OR UPDATE OR DELETE
ON public.user_settings
FOR EACH ROW
EXECUTE FUNCTION public.history_trigger_function();

View File

@ -0,0 +1,40 @@
from core.configuration.feature_flags import FeatureFlags
from core.configuration.feature_flags_enum import FeatureFlagsEnum
from core.logger import DBLogger
from data.abc.data_seeder_abc import DataSeederABC
from data.schemas.system.feature_flag import FeatureFlag
from data.schemas.system.feature_flag_dao import featureFlagDao
logger = DBLogger(__name__)
class FeatureFlagsSeeder(DataSeederABC):
def __init__(self):
DataSeederABC.__init__(self)
async def seed(self):
logger.info("Seeding feature flags")
feature_flags = await featureFlagDao.get_all()
feature_flag_keys = [x.key for x in feature_flags]
possible_feature_flags = {
x.value: FeatureFlags.get_default(x) for x in FeatureFlagsEnum
}
to_create = [
FeatureFlag(0, x, possible_feature_flags[x])
for x in possible_feature_flags.keys()
if x not in feature_flag_keys
]
if len(to_create) > 0:
await featureFlagDao.create_many(to_create)
to_create_dicts = {x.key: x.value for x in to_create}
logger.debug(f"Created feature flags: {to_create_dicts}")
to_delete = [
x for x in feature_flags if x.key not in possible_feature_flags.keys()
]
if len(to_delete) > 0:
await featureFlagDao.delete_many(to_delete, hard_delete=True)
to_delete_dicts = {x.key: x.value for x in to_delete}
logger.debug(f"Deleted feature flags: {to_delete_dicts}")

Some files were not shown because too many files have changed in this diff Show More