diff --git a/src/cpl-api/cpl/api/error.py b/src/cpl-api/cpl/api/error.py index 89d5373c..5555d006 100644 --- a/src/cpl-api/cpl/api/error.py +++ b/src/cpl-api/cpl/api/error.py @@ -1,9 +1,15 @@ from http.client import HTTPException +from starlette.responses import JSONResponse + class APIError(HTTPException): status_code = 500 + @classmethod + def response(cls): + return JSONResponse({"error": cls.__name__}, status_code=cls.status_code) + class Unauthorized(APIError): status_code = 401 diff --git a/src/cpl-api/cpl/api/middleware/authentication.py b/src/cpl-api/cpl/api/middleware/authentication.py new file mode 100644 index 00000000..d82ed747 --- /dev/null +++ b/src/cpl-api/cpl/api/middleware/authentication.py @@ -0,0 +1,49 @@ +from keycloak import KeycloakAuthenticationError +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request + +from cpl.api.api_logger import APILogger +from cpl.api.error import Unauthorized +from cpl.api.router import Router +from cpl.auth.keycloak import KeycloakClient +from cpl.dependency import ServiceProviderABC + +_logger = APILogger(__name__) + + +class AuthenticationMiddleware(BaseHTTPMiddleware): + + @classmethod + async def _verify_login(cls, token: str) -> bool: + keycloak = ServiceProviderABC.get_global_service(KeycloakClient) + try: + user_info = keycloak.userinfo(token) + if not user_info: + return False + except KeycloakAuthenticationError: + return False + return True + + async def dispatch(self, request: Request, call_next): + url = request.url.path + + if url not in Router.get_auth_required_routes(): + _logger.trace(f"No authentication required for {url}") + return await call_next(request) + + if not request.headers.get("Authorization"): + _logger.debug(f"Unauthorized access to {url}, missing Authorization header") + return Unauthorized(f"Missing header Authorization").response() + + + auth_header = request.headers.get("Authorization", None) + if not auth_header or not auth_header.startswith("Bearer "): + return Unauthorized("Invalid Authorization header").response() + + if not await self._verify_login(auth_header.split("Bearer ")[1]): + _logger.debug(f"Unauthorized access to {url}, invalid token") + return Unauthorized("Invalid token").response() + + # check user exists in db, if not create + # unauthorized if user is deleted + return await call_next(request) diff --git a/src/cpl-api/cpl/api/router.py b/src/cpl-api/cpl/api/router.py index 0a8d8ba0..e8936b25 100644 --- a/src/cpl-api/cpl/api/router.py +++ b/src/cpl-api/cpl/api/router.py @@ -3,11 +3,35 @@ from starlette.routing import Route class Router: _registered_routes: list[Route] = [] + _auth_required: list[str] = [] @classmethod def get_routes(cls) -> list[Route]: return cls._registered_routes + @classmethod + def get_auth_required_routes(cls) -> list[str]: + return cls._auth_required + + @classmethod + def authenticate(cls): + """ + Decorator to mark a route as requiring authentication. + Usage: + @Route.authenticate() + @Route.get("/example") + async def example_endpoint(request: TRequest): + ... + """ + + def inner(fn): + route_path = getattr(fn, "_route_path", None) + if route_path and route_path not in cls._auth_required: + cls._auth_required.append(route_path) + return fn + + return inner + @classmethod def route(cls, path=None, **kwargs): def inner(fn): @@ -57,4 +81,4 @@ class Router: return fn - return inner \ No newline at end of file + return inner diff --git a/src/cpl-api/cpl/api/typing.py b/src/cpl-api/cpl/api/typing.py index 06eea50e..ca570e59 100644 --- a/src/cpl-api/cpl/api/typing.py +++ b/src/cpl-api/cpl/api/typing.py @@ -10,4 +10,4 @@ HTTPMethods = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"] PartialMiddleware = Union[ Middleware, Callable[[ASGIApp], ASGIApp], -] \ No newline at end of file +] diff --git a/src/cpl-api/cpl/api/web_app.py b/src/cpl-api/cpl/api/web_app.py index ff4438cb..6c668782 100644 --- a/src/cpl-api/cpl/api/web_app.py +++ b/src/cpl-api/cpl/api/web_app.py @@ -13,6 +13,7 @@ from starlette.types import ExceptionHandler from cpl.api.api_logger import APILogger from cpl.api.api_settings import ApiSettings from cpl.api.error import APIError +from cpl.api.middleware.authentication import AuthenticationMiddleware from cpl.api.middleware.logging import LoggingMiddleware from cpl.api.middleware.request import RequestMiddleware from cpl.api.router import Router @@ -24,7 +25,6 @@ from cpl.dependency.service_provider_abc import ServiceProviderABC _logger = APILogger("API") - class WebApp(ApplicationABC): def __init__(self, services: ServiceProviderABC): super().__init__(services) @@ -37,18 +37,22 @@ class WebApp(ApplicationABC): Middleware(RequestMiddleware), Middleware(LoggingMiddleware), ] - self._exception_handlers: Mapping[Any, ExceptionHandler] = {Exception: self.handle_exception} + self._exception_handlers: Mapping[Any, ExceptionHandler] = { + Exception: self._handle_exception, + APIError: self._handle_exception, + } @staticmethod - async def handle_exception(request: Request, exc: Exception): + async def _handle_exception(request: Request, exc: Exception): + if isinstance(exc, APIError): + _logger.error(exc) + return JSONResponse({"error": str(exc)}, status_code=exc.status_code) + if hasattr(request.state, "request_id"): _logger.error(f"Request {request.state.request_id}", exc) else: _logger.error("Request unknown", exc) - if isinstance(exc, APIError): - return JSONResponse({"error": str(exc)}, status_code=exc.status_code) - return JSONResponse({"error": str(exc)}, status_code=500) def _get_allowed_origins(self): @@ -96,7 +100,15 @@ class WebApp(ApplicationABC): self._check_for_app() assert path is not None, "path must not be None" assert fn is not None, "fn must not be None" - assert method in ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"], "method must be a valid HTTP method" + assert method in [ + "GET", + "POST", + "PUT", + "DELETE", + "PATCH", + "OPTIONS", + "HEAD", + ], "method must be a valid HTTP method" self._routes.append(Route(path, fn, methods=[method], **kwargs)) return self @@ -105,15 +117,20 @@ class WebApp(ApplicationABC): if isinstance(middleware, Middleware): self._middleware.append(middleware) - elif callable(middleware): self._middleware.append(Middleware(middleware)) else: raise ValueError("middleware must be of type starlette.middleware.Middleware or a callable") - return self + def with_authentication(self): + self.with_middleware(AuthenticationMiddleware) + return self + + def with_authorization(self): + pass + def main(self): _logger.debug(f"Preparing API") if self._app is None: diff --git a/src/cpl-dependency/cpl/dependency/service_provider.py b/src/cpl-dependency/cpl/dependency/service_provider.py index 075c14c5..02f455b8 100644 --- a/src/cpl-dependency/cpl/dependency/service_provider.py +++ b/src/cpl-dependency/cpl/dependency/service_provider.py @@ -77,7 +77,7 @@ class ServiceProvider(ServiceProviderABC): return implementations - def _build_by_signature(self, sig: Signature, origin_service_type: type=None) -> list[R]: + def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[R]: params = [] for param in sig.parameters.items(): parameter = param[1] diff --git a/src/cpl-dependency/cpl/dependency/service_provider_abc.py b/src/cpl-dependency/cpl/dependency/service_provider_abc.py index 5873f443..6e0c5dda 100644 --- a/src/cpl-dependency/cpl/dependency/service_provider_abc.py +++ b/src/cpl-dependency/cpl/dependency/service_provider_abc.py @@ -36,7 +36,7 @@ class ServiceProviderABC(ABC): return cls._provider.get_services(instance_type, *args, **kwargs) @abstractmethod - def _build_by_signature(self, sig: Signature, origin_service_type: type=None) -> list[R]: ... + def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[R]: ... @abstractmethod def _build_service(self, service_type: type, *args, **kwargs) -> object: diff --git a/tests/custom/api/src/main.py b/tests/custom/api/src/main.py index 54cb83c9..e0064b47 100644 --- a/tests/custom/api/src/main.py +++ b/tests/custom/api/src/main.py @@ -15,6 +15,7 @@ def main(): app.with_route(path="/route1", fn=lambda r: JSONResponse("route1"), method="GET") app.with_routes_directory("routes") app.with_logging() + app.with_authentication() app.run() diff --git a/tests/custom/api/src/routes/ping.py b/tests/custom/api/src/routes/ping.py index 68a79d1a..0324a2f6 100644 --- a/tests/custom/api/src/routes/ping.py +++ b/tests/custom/api/src/routes/ping.py @@ -7,6 +7,7 @@ from cpl.core.log import Logger from service import PingService +@Router.authenticate() @Router.get(f"/ping") async def ping(r: Request, ping: PingService, logger: Logger): logger.info(f"Ping: {ping}")