Changed middleware to asgi
This commit is contained in:
@@ -1,9 +1,8 @@
|
|||||||
from keycloak import KeycloakAuthenticationError
|
from keycloak import KeycloakAuthenticationError
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.requests import Request
|
|
||||||
|
|
||||||
from cpl.api.api_logger import APILogger
|
from cpl.api.api_logger import APILogger
|
||||||
from cpl.api.error import Unauthorized
|
from cpl.api.error import Unauthorized
|
||||||
|
from cpl.api.middleware.request import get_request
|
||||||
from cpl.api.router import Router
|
from cpl.api.router import Router
|
||||||
from cpl.auth.keycloak import KeycloakClient
|
from cpl.auth.keycloak import KeycloakClient
|
||||||
from cpl.dependency import ServiceProviderABC
|
from cpl.dependency import ServiceProviderABC
|
||||||
@@ -11,27 +10,18 @@ from cpl.dependency import ServiceProviderABC
|
|||||||
_logger = APILogger(__name__)
|
_logger = APILogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationMiddleware(BaseHTTPMiddleware):
|
class AuthenticationMiddleware:
|
||||||
|
|
||||||
@classmethod
|
def __init__(self, app):
|
||||||
async def _verify_login(cls, token: str) -> bool:
|
self._app = app
|
||||||
keycloak = ServiceProviderABC.get_global_service(KeycloakClient)
|
|
||||||
try:
|
|
||||||
token_info = keycloak.introspect(token)
|
|
||||||
return token_info.get("active", False)
|
|
||||||
except KeycloakAuthenticationError as e:
|
|
||||||
_logger.debug(f"Keycloak authentication error: {e}")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
_logger.error(f"Unexpected error during token verification: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def __call__(self, scope, receive, send):
|
||||||
|
request = get_request()
|
||||||
url = request.url.path
|
url = request.url.path
|
||||||
|
|
||||||
if url not in Router.get_auth_required_routes():
|
if url not in Router.get_auth_required_routes():
|
||||||
_logger.trace(f"No authentication required for {url}")
|
_logger.trace(f"No authentication required for {url}")
|
||||||
return await call_next(request)
|
return await self._app(scope, receive, send)
|
||||||
|
|
||||||
if not request.headers.get("Authorization"):
|
if not request.headers.get("Authorization"):
|
||||||
_logger.debug(f"Unauthorized access to {url}, missing Authorization header")
|
_logger.debug(f"Unauthorized access to {url}, missing Authorization header")
|
||||||
@@ -49,4 +39,17 @@ class AuthenticationMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
# check user exists in db, if not create
|
# check user exists in db, if not create
|
||||||
# unauthorized if user is deleted
|
# unauthorized if user is deleted
|
||||||
return await call_next(request)
|
return await self._app(scope, receive, send)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _verify_login(cls, token: str) -> bool:
|
||||||
|
keycloak = ServiceProviderABC.get_global_service(KeycloakClient)
|
||||||
|
try:
|
||||||
|
token_info = keycloak.introspect(token)
|
||||||
|
return token_info.get("active", False)
|
||||||
|
except KeycloakAuthenticationError as e:
|
||||||
|
_logger.debug(f"Keycloak authentication error: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
_logger.error(f"Unexpected error during token verification: {e}")
|
||||||
|
return False
|
||||||
|
|||||||
@@ -1,21 +1,41 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
from cpl.api.api_logger import APILogger
|
from cpl.api.api_logger import APILogger
|
||||||
|
from cpl.api.middleware.request import get_request
|
||||||
|
|
||||||
_logger = APILogger(__name__)
|
_logger = APILogger(__name__)
|
||||||
|
|
||||||
|
class LoggingMiddleware:
|
||||||
|
def __init__(self, app: ASGIApp):
|
||||||
|
self.app = app
|
||||||
|
|
||||||
class LoggingMiddleware(BaseHTTPMiddleware):
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||||
async def dispatch(self, request: Request, call_next):
|
if scope["type"] != "http":
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
request = get_request()
|
||||||
await self._log_request(request)
|
await self._log_request(request)
|
||||||
response = await call_next(request)
|
start_time = time.time()
|
||||||
await self._log_after_request(request, response)
|
|
||||||
|
|
||||||
return response
|
response_body = b""
|
||||||
|
status_code = 500
|
||||||
|
|
||||||
|
async def send_wrapper(message):
|
||||||
|
nonlocal response_body, status_code
|
||||||
|
if message["type"] == "http.response.start":
|
||||||
|
status_code = message["status"]
|
||||||
|
if message["type"] == "http.response.body":
|
||||||
|
response_body += message.get("body", b"")
|
||||||
|
await send(message)
|
||||||
|
|
||||||
|
await self.app(scope, receive, send_wrapper)
|
||||||
|
|
||||||
|
duration = (time.time() - start_time) * 1000
|
||||||
|
await self._log_after_request(request, status_code, duration)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _filter_relevant_headers(headers: dict) -> dict:
|
def _filter_relevant_headers(headers: dict) -> dict:
|
||||||
@@ -33,7 +53,7 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def _log_request(cls, request: Request):
|
async def _log_request(cls, request: Request):
|
||||||
_logger.debug(
|
_logger.debug(
|
||||||
f"Request {request.state.request_id}: {request.method}@{request.url.path} from {request.client.host}"
|
f"Request {getattr(request.state, 'request_id', '-')}: {request.method}@{request.url.path} from {request.client.host}"
|
||||||
)
|
)
|
||||||
|
|
||||||
from cpl.core.ctx.user_context import get_user
|
from cpl.core.ctx.user_context import get_user
|
||||||
@@ -55,11 +75,10 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
_logger.trace(f"Request {request.state.request_id}: {request_info}")
|
_logger.trace(f"Request {getattr(request.state, 'request_id', '-')}: {request_info}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _log_after_request(request: Request, response: Response):
|
async def _log_after_request(request: Request, status_code: int, duration: float):
|
||||||
duration = (time.time() - request.state.start_time) * 1000
|
|
||||||
_logger.info(
|
_logger.info(
|
||||||
f"Request finished {request.state.request_id}: {response.status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
|
f"Request finished {getattr(request.state, 'request_id', '-')}: {status_code}-{request.method}@{request.url.path} from {request.client.host} in {duration:.2f}ms"
|
||||||
)
|
)
|
||||||
@@ -3,7 +3,7 @@ from contextvars import ContextVar
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.requests import Request
|
||||||
from starlette.websockets import WebSocket
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
from cpl.api.api_logger import APILogger
|
from cpl.api.api_logger import APILogger
|
||||||
@@ -14,35 +14,38 @@ _request_context: ContextVar[Union[TRequest, None]] = ContextVar("request", defa
|
|||||||
_logger = APILogger(__name__)
|
_logger = APILogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RequestMiddleware(BaseHTTPMiddleware):
|
class RequestMiddleware:
|
||||||
_request_token = {}
|
def __init__(self, app):
|
||||||
_user_token = {}
|
self._app = app
|
||||||
|
self._ctx_token = None
|
||||||
|
|
||||||
@classmethod
|
async def __call__(self, scope, receive, send):
|
||||||
async def set_request_data(cls, request: TRequest):
|
request = Request(scope, receive, send)
|
||||||
|
await self.set_request_data(request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._app(scope, receive, send)
|
||||||
|
finally:
|
||||||
|
await self.clean_request_data()
|
||||||
|
|
||||||
|
async def set_request_data(self, request: TRequest):
|
||||||
request.state.request_id = uuid4()
|
request.state.request_id = uuid4()
|
||||||
request.state.start_time = time.time()
|
request.state.start_time = time.time()
|
||||||
_logger.trace(f"Set new current request: {request.state.request_id}")
|
_logger.trace(f"Set new current request: {request.state.request_id}")
|
||||||
|
|
||||||
cls._request_token[request.state.request_id] = _request_context.set(request)
|
self._ctx_token = _request_context.set(request)
|
||||||
|
|
||||||
@classmethod
|
async def clean_request_data(self):
|
||||||
async def clean_request_data(cls):
|
|
||||||
request = get_request()
|
request = get_request()
|
||||||
if request is None:
|
if request is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if request.state.request_id in cls._request_token:
|
if self._ctx_token is None:
|
||||||
_request_context.reset(cls._request_token[request.state.request_id])
|
return
|
||||||
|
|
||||||
async def dispatch(self, request: TRequest, call_next):
|
_logger.trace(f"Clearing current request: {request.state.request_id}")
|
||||||
await self.set_request_data(request)
|
_request_context.reset(self._ctx_token)
|
||||||
try:
|
|
||||||
response = await call_next(request)
|
|
||||||
return response
|
|
||||||
finally:
|
|
||||||
await self.clean_request_data()
|
|
||||||
|
|
||||||
|
|
||||||
def get_request() -> Optional[Union[TRequest, WebSocket]]:
|
def get_request() -> Optional[Union[TRequest, WebSocket]]:
|
||||||
return _request_context.get()
|
return _request_context.get()
|
||||||
Reference in New Issue
Block a user