WIP: dev into master #184

Draft
edraft wants to merge 121 commits from dev into master
9 changed files with 111 additions and 13 deletions
Showing only changes of commit eceff6128b - Show all commits

View File

@@ -1,9 +1,15 @@
from http.client import HTTPException from http.client import HTTPException
from starlette.responses import JSONResponse
class APIError(HTTPException): class APIError(HTTPException):
status_code = 500 status_code = 500
@classmethod
def response(cls):
return JSONResponse({"error": cls.__name__}, status_code=cls.status_code)
class Unauthorized(APIError): class Unauthorized(APIError):
status_code = 401 status_code = 401

View File

@@ -0,0 +1,49 @@
from keycloak import KeycloakAuthenticationError
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from cpl.api.api_logger import APILogger
from cpl.api.error import Unauthorized
from cpl.api.router import Router
from cpl.auth.keycloak import KeycloakClient
from cpl.dependency import ServiceProviderABC
_logger = APILogger(__name__)
class AuthenticationMiddleware(BaseHTTPMiddleware):
@classmethod
async def _verify_login(cls, token: str) -> bool:
keycloak = ServiceProviderABC.get_global_service(KeycloakClient)
try:
user_info = keycloak.userinfo(token)
if not user_info:
return False
except KeycloakAuthenticationError:
return False
return True
async def dispatch(self, request: Request, call_next):
url = request.url.path
if url not in Router.get_auth_required_routes():
_logger.trace(f"No authentication required for {url}")
return await call_next(request)
if not request.headers.get("Authorization"):
_logger.debug(f"Unauthorized access to {url}, missing Authorization header")
return Unauthorized(f"Missing header Authorization").response()
auth_header = request.headers.get("Authorization", None)
if not auth_header or not auth_header.startswith("Bearer "):
return Unauthorized("Invalid Authorization header").response()
if not await self._verify_login(auth_header.split("Bearer ")[1]):
_logger.debug(f"Unauthorized access to {url}, invalid token")
return Unauthorized("Invalid token").response()
# check user exists in db, if not create
# unauthorized if user is deleted
return await call_next(request)

View File

@@ -3,11 +3,35 @@ from starlette.routing import Route
class Router: class Router:
_registered_routes: list[Route] = [] _registered_routes: list[Route] = []
_auth_required: list[str] = []
@classmethod @classmethod
def get_routes(cls) -> list[Route]: def get_routes(cls) -> list[Route]:
return cls._registered_routes return cls._registered_routes
@classmethod
def get_auth_required_routes(cls) -> list[str]:
return cls._auth_required
@classmethod
def authenticate(cls):
"""
Decorator to mark a route as requiring authentication.
Usage:
@Route.authenticate()
@Route.get("/example")
async def example_endpoint(request: TRequest):
...
"""
def inner(fn):
route_path = getattr(fn, "_route_path", None)
if route_path and route_path not in cls._auth_required:
cls._auth_required.append(route_path)
return fn
return inner
@classmethod @classmethod
def route(cls, path=None, **kwargs): def route(cls, path=None, **kwargs):
def inner(fn): def inner(fn):

View File

@@ -13,6 +13,7 @@ from starlette.types import ExceptionHandler
from cpl.api.api_logger import APILogger from cpl.api.api_logger import APILogger
from cpl.api.api_settings import ApiSettings from cpl.api.api_settings import ApiSettings
from cpl.api.error import APIError from cpl.api.error import APIError
from cpl.api.middleware.authentication import AuthenticationMiddleware
from cpl.api.middleware.logging import LoggingMiddleware from cpl.api.middleware.logging import LoggingMiddleware
from cpl.api.middleware.request import RequestMiddleware from cpl.api.middleware.request import RequestMiddleware
from cpl.api.router import Router from cpl.api.router import Router
@@ -24,7 +25,6 @@ from cpl.dependency.service_provider_abc import ServiceProviderABC
_logger = APILogger("API") _logger = APILogger("API")
class WebApp(ApplicationABC): class WebApp(ApplicationABC):
def __init__(self, services: ServiceProviderABC): def __init__(self, services: ServiceProviderABC):
super().__init__(services) super().__init__(services)
@@ -37,18 +37,22 @@ class WebApp(ApplicationABC):
Middleware(RequestMiddleware), Middleware(RequestMiddleware),
Middleware(LoggingMiddleware), Middleware(LoggingMiddleware),
] ]
self._exception_handlers: Mapping[Any, ExceptionHandler] = {Exception: self.handle_exception} self._exception_handlers: Mapping[Any, ExceptionHandler] = {
Exception: self._handle_exception,
APIError: self._handle_exception,
}
@staticmethod @staticmethod
async def handle_exception(request: Request, exc: Exception): async def _handle_exception(request: Request, exc: Exception):
if isinstance(exc, APIError):
_logger.error(exc)
return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
if hasattr(request.state, "request_id"): if hasattr(request.state, "request_id"):
_logger.error(f"Request {request.state.request_id}", exc) _logger.error(f"Request {request.state.request_id}", exc)
else: else:
_logger.error("Request unknown", exc) _logger.error("Request unknown", exc)
if isinstance(exc, APIError):
return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
return JSONResponse({"error": str(exc)}, status_code=500) return JSONResponse({"error": str(exc)}, status_code=500)
def _get_allowed_origins(self): def _get_allowed_origins(self):
@@ -96,7 +100,15 @@ class WebApp(ApplicationABC):
self._check_for_app() self._check_for_app()
assert path is not None, "path must not be None" assert path is not None, "path must not be None"
assert fn is not None, "fn must not be None" assert fn is not None, "fn must not be None"
assert method in ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"], "method must be a valid HTTP method" assert method in [
"GET",
"POST",
"PUT",
"DELETE",
"PATCH",
"OPTIONS",
"HEAD",
], "method must be a valid HTTP method"
self._routes.append(Route(path, fn, methods=[method], **kwargs)) self._routes.append(Route(path, fn, methods=[method], **kwargs))
return self return self
@@ -105,15 +117,20 @@ class WebApp(ApplicationABC):
if isinstance(middleware, Middleware): if isinstance(middleware, Middleware):
self._middleware.append(middleware) self._middleware.append(middleware)
elif callable(middleware): elif callable(middleware):
self._middleware.append(Middleware(middleware)) self._middleware.append(Middleware(middleware))
else: else:
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable") raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
return self return self
def with_authentication(self):
self.with_middleware(AuthenticationMiddleware)
return self
def with_authorization(self):
pass
def main(self): def main(self):
_logger.debug(f"Preparing API") _logger.debug(f"Preparing API")
if self._app is None: if self._app is None:

View File

@@ -15,6 +15,7 @@ def main():
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET") app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET")
app.with_routes_directory("routes") app.with_routes_directory("routes")
app.with_logging() app.with_logging()
app.with_authentication()
app.run() app.run()

View File

@@ -7,6 +7,7 @@ from cpl.core.log import Logger
from service import PingService from service import PingService
@Router.authenticate()
@Router.get(f"/ping") @Router.get(f"/ping")
async def ping(r: Request, ping: PingService, logger: Logger): async def ping(r: Request, ping: PingService, logger: Logger):
logger.info(f"Ping: {ping}") logger.info(f"Ping: {ping}")