diff --git a/src/cpl-api/cpl/api/middleware/authentication.py b/src/cpl-api/cpl/api/middleware/authentication.py index 4a491042..98481b3b 100644 --- a/src/cpl-api/cpl/api/middleware/authentication.py +++ b/src/cpl-api/cpl/api/middleware/authentication.py @@ -1,9 +1,8 @@ from keycloak import KeycloakAuthenticationError -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.requests import Request from cpl.api.api_logger import APILogger from cpl.api.error import Unauthorized +from cpl.api.middleware.request import get_request from cpl.api.router import Router from cpl.auth.keycloak import KeycloakClient from cpl.dependency import ServiceProviderABC @@ -11,27 +10,18 @@ from cpl.dependency import ServiceProviderABC _logger = APILogger(__name__) -class AuthenticationMiddleware(BaseHTTPMiddleware): +class AuthenticationMiddleware: - @classmethod - async def _verify_login(cls, token: str) -> bool: - keycloak = ServiceProviderABC.get_global_service(KeycloakClient) - try: - token_info = keycloak.introspect(token) - return token_info.get("active", False) - except KeycloakAuthenticationError as e: - _logger.debug(f"Keycloak authentication error: {e}") - return False - except Exception as e: - _logger.error(f"Unexpected error during token verification: {e}") - return False + def __init__(self, app): + self._app = app - async def dispatch(self, request: Request, call_next): + async def __call__(self, scope, receive, send): + request = get_request() url = request.url.path if url not in Router.get_auth_required_routes(): _logger.trace(f"No authentication required for {url}") - return await call_next(request) + return await self._app(scope, receive, send) if not request.headers.get("Authorization"): _logger.debug(f"Unauthorized access to {url}, missing Authorization header") @@ -49,4 +39,17 @@ class AuthenticationMiddleware(BaseHTTPMiddleware): # check user exists in db, if not create # unauthorized if user is deleted - return await call_next(request) + return await self._app(scope, receive, send) + + @classmethod + async def _verify_login(cls, token: str) -> bool: + keycloak = ServiceProviderABC.get_global_service(KeycloakClient) + try: + token_info = keycloak.introspect(token) + return token_info.get("active", False) + except KeycloakAuthenticationError as e: + _logger.debug(f"Keycloak authentication error: {e}") + return False + except Exception as e: + _logger.error(f"Unexpected error during token verification: {e}") + return False diff --git a/src/cpl-api/cpl/api/middleware/logging.py b/src/cpl-api/cpl/api/middleware/logging.py index 96919d0d..6a9f888e 100644 --- a/src/cpl-api/cpl/api/middleware/logging.py +++ b/src/cpl-api/cpl/api/middleware/logging.py @@ -1,21 +1,41 @@ import time -from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request -from starlette.responses import Response +from starlette.types import ASGIApp, Receive, Scope, Send from cpl.api.api_logger import APILogger +from cpl.api.middleware.request import get_request _logger = APILogger(__name__) +class LoggingMiddleware: + def __init__(self, app: ASGIApp): + self.app = app -class LoggingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): + async def __call__(self, scope: Scope, receive: Receive, send: Send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + request = get_request() await self._log_request(request) - response = await call_next(request) - await self._log_after_request(request, response) + start_time = time.time() - return response + response_body = b"" + status_code = 500 + + async def send_wrapper(message): + nonlocal response_body, status_code + if message["type"] == "http.response.start": + status_code = message["status"] + if message["type"] == "http.response.body": + response_body += message.get("body", b"") + await send(message) + + await self.app(scope, receive, send_wrapper) + + duration = (time.time() - start_time) * 1000 + await self._log_after_request(request, status_code, duration) @staticmethod def _filter_relevant_headers(headers: dict) -> dict: @@ -33,7 +53,7 @@ class LoggingMiddleware(BaseHTTPMiddleware): @classmethod async def _log_request(cls, request: Request): _logger.debug( - f"Request {request.state.request_id}: {request.method}@{request.url.path} from {request.client.host}" + f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}" ) from cpl.core.ctx.user_context import get_user @@ -55,11 +75,10 @@ class LoggingMiddleware(BaseHTTPMiddleware): ), } - _logger.trace(f"Request {request.state.request_id}: {request_info}") + _logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}") @staticmethod - async def _log_after_request(request: Request, response: Response): - duration = (time.time() - request.state.start_time) * 1000 + async def _log_after_request(request: Request, status_code: int, duration: float): _logger.info( - f"Request finished {request.state.request_id}: {response.status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms" - ) + f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms" + ) \ No newline at end of file diff --git a/src/cpl-api/cpl/api/middleware/request.py b/src/cpl-api/cpl/api/middleware/request.py index 16c48ea7..7ec8dc61 100644 --- a/src/cpl-api/cpl/api/middleware/request.py +++ b/src/cpl-api/cpl/api/middleware/request.py @@ -3,7 +3,7 @@ from contextvars import ContextVar from typing import Optional, Union from uuid import uuid4 -from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request from starlette.websockets import WebSocket from cpl.api.api_logger import APILogger @@ -14,35 +14,38 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa _logger = APILogger(__name__) -class RequestMiddleware(BaseHTTPMiddleware): - _request_token = {} - _user_token = {} +class RequestMiddleware: + def __init__(self, app): + self._app = app + self._ctx_token = None - @classmethod - async def set_request_data(cls, request: TRequest): + async def __call__(self, scope, receive, send): + request = Request(scope, receive, send) + await self.set_request_data(request) + + try: + await self._app(scope, receive, send) + finally: + await self.clean_request_data() + + async def set_request_data(self, request: TRequest): request.state.request_id = uuid4() request.state.start_time = time.time() _logger.trace(f"Set new current request: {request.state.request_id}") - cls._request_token[request.state.request_id] = _request_context.set(request) + self._ctx_token = _request_context.set(request) - @classmethod - async def clean_request_data(cls): + async def clean_request_data(self): request = get_request() if request is None: return - if request.state.request_id in cls._request_token: - _request_context.reset(cls._request_token[request.state.request_id]) + if self._ctx_token is None: + return - async def dispatch(self, request: TRequest, call_next): - await self.set_request_data(request) - try: - response = await call_next(request) - return response - finally: - await self.clean_request_data() + _logger.trace(f"Clearing current request: {request.state.request_id}") + _request_context.reset(self._ctx_token) def get_request() -> Optional[Union[TRequest, WebSocket]]: - return _request_context.get() + return _request_context.get() \ No newline at end of file