open-redirect/api/src/api/route.py

141 lines
4.4 KiB
Python

import functools
from functools import wraps
from inspect import iscoroutinefunction
from typing import Callable, Union, Optional
from starlette.requests import Request
from starlette.routing import Route as StarletteRoute
from api.errors import unauthorized
from api.middleware.request import get_request
from api.route_api_key_extension import RouteApiKeyExtension
from api.route_user_extension import RouteUserExtension
from core.environment import Environment
from data.schemas.administration.api_key import ApiKey
from data.schemas.administration.user import User
class Route(RouteUserExtension, RouteApiKeyExtension):
registered_routes: list[StarletteRoute] = []
@classmethod
async def _get_auth_type(
cls, request: Request, auth_header: str
) -> Optional[Union[User, ApiKey]]:
if auth_header.startswith("Bearer "):
return await cls.get_user()
elif auth_header.startswith("API-Key "):
return await cls.get_api_key(request)
elif (
auth_header.startswith("DEV-User ")
and Environment.get_environment() == "development"
):
return await cls.get_dev_user()
return None
@classmethod
async def get_authenticated_user_or_api_key(cls) -> Union[User, ApiKey]:
request = get_request()
if request is None:
raise ValueError("No request found")
auth_header = request.headers.get("Authorization", None)
if not auth_header:
raise Exception("No Authorization header found")
user_or_api_key = await cls._get_auth_type(request, auth_header)
if user_or_api_key is None:
raise Exception("Invalid Authorization header")
return user_or_api_key
@classmethod
async def get_authenticated_user_or_api_key_or_default(
cls,
) -> Optional[Union[User, ApiKey]]:
request = get_request()
if request is None:
return None
auth_header = request.headers.get("Authorization", None)
if not auth_header:
return None
return await cls._get_auth_type(request, auth_header)
@classmethod
async def is_authorized(cls, request: Request) -> bool:
if request is None:
return False
auth_header = request.headers.get("Authorization", None)
if not auth_header:
return False
if auth_header.startswith("Bearer "):
return await cls.verify_login(request)
elif auth_header.startswith("API-Key "):
return await cls._verify_api_key(request)
elif (
auth_header.startswith("DEV-User ")
and Environment.get_environment() == "development"
):
user = await cls.get_dev_user()
return user is not None
return False
@classmethod
def authorize(
cls,
f: Callable = None,
skip_in_dev=False,
by_api_key=False,
):
if f is None:
return functools.partial(
cls.authorize, skip_in_dev=skip_in_dev, by_api_key=by_api_key
)
@wraps(f)
async def decorator(request: Request, *args, **kwargs):
if skip_in_dev and Environment.get_environment() == "development":
if iscoroutinefunction(f):
return await f(request, *args, **kwargs)
return f(request, *args, **kwargs)
if not await cls.is_authorized(request):
return unauthorized()
if iscoroutinefunction(f):
return await f(request, *args, **kwargs)
return f(request, *args, **kwargs)
return decorator
@classmethod
def route(cls, path=None, **kwargs):
def inner(fn):
cls.registered_routes.append(StarletteRoute(path, fn, **kwargs))
return fn
return inner
@classmethod
def get(cls, path=None, **kwargs):
return cls.route(path, methods=["GET"], **kwargs)
@classmethod
def post(cls, path=None, **kwargs):
return cls.route(path, methods=["POST"], **kwargs)
@classmethod
def head(cls, path=None, **kwargs):
return cls.route(path, methods=["HEAD"], **kwargs)
@classmethod
def put(cls, path=None, **kwargs):
return cls.route(path, methods=["PUT"], **kwargs)
@classmethod
def delete(cls, path=None, **kwargs):
return cls.route(path, methods=["DELETE"], **kwargs)