[WIP] Authentication
All checks were successful
Build on push / prepare (push) Successful in 9s
Build on push / core (push) Successful in 19s
Build on push / query (push) Successful in 19s
Build on push / dependency (push) Successful in 17s
Build on push / application (push) Successful in 16s
Build on push / mail (push) Successful in 16s
Build on push / translation (push) Successful in 18s
Build on push / database (push) Successful in 22s
Build on push / auth (push) Successful in 15s
Build on push / api (push) Successful in 13s
All checks were successful
Build on push / prepare (push) Successful in 9s
Build on push / core (push) Successful in 19s
Build on push / query (push) Successful in 19s
Build on push / dependency (push) Successful in 17s
Build on push / application (push) Successful in 16s
Build on push / mail (push) Successful in 16s
Build on push / translation (push) Successful in 18s
Build on push / database (push) Successful in 22s
Build on push / auth (push) Successful in 15s
Build on push / api (push) Successful in 13s
This commit is contained in:
@@ -1,9 +1,15 @@
|
||||
from http.client import HTTPException
|
||||
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
class APIError(HTTPException):
|
||||
status_code = 500
|
||||
|
||||
@classmethod
|
||||
def response(cls):
|
||||
return JSONResponse({"error": cls.__name__}, status_code=cls.status_code)
|
||||
|
||||
|
||||
class Unauthorized(APIError):
|
||||
status_code = 401
|
||||
|
||||
49
src/cpl-api/cpl/api/middleware/authentication.py
Normal file
49
src/cpl-api/cpl/api/middleware/authentication.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from keycloak import KeycloakAuthenticationError
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
|
||||
from cpl.api.api_logger import APILogger
|
||||
from cpl.api.error import Unauthorized
|
||||
from cpl.api.router import Router
|
||||
from cpl.auth.keycloak import KeycloakClient
|
||||
from cpl.dependency import ServiceProviderABC
|
||||
|
||||
_logger = APILogger(__name__)
|
||||
|
||||
|
||||
class AuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
@classmethod
|
||||
async def _verify_login(cls, token: str) -> bool:
|
||||
keycloak = ServiceProviderABC.get_global_service(KeycloakClient)
|
||||
try:
|
||||
user_info = keycloak.userinfo(token)
|
||||
if not user_info:
|
||||
return False
|
||||
except KeycloakAuthenticationError:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
url = request.url.path
|
||||
|
||||
if url not in Router.get_auth_required_routes():
|
||||
_logger.trace(f"No authentication required for {url}")
|
||||
return await call_next(request)
|
||||
|
||||
if not request.headers.get("Authorization"):
|
||||
_logger.debug(f"Unauthorized access to {url}, missing Authorization header")
|
||||
return Unauthorized(f"Missing header Authorization").response()
|
||||
|
||||
|
||||
auth_header = request.headers.get("Authorization", None)
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
return Unauthorized("Invalid Authorization header").response()
|
||||
|
||||
if not await self._verify_login(auth_header.split("Bearer ")[1]):
|
||||
_logger.debug(f"Unauthorized access to {url}, invalid token")
|
||||
return Unauthorized("Invalid token").response()
|
||||
|
||||
# check user exists in db, if not create
|
||||
# unauthorized if user is deleted
|
||||
return await call_next(request)
|
||||
@@ -3,11 +3,35 @@ from starlette.routing import Route
|
||||
|
||||
class Router:
|
||||
_registered_routes: list[Route] = []
|
||||
_auth_required: list[str] = []
|
||||
|
||||
@classmethod
|
||||
def get_routes(cls) -> list[Route]:
|
||||
return cls._registered_routes
|
||||
|
||||
@classmethod
|
||||
def get_auth_required_routes(cls) -> list[str]:
|
||||
return cls._auth_required
|
||||
|
||||
@classmethod
|
||||
def authenticate(cls):
|
||||
"""
|
||||
Decorator to mark a route as requiring authentication.
|
||||
Usage:
|
||||
@Route.authenticate()
|
||||
@Route.get("/example")
|
||||
async def example_endpoint(request: TRequest):
|
||||
...
|
||||
"""
|
||||
|
||||
def inner(fn):
|
||||
route_path = getattr(fn, "_route_path", None)
|
||||
if route_path and route_path not in cls._auth_required:
|
||||
cls._auth_required.append(route_path)
|
||||
return fn
|
||||
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def route(cls, path=None, **kwargs):
|
||||
def inner(fn):
|
||||
|
||||
@@ -13,6 +13,7 @@ from starlette.types import ExceptionHandler
|
||||
from cpl.api.api_logger import APILogger
|
||||
from cpl.api.api_settings import ApiSettings
|
||||
from cpl.api.error import APIError
|
||||
from cpl.api.middleware.authentication import AuthenticationMiddleware
|
||||
from cpl.api.middleware.logging import LoggingMiddleware
|
||||
from cpl.api.middleware.request import RequestMiddleware
|
||||
from cpl.api.router import Router
|
||||
@@ -24,7 +25,6 @@ from cpl.dependency.service_provider_abc import ServiceProviderABC
|
||||
_logger = APILogger("API")
|
||||
|
||||
|
||||
|
||||
class WebApp(ApplicationABC):
|
||||
def __init__(self, services: ServiceProviderABC):
|
||||
super().__init__(services)
|
||||
@@ -37,18 +37,22 @@ class WebApp(ApplicationABC):
|
||||
Middleware(RequestMiddleware),
|
||||
Middleware(LoggingMiddleware),
|
||||
]
|
||||
self._exception_handlers: Mapping[Any, ExceptionHandler] = {Exception: self.handle_exception}
|
||||
self._exception_handlers: Mapping[Any, ExceptionHandler] = {
|
||||
Exception: self._handle_exception,
|
||||
APIError: self._handle_exception,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def handle_exception(request: Request, exc: Exception):
|
||||
async def _handle_exception(request: Request, exc: Exception):
|
||||
if isinstance(exc, APIError):
|
||||
_logger.error(exc)
|
||||
return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
|
||||
|
||||
if hasattr(request.state, "request_id"):
|
||||
_logger.error(f"Request {request.state.request_id}", exc)
|
||||
else:
|
||||
_logger.error("Request unknown", exc)
|
||||
|
||||
if isinstance(exc, APIError):
|
||||
return JSONResponse({"error": str(exc)}, status_code=exc.status_code)
|
||||
|
||||
return JSONResponse({"error": str(exc)}, status_code=500)
|
||||
|
||||
def _get_allowed_origins(self):
|
||||
@@ -96,7 +100,15 @@ class WebApp(ApplicationABC):
|
||||
self._check_for_app()
|
||||
assert path is not None, "path must not be None"
|
||||
assert fn is not None, "fn must not be None"
|
||||
assert method in ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"], "method must be a valid HTTP method"
|
||||
assert method in [
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"DELETE",
|
||||
"PATCH",
|
||||
"OPTIONS",
|
||||
"HEAD",
|
||||
], "method must be a valid HTTP method"
|
||||
self._routes.append(Route(path, fn, methods=[method], **kwargs))
|
||||
return self
|
||||
|
||||
@@ -105,15 +117,20 @@ class WebApp(ApplicationABC):
|
||||
|
||||
if isinstance(middleware, Middleware):
|
||||
self._middleware.append(middleware)
|
||||
|
||||
elif callable(middleware):
|
||||
self._middleware.append(Middleware(middleware))
|
||||
else:
|
||||
raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable")
|
||||
|
||||
|
||||
return self
|
||||
|
||||
def with_authentication(self):
|
||||
self.with_middleware(AuthenticationMiddleware)
|
||||
return self
|
||||
|
||||
def with_authorization(self):
|
||||
pass
|
||||
|
||||
def main(self):
|
||||
_logger.debug(f"Preparing API")
|
||||
if self._app is None:
|
||||
|
||||
@@ -77,7 +77,7 @@ class ServiceProvider(ServiceProviderABC):
|
||||
|
||||
return implementations
|
||||
|
||||
def _build_by_signature(self, sig: Signature, origin_service_type: type=None) -> list[R]:
|
||||
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[R]:
|
||||
params = []
|
||||
for param in sig.parameters.items():
|
||||
parameter = param[1]
|
||||
|
||||
@@ -36,7 +36,7 @@ class ServiceProviderABC(ABC):
|
||||
return cls._provider.get_services(instance_type, *args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _build_by_signature(self, sig: Signature, origin_service_type: type=None) -> list[R]: ...
|
||||
def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[R]: ...
|
||||
|
||||
@abstractmethod
|
||||
def _build_service(self, service_type: type, *args, **kwargs) -> object:
|
||||
|
||||
@@ -15,6 +15,7 @@ def main():
|
||||
app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET")
|
||||
app.with_routes_directory("routes")
|
||||
app.with_logging()
|
||||
app.with_authentication()
|
||||
|
||||
app.run()
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from cpl.core.log import Logger
|
||||
from service import PingService
|
||||
|
||||
|
||||
@Router.authenticate()
|
||||
@Router.get(f"/ping")
|
||||
async def ping(r: Request, ping: PingService, logger: Logger):
|
||||
logger.info(f"Ping: {ping}")
|
||||
|
||||
Reference in New Issue
Block a user