WIP: dev into master #184

Draft
edraft wants to merge 121 commits from dev into master
3 changed files with 75 additions and 50 deletions
Showing only changes of commit ea3055527c - Show all commits

View File

@@ -1,9 +1,8 @@
from keycloak import KeycloakAuthenticationError from keycloak import KeycloakAuthenticationError
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from cpl.api.api_logger import APILogger from cpl.api.api_logger import APILogger
from cpl.api.error import Unauthorized from cpl.api.error import Unauthorized
from cpl.api.middleware.request import get_request
from cpl.api.router import Router from cpl.api.router import Router
from cpl.auth.keycloak import KeycloakClient from cpl.auth.keycloak import KeycloakClient
from cpl.dependency import ServiceProviderABC from cpl.dependency import ServiceProviderABC
@@ -11,27 +10,18 @@ from cpl.dependency import ServiceProviderABC
_logger = APILogger(__name__) _logger = APILogger(__name__)
class AuthenticationMiddleware(BaseHTTPMiddleware): class AuthenticationMiddleware:
@classmethod def __init__(self, app):
async def _verify_login(cls, token: str) -> bool: self._app = app
keycloak = ServiceProviderABC.get_global_service(KeycloakClient)
try:
token_info = keycloak.introspect(token)
return token_info.get("active", False)
except KeycloakAuthenticationError as e:
_logger.debug(f"Keycloak authentication error: {e}")
return False
except Exception as e:
_logger.error(f"Unexpected error during token verification: {e}")
return False
async def dispatch(self, request: Request, call_next): async def __call__(self, scope, receive, send):
request = get_request()
url = request.url.path url = request.url.path
if url not in Router.get_auth_required_routes(): if url not in Router.get_auth_required_routes():
_logger.trace(f"No authentication required for {url}") _logger.trace(f"No authentication required for {url}")
return await call_next(request) return await self._app(scope, receive, send)
if not request.headers.get("Authorization"): if not request.headers.get("Authorization"):
_logger.debug(f"Unauthorized access to {url}, missing Authorization header") _logger.debug(f"Unauthorized access to {url}, missing Authorization header")
@@ -49,4 +39,17 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
# check user exists in db, if not create # check user exists in db, if not create
# unauthorized if user is deleted # unauthorized if user is deleted
return await call_next(request) return await self._app(scope, receive, send)
@classmethod
async def _verify_login(cls, token: str) -> bool:
keycloak = ServiceProviderABC.get_global_service(KeycloakClient)
try:
token_info = keycloak.introspect(token)
return token_info.get("active", False)
except KeycloakAuthenticationError as e:
_logger.debug(f"Keycloak authentication error: {e}")
return False
except Exception as e:
_logger.error(f"Unexpected error during token verification: {e}")
return False

View File

@@ -1,21 +1,41 @@
import time import time
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.types import ASGIApp, Receive, Scope, Send
from cpl.api.api_logger import APILogger from cpl.api.api_logger import APILogger
from cpl.api.middleware.request import get_request
_logger = APILogger(__name__) _logger = APILogger(__name__)
class LoggingMiddleware:
def __init__(self, app: ASGIApp):
self.app = app
class LoggingMiddleware(BaseHTTPMiddleware): async def __call__(self, scope: Scope, receive: Receive, send: Send):
async def dispatch(self, request: Request, call_next): if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = get_request()
await self._log_request(request) await self._log_request(request)
response = await call_next(request) start_time = time.time()
await self._log_after_request(request, response)
return response response_body = b""
status_code = 500
async def send_wrapper(message):
nonlocal response_body, status_code
if message["type"] == "http.response.start":
status_code = message["status"]
if message["type"] == "http.response.body":
response_body += message.get("body", b"")
await send(message)
await self.app(scope, receive, send_wrapper)
duration = (time.time() - start_time) * 1000
await self._log_after_request(request, status_code, duration)
@staticmethod @staticmethod
def _filter_relevant_headers(headers: dict) -> dict: def _filter_relevant_headers(headers: dict) -> dict:
@@ -33,7 +53,7 @@ class LoggingMiddleware(BaseHTTPMiddleware):
@classmethod @classmethod
async def _log_request(cls, request: Request): async def _log_request(cls, request: Request):
_logger.debug( _logger.debug(
f"Request {request.state.request_id}: {request.method}@{request.url.path} from {request.client.host}" f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}"
) )
from cpl.core.ctx.user_context import get_user from cpl.core.ctx.user_context import get_user
@@ -55,11 +75,10 @@ class LoggingMiddleware(BaseHTTPMiddleware):
), ),
} }
_logger.trace(f"Request {request.state.request_id}: {request_info}") _logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
@staticmethod @staticmethod
async def _log_after_request(request: Request, response: Response): async def _log_after_request(request: Request, status_code: int, duration: float):
duration = (time.time() - request.state.start_time) * 1000
_logger.info( _logger.info(
f"Request finished {request.state.request_id}: {response.status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms" f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
) )

View File

@@ -3,7 +3,7 @@ from contextvars import ContextVar
from typing import Optional, Union from typing import Optional, Union
from uuid import uuid4 from uuid import uuid4
from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
from cpl.api.api_logger import APILogger from cpl.api.api_logger import APILogger
@@ -14,34 +14,37 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
_logger = APILogger(__name__) _logger = APILogger(__name__)
class RequestMiddleware(BaseHTTPMiddleware): class RequestMiddleware:
_request_token = {} def __init__(self, app):
_user_token = {} self._app = app
self._ctx_token = None
@classmethod async def __call__(self, scope, receive, send):
async def set_request_data(cls, request: TRequest): request = Request(scope, receive, send)
await self.set_request_data(request)
try:
await self._app(scope, receive, send)
finally:
await self.clean_request_data()
async def set_request_data(self, request: TRequest):
request.state.request_id = uuid4() request.state.request_id = uuid4()
request.state.start_time = time.time() request.state.start_time = time.time()
_logger.trace(f"Set new current request: {request.state.request_id}") _logger.trace(f"Set new current request: {request.state.request_id}")
cls._request_token[request.state.request_id] = _request_context.set(request) self._ctx_token = _request_context.set(request)
@classmethod async def clean_request_data(self):
async def clean_request_data(cls):
request = get_request() request = get_request()
if request is None: if request is None:
return return
if request.state.request_id in cls._request_token: if self._ctx_token is None:
_request_context.reset(cls._request_token[request.state.request_id]) return
async def dispatch(self, request: TRequest, call_next): _logger.trace(f"Clearing current request: {request.state.request_id}")
await self.set_request_data(request) _request_context.reset(self._ctx_token)
try:
response = await call_next(request)
return response
finally:
await self.clean_request_data()
def get_request() -> Optional[Union[TRequest, WebSocket]]: def get_request() -> Optional[Union[TRequest, WebSocket]]: