WIP: dev into master #184

Draft
edraft wants to merge 121 commits from dev into master
3 changed files with 75 additions and 50 deletions
Showing only changes of commit ea3055527c - Show all commits

View File

@@ -1,9 +1,8 @@
from keycloak import KeycloakAuthenticationError
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from cpl.api.api_logger import APILogger
from cpl.api.error import Unauthorized
from cpl.api.middleware.request import get_request
from cpl.api.router import Router
from cpl.auth.keycloak import KeycloakClient
from cpl.dependency import ServiceProviderABC
@@ -11,27 +10,18 @@ from cpl.dependency import ServiceProviderABC
_logger = APILogger(__name__)
class AuthenticationMiddleware(BaseHTTPMiddleware):
class AuthenticationMiddleware:
@classmethod
async def _verify_login(cls, token: str) -> bool:
keycloak = ServiceProviderABC.get_global_service(KeycloakClient)
try:
token_info = keycloak.introspect(token)
return token_info.get("active", False)
except KeycloakAuthenticationError as e:
_logger.debug(f"Keycloak authentication error: {e}")
return False
except Exception as e:
_logger.error(f"Unexpected error during token verification: {e}")
return False
def __init__(self, app):
self._app = app
async def dispatch(self, request: Request, call_next):
async def __call__(self, scope, receive, send):
request = get_request()
url = request.url.path
if url not in Router.get_auth_required_routes():
_logger.trace(f"No authentication required for {url}")
return await call_next(request)
return await self._app(scope, receive, send)
if not request.headers.get("Authorization"):
_logger.debug(f"Unauthorized access to {url}, missing Authorization header")
@@ -49,4 +39,17 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
# check user exists in db, if not create
# unauthorized if user is deleted
return await call_next(request)
return await self._app(scope, receive, send)
@classmethod
async def _verify_login(cls, token: str) -> bool:
keycloak = ServiceProviderABC.get_global_service(KeycloakClient)
try:
token_info = keycloak.introspect(token)
return token_info.get("active", False)
except KeycloakAuthenticationError as e:
_logger.debug(f"Keycloak authentication error: {e}")
return False
except Exception as e:
_logger.error(f"Unexpected error during token verification: {e}")
return False

View File

@@ -1,21 +1,41 @@
import time
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp, Receive, Scope, Send
from cpl.api.api_logger import APILogger
from cpl.api.middleware.request import get_request
_logger = APILogger(__name__)
class LoggingMiddleware:
def __init__(self, app: ASGIApp):
self.app = app
class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = get_request()
await self._log_request(request)
response = await call_next(request)
await self._log_after_request(request, response)
start_time = time.time()
return response
response_body = b""
status_code = 500
async def send_wrapper(message):
nonlocal response_body, status_code
if message["type"] == "http.response.start":
status_code = message["status"]
if message["type"] == "http.response.body":
response_body += message.get("body", b"")
await send(message)
await self.app(scope, receive, send_wrapper)
duration = (time.time() - start_time) * 1000
await self._log_after_request(request, status_code, duration)
@staticmethod
def _filter_relevant_headers(headers: dict) -> dict:
@@ -33,7 +53,7 @@ class LoggingMiddleware(BaseHTTPMiddleware):
@classmethod
async def _log_request(cls, request: Request):
_logger.debug(
f"Request {request.state.request_id}: {request.method}@{request.url.path} from {request.client.host}"
f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}"
)
from cpl.core.ctx.user_context import get_user
@@ -55,11 +75,10 @@ class LoggingMiddleware(BaseHTTPMiddleware):
),
}
_logger.trace(f"Request {request.state.request_id}: {request_info}")
_logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
@staticmethod
async def _log_after_request(request: Request, response: Response):
duration = (time.time() - request.state.start_time) * 1000
async def _log_after_request(request: Request, status_code: int, duration: float):
_logger.info(
f"Request finished {request.state.request_id}: {response.status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
)
f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
)

View File

@@ -3,7 +3,7 @@ from contextvars import ContextVar
from typing import Optional, Union
from uuid import uuid4
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.websockets import WebSocket
from cpl.api.api_logger import APILogger
@@ -14,35 +14,38 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
_logger = APILogger(__name__)
class RequestMiddleware(BaseHTTPMiddleware):
_request_token = {}
_user_token = {}
class RequestMiddleware:
def __init__(self, app):
self._app = app
self._ctx_token = None
@classmethod
async def set_request_data(cls, request: TRequest):
async def __call__(self, scope, receive, send):
request = Request(scope, receive, send)
await self.set_request_data(request)
try:
await self._app(scope, receive, send)
finally:
await self.clean_request_data()
async def set_request_data(self, request: TRequest):
request.state.request_id = uuid4()
request.state.start_time = time.time()
_logger.trace(f"Set new current request: {request.state.request_id}")
cls._request_token[request.state.request_id] = _request_context.set(request)
self._ctx_token = _request_context.set(request)
@classmethod
async def clean_request_data(cls):
async def clean_request_data(self):
request = get_request()
if request is None:
return
if request.state.request_id in cls._request_token:
_request_context.reset(cls._request_token[request.state.request_id])
if self._ctx_token is None:
return
async def dispatch(self, request: TRequest, call_next):
await self.set_request_data(request)
try:
response = await call_next(request)
return response
finally:
await self.clean_request_data()
_logger.trace(f"Clearing current request: {request.state.request_id}")
_request_context.reset(self._ctx_token)
def get_request() -> Optional[Union[TRequest, WebSocket]]:
return _request_context.get()
return _request_context.get()