141 lines
4.4 KiB
Python
141 lines
4.4 KiB
Python
import functools
|
|
from functools import wraps
|
|
from inspect import iscoroutinefunction
|
|
from typing import Callable, Union, Optional
|
|
|
|
from starlette.requests import Request
|
|
from starlette.routing import Route as StarletteRoute
|
|
|
|
from api.errors import unauthorized
|
|
from api.middleware.request import get_request
|
|
from api.route_api_key_extension import RouteApiKeyExtension
|
|
from api.route_user_extension import RouteUserExtension
|
|
from core.environment import Environment
|
|
from data.schemas.administration.api_key import ApiKey
|
|
from data.schemas.administration.user import User
|
|
|
|
|
|
class Route(RouteUserExtension, RouteApiKeyExtension):
|
|
registered_routes: list[StarletteRoute] = []
|
|
|
|
@classmethod
|
|
async def _get_auth_type(
|
|
cls, request: Request, auth_header: str
|
|
) -> Optional[Union[User, ApiKey]]:
|
|
if auth_header.startswith("Bearer "):
|
|
return await cls.get_user()
|
|
elif auth_header.startswith("API-Key "):
|
|
return await cls.get_api_key(request)
|
|
elif (
|
|
auth_header.startswith("DEV-User ")
|
|
and Environment.get_environment() == "development"
|
|
):
|
|
return await cls.get_dev_user()
|
|
return None
|
|
|
|
@classmethod
|
|
async def get_authenticated_user_or_api_key(cls) -> Union[User, ApiKey]:
|
|
request = get_request()
|
|
if request is None:
|
|
raise ValueError("No request found")
|
|
|
|
auth_header = request.headers.get("Authorization", None)
|
|
if not auth_header:
|
|
raise Exception("No Authorization header found")
|
|
|
|
user_or_api_key = await cls._get_auth_type(request, auth_header)
|
|
if user_or_api_key is None:
|
|
raise Exception("Invalid Authorization header")
|
|
return user_or_api_key
|
|
|
|
@classmethod
|
|
async def get_authenticated_user_or_api_key_or_default(
|
|
cls,
|
|
) -> Optional[Union[User, ApiKey]]:
|
|
request = get_request()
|
|
if request is None:
|
|
return None
|
|
|
|
auth_header = request.headers.get("Authorization", None)
|
|
if not auth_header:
|
|
return None
|
|
|
|
return await cls._get_auth_type(request, auth_header)
|
|
|
|
@classmethod
|
|
async def is_authorized(cls, request: Request) -> bool:
|
|
if request is None:
|
|
return False
|
|
|
|
auth_header = request.headers.get("Authorization", None)
|
|
if not auth_header:
|
|
return False
|
|
|
|
if auth_header.startswith("Bearer "):
|
|
return await cls.verify_login(request)
|
|
elif auth_header.startswith("API-Key "):
|
|
return await cls._verify_api_key(request)
|
|
elif (
|
|
auth_header.startswith("DEV-User ")
|
|
and Environment.get_environment() == "development"
|
|
):
|
|
user = await cls.get_dev_user()
|
|
return user is not None
|
|
return False
|
|
|
|
@classmethod
|
|
def authorize(
|
|
cls,
|
|
f: Callable = None,
|
|
skip_in_dev=False,
|
|
by_api_key=False,
|
|
):
|
|
if f is None:
|
|
return functools.partial(
|
|
cls.authorize, skip_in_dev=skip_in_dev, by_api_key=by_api_key
|
|
)
|
|
|
|
@wraps(f)
|
|
async def decorator(request: Request, *args, **kwargs):
|
|
if skip_in_dev and Environment.get_environment() == "development":
|
|
if iscoroutinefunction(f):
|
|
return await f(request, *args, **kwargs)
|
|
return f(request, *args, **kwargs)
|
|
|
|
if not await cls.is_authorized(request):
|
|
return unauthorized()
|
|
|
|
if iscoroutinefunction(f):
|
|
return await f(request, *args, **kwargs)
|
|
return f(request, *args, **kwargs)
|
|
|
|
return decorator
|
|
|
|
@classmethod
|
|
def route(cls, path=None, **kwargs):
|
|
def inner(fn):
|
|
cls.registered_routes.append(StarletteRoute(path, fn, **kwargs))
|
|
return fn
|
|
|
|
return inner
|
|
|
|
@classmethod
|
|
def get(cls, path=None, **kwargs):
|
|
return cls.route(path, methods=["GET"], **kwargs)
|
|
|
|
@classmethod
|
|
def post(cls, path=None, **kwargs):
|
|
return cls.route(path, methods=["POST"], **kwargs)
|
|
|
|
@classmethod
|
|
def head(cls, path=None, **kwargs):
|
|
return cls.route(path, methods=["HEAD"], **kwargs)
|
|
|
|
@classmethod
|
|
def put(cls, path=None, **kwargs):
|
|
return cls.route(path, methods=["PUT"], **kwargs)
|
|
|
|
@classmethod
|
|
def delete(cls, path=None, **kwargs):
|
|
return cls.route(path, methods=["DELETE"], **kwargs)
|