WIP: dev into master #184
@@ -1,9 +1,8 @@
|
||||
from keycloak import KeycloakAuthenticationError
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
|
||||
from cpl.api.api_logger import APILogger
|
||||
from cpl.api.error import Unauthorized
|
||||
from cpl.api.middleware.request import get_request
|
||||
from cpl.api.router import Router
|
||||
from cpl.auth.keycloak import KeycloakClient
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
@@ -11,27 +10,18 @@ from cpl.dependency import ServiceProviderABC
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
|
||||
class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
class AuthenticationMiddleware:
|
||||
|
||||
@classmethod
|
||||
async def _verify_login(cls, token: str) -> bool:
|
||||
keycloak = ServiceProviderABC.get_global_service(KeycloakClient)
|
||||
try:
|
||||
token_info = keycloak.introspect(token)
|
||||
return token_info.get("active", False)
|
||||
except KeycloakAuthenticationError as e:
|
||||
_logger.debug(f"Keycloak authentication error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
_logger.error(f"Unexpected error during token verification: {e}")
|
||||
return False
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
async def __call__(self, scope, receive, send):
|
||||
request = get_request()
|
||||
url = request.url.path
|
||||
|
||||
if url not in Router.get_auth_required_routes():
|
||||
_logger.trace(f"No authentication required for {url}")
|
||||
return await call_next(request)
|
||||
return await self._app(scope, receive, send)
|
||||
|
||||
if not request.headers.get("Authorization"):
|
||||
_logger.debug(f"Unauthorized access to {url}, missing Authorization header")
|
||||
@@ -49,4 +39,17 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# check user exists in db, if not create
|
||||
# unauthorized if user is deleted
|
||||
return await call_next(request)
|
||||
return await self._app(scope, receive, send)
|
||||
|
||||
@classmethod
|
||||
async def _verify_login(cls, token: str) -> bool:
|
||||
keycloak = ServiceProviderABC.get_global_service(KeycloakClient)
|
||||
try:
|
||||
token_info = keycloak.introspect(token)
|
||||
return token_info.get("active", False)
|
||||
except KeycloakAuthenticationError as e:
|
||||
_logger.debug(f"Keycloak authentication error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
_logger.error(f"Unexpected error during token verification: {e}")
|
||||
return False
|
||||
|
||||
@@ -1,21 +1,41 @@
|
||||
import time
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from cpl.api.api_logger import APILogger
|
||||
from cpl.api.middleware.request import get_request
|
||||
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
class LoggingMiddleware:
|
||||
def __init__(self, app: ASGIApp):
|
||||
self.app = app
|
||||
|
||||
class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = get_request()
|
||||
await self._log_request(request)
|
||||
response = await call_next(request)
|
||||
await self._log_after_request(request, response)
|
||||
start_time = time.time()
|
||||
|
||||
return response
|
||||
response_body = b""
|
||||
status_code = 500
|
||||
|
||||
async def send_wrapper(message):
|
||||
nonlocal response_body, status_code
|
||||
if message["type"] == "http.response.start":
|
||||
status_code = message["status"]
|
||||
if message["type"] == "http.response.body":
|
||||
response_body += message.get("body", b"")
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
|
||||
duration = (time.time() - start_time) * 1000
|
||||
await self._log_after_request(request, status_code, duration)
|
||||
|
||||
@staticmethod
|
||||
def _filter_relevant_headers(headers: dict) -> dict:
|
||||
@@ -33,7 +53,7 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
@classmethod
|
||||
async def _log_request(cls, request: Request):
|
||||
_logger.debug(
|
||||
f"Request {request.state.request_id}: {request.method}@{request.url.path} from {request.client.host}"
|
||||
f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}"
|
||||
)
|
||||
|
||||
from cpl.core.ctx.user_context import get_user
|
||||
@@ -55,11 +75,10 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
),
|
||||
}
|
||||
|
||||
_logger.trace(f"Request {request.state.request_id}: {request_info}")
|
||||
_logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
|
||||
|
||||
@staticmethod
|
||||
async def _log_after_request(request: Request, response: Response):
|
||||
duration = (time.time() - request.state.start_time) * 1000
|
||||
async def _log_after_request(request: Request, status_code: int, duration: float):
|
||||
_logger.info(
|
||||
f"Request finished {request.state.request_id}: {response.status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
|
||||
)
|
||||
f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
|
||||
)
|
||||
@@ -3,7 +3,7 @@ from contextvars import ContextVar
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
from cpl.api.api_logger import APILogger
|
||||
@@ -14,35 +14,38 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
|
||||
class RequestMiddleware(BaseHTTPMiddleware):
|
||||
_request_token = {}
|
||||
_user_token = {}
|
||||
class RequestMiddleware:
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self._ctx_token = None
|
||||
|
||||
@classmethod
|
||||
async def set_request_data(cls, request: TRequest):
|
||||
async def __call__(self, scope, receive, send):
|
||||
request = Request(scope, receive, send)
|
||||
await self.set_request_data(request)
|
||||
|
||||
try:
|
||||
await self._app(scope, receive, send)
|
||||
finally:
|
||||
await self.clean_request_data()
|
||||
|
||||
async def set_request_data(self, request: TRequest):
|
||||
request.state.request_id = uuid4()
|
||||
request.state.start_time = time.time()
|
||||
_logger.trace(f"Set new current request: {request.state.request_id}")
|
||||
|
||||
cls._request_token[request.state.request_id] = _request_context.set(request)
|
||||
self._ctx_token = _request_context.set(request)
|
||||
|
||||
@classmethod
|
||||
async def clean_request_data(cls):
|
||||
async def clean_request_data(self):
|
||||
request = get_request()
|
||||
if request is None:
|
||||
return
|
||||
|
||||
if request.state.request_id in cls._request_token:
|
||||
_request_context.reset(cls._request_token[request.state.request_id])
|
||||
if self._ctx_token is None:
|
||||
return
|
||||
|
||||
async def dispatch(self, request: TRequest, call_next):
|
||||
await self.set_request_data(request)
|
||||
try:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
finally:
|
||||
await self.clean_request_data()
|
||||
_logger.trace(f"Clearing current request: {request.state.request_id}")
|
||||
_request_context.reset(self._ctx_token)
|
||||
|
||||
|
||||
def get_request() -> Optional[Union[TRequest, WebSocket]]:
|
||||
return _request_context.get()
|
||||
return _request_context.get()
|
||||
Reference in New Issue
Block a user