From 8226f74f9ec0fe77c563e54b3e7c2d0800f403b3 Mon Sep 17 00:00:00 2001 From: Major Hayden Date: Wed, 29 Apr 2026 18:20:58 -0500 Subject: [PATCH 1/3] feat: add auth monitoring metrics Signed-off-by: Major Hayden --- src/authentication/api_key_token.py | 33 ++- src/authentication/jwk_token.py | 178 ++++++++++----- src/authentication/k8s.py | 212 ++++++++++++++---- src/authentication/noop.py | 8 + src/authentication/noop_with_token.py | 32 ++- src/authentication/rh_identity.py | 103 +++++++-- src/authentication/utils.py | 46 +++- src/metrics/__init__.py | 33 ++- src/metrics/recording.py | 114 ++++++++++ tests/unit/authentication/test_jwk_token.py | 57 +++++ tests/unit/authentication/test_k8s.py | 90 +++++++- tests/unit/authentication/test_rh_identity.py | 88 +++++++- tests/unit/authentication/test_utils.py | 50 ++++- tests/unit/metrics/test_recording.py | 79 +++++++ 14 files changed, 978 insertions(+), 145 deletions(-) diff --git a/src/authentication/api_key_token.py b/src/authentication/api_key_token.py index 2b9b27900..b173fa254 100644 --- a/src/authentication/api_key_token.py +++ b/src/authentication/api_key_token.py @@ -7,12 +7,18 @@ """ import secrets +import time from fastapi import HTTPException, Request, status from authentication.interface import AuthInterface, AuthTuple -from authentication.utils import extract_user_token +from authentication.utils import ( + AUTH_CAUSE_NO_HEADER, + extract_user_token, + record_auth_metrics, +) from constants import ( + AUTH_MOD_APIKEY_TOKEN, DEFAULT_USER_NAME, DEFAULT_USER_UID, DEFAULT_VIRTUAL_PATH, @@ -59,16 +65,39 @@ async def __call__(self, request: Request) -> AuthTuple: HTTPException: If the bearer token is missing or doesn't match the configured API key (HTTP 401). """ + start_time = time.monotonic() + # try to extract user token from request - user_token = extract_user_token(request.headers) + try: + user_token = extract_user_token(request.headers) + except HTTPException as exc: + # Distinguish missing header from malformed token + reason = "missing_token" + if ( + isinstance(exc.detail, dict) + and exc.detail.get("cause") == AUTH_CAUSE_NO_HEADER + ): + reason = "missing_header" + record_auth_metrics(AUTH_MOD_APIKEY_TOKEN, "failure", reason, start_time) + raise + except Exception: # pylint: disable=broad-exception-caught + logger.exception("Unexpected error while extracting API key bearer token") + record_auth_metrics( + AUTH_MOD_APIKEY_TOKEN, "failure", "unexpected_error", start_time + ) + raise # API Key validation. Use secrets.compare_digest for constant-time comparison if not secrets.compare_digest( user_token, self.config.api_key.get_secret_value() ): + record_auth_metrics( + AUTH_MOD_APIKEY_TOKEN, "failure", "invalid_key", start_time + ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key", ) + record_auth_metrics(AUTH_MOD_APIKEY_TOKEN, "success", "valid_key", start_time) return DEFAULT_USER_UID, DEFAULT_USER_NAME, self.skip_userid_check, user_token diff --git a/src/authentication/jwk_token.py b/src/authentication/jwk_token.py index 7cc15870b..db8241b73 100644 --- a/src/authentication/jwk_token.py +++ b/src/authentication/jwk_token.py @@ -1,6 +1,7 @@ """Manage authentication flow for FastAPI endpoints with JWK based JWT auth.""" import json +import time from asyncio import Lock from collections.abc import Callable from typing import Any @@ -17,12 +18,13 @@ from fastapi import HTTPException, Request from authentication.interface import AuthInterface, AuthTuple -from authentication.utils import extract_user_token +from authentication.utils import extract_user_token, record_auth_metrics from constants import ( + AUTH_MOD_JWK_TOKEN, DEFAULT_VIRTUAL_PATH, ) from log import get_logger -from models.api.responses.error import UnauthorizedResponse +from models.api.responses.error import ServiceUnavailableResponse, UnauthorizedResponse from models.config import JwkConfiguration logger = get_logger(__name__) @@ -141,6 +143,93 @@ def _internal(header: dict[str, Any], _payload: dict[str, Any]) -> Key: return _internal +async def _get_jwk_set_for_auth(config: JwkConfiguration, start_time: float) -> KeySet: + """Load the configured JWK set and record bounded auth failures.""" + try: + return await get_jwk_set(str(config.url)) + except aiohttp.ClientError as exc: + logger.error("Failed to fetch JWK set: %s", exc) + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "jwk_fetch_error", start_time + ) + response = ServiceUnavailableResponse( + backend_name="JWK key server", + cause="Unable to reach authentication key server", + ) + raise HTTPException(**response.model_dump()) from exc + except json.JSONDecodeError as exc: + logger.error("Invalid JSON in JWK set response: %s", exc) + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_json", start_time) + response = ServiceUnavailableResponse( + backend_name="JWK key server", + cause="Authentication key server returned invalid data", + ) + raise HTTPException(**response.model_dump()) from exc + except JoseError as exc: + logger.error("Invalid JWK set format: %s", exc) + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_jwk", start_time) + response = ServiceUnavailableResponse( + backend_name="JWK key server", + cause="Authentication keys are malformed", + ) + raise HTTPException(**response.model_dump()) from exc + + +def _decode_jwk_claims(user_token: str, jwk_set: KeySet, start_time: float) -> Any: + """Decode a JWT and record bounded auth failures.""" + try: + return jwt.decode(user_token, key=key_resolver_func(jwk_set)) + except (KeyNotFoundError, BadSignatureError, DecodeError, JoseError) as exc: + logger.warning("Token decode error: %s", exc) + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "token_decode_error", start_time + ) + if isinstance(exc, KeyNotFoundError): + cause = "Token signed by unknown key" + elif isinstance(exc, BadSignatureError): + cause = "Invalid token signature" + elif isinstance(exc, DecodeError): + cause = "Token could not be decoded" + else: + cause = "Token format error" + response = UnauthorizedResponse(cause=cause) + raise HTTPException(**response.model_dump()) from exc + + +def _validate_jwk_claims(claims: Any, start_time: float) -> None: + """Validate decoded JWT claims and record bounded auth failures.""" + try: + claims.validate() + except ExpiredTokenError as exc: + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "token_expired", start_time) + response = UnauthorizedResponse(cause="Token has expired") + raise HTTPException(**response.model_dump()) from exc + except JoseError as exc: + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "token_validation_error", start_time + ) + response = UnauthorizedResponse(cause="Token validation failed") + raise HTTPException(**response.model_dump()) from exc + + +def _get_required_claim(claims: Any, claim_name: str, start_time: float) -> str: + """Return a required JWT claim and record bounded auth failures when missing.""" + try: + value = claims[claim_name] + except KeyError as exc: + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "missing_claim", start_time) + response = UnauthorizedResponse(cause=f"Token missing claim: {claim_name}") + raise HTTPException(**response.model_dump()) from exc + if not isinstance(value, str) or not value.strip(): + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_claim", start_time) + response = UnauthorizedResponse(cause=f"Token has invalid claim: {claim_name}") + invalid_claim_error = ValueError( + f"Token claim {claim_name} must be a non-empty string" + ) + raise HTTPException(**response.model_dump()) from invalid_claim_error + return value + + class JwkTokenAuthDependency(AuthInterface): # pylint: disable=too-few-public-methods """JWK AuthDependency class for JWK-based JWT authentication.""" @@ -189,73 +278,40 @@ async def __call__(self, request: Request) -> AuthTuple: extracted from the validated JWT. Only returned on successful authentication; all error paths raise HTTPException. """ + start_time = time.monotonic() + if not request.headers.get("Authorization"): + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "missing_header", start_time + ) response = UnauthorizedResponse(cause="No Authorization header found") raise HTTPException(**response.model_dump()) - user_token = extract_user_token(request.headers) - - try: - jwk_set = await get_jwk_set(str(self.config.url)) - except aiohttp.ClientError as exc: - logger.error("Failed to fetch JWK set: %s", exc) - response = UnauthorizedResponse( - cause="Unable to reach authentication key server" - ) - raise HTTPException(**response.model_dump()) from exc - except json.JSONDecodeError as exc: - logger.error("Invalid JSON in JWK set response: %s", exc) - response = UnauthorizedResponse( - cause="Authentication key server returned invalid data" - ) - raise HTTPException(**response.model_dump()) from exc - except JoseError as exc: - logger.error("Invalid JWK set format: %s", exc) - response = UnauthorizedResponse(cause="Authentication keys are malformed") - raise HTTPException(**response.model_dump()) from exc - - try: - claims = jwt.decode(user_token, key=key_resolver_func(jwk_set)) - except (KeyNotFoundError, BadSignatureError, DecodeError, JoseError) as exc: - logger.warning("Token decode error: %s", exc) - cause_map = { - KeyNotFoundError: "Token signed by unknown key", - BadSignatureError: "Invalid token signature", - DecodeError: "Token could not be decoded", - JoseError: "Token format error", - } - response = UnauthorizedResponse( - cause=cause_map.get(type(exc), "Unknown token error") - ) - raise HTTPException(**response.model_dump()) from exc - try: - claims.validate() - except ExpiredTokenError as exc: - response = UnauthorizedResponse(cause="Token has expired") - raise HTTPException(**response.model_dump()) from exc - except JoseError as exc: - response = UnauthorizedResponse(cause="Token validation failed") - raise HTTPException(**response.model_dump()) from exc - - try: - user_id: str = claims[self.config.jwt_configuration.user_id_claim] - except KeyError as exc: - missing_claim = self.config.jwt_configuration.user_id_claim - response = UnauthorizedResponse( - cause=f"Token missing claim: {missing_claim}" + user_token = extract_user_token(request.headers) + except HTTPException: + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "missing_token", start_time ) - raise HTTPException(**response.model_dump()) from exc - - try: - username: str = claims[self.config.jwt_configuration.username_claim] - except KeyError as exc: - missing_claim = self.config.jwt_configuration.username_claim - response = UnauthorizedResponse( - cause=f"Token missing claim: {missing_claim}" + raise + except Exception: # pylint: disable=broad-exception-caught + logger.exception("Unexpected error while extracting JWK bearer token") + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "unexpected_error", start_time ) - raise HTTPException(**response.model_dump()) from exc + raise + + jwk_set = await _get_jwk_set_for_auth(self.config, start_time) + claims = _decode_jwk_claims(user_token, jwk_set, start_time) + _validate_jwk_claims(claims, start_time) + user_id = _get_required_claim( + claims, self.config.jwt_configuration.user_id_claim, start_time + ) + username = _get_required_claim( + claims, self.config.jwt_configuration.username_claim, start_time + ) logger.info("Successfully authenticated user %s (ID: %s)", username, user_id) + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "success", "authenticated", start_time) return user_id, username, self.skip_userid_check, user_token diff --git a/src/authentication/k8s.py b/src/authentication/k8s.py index 4a8c9dbdf..0e7ea01e8 100644 --- a/src/authentication/k8s.py +++ b/src/authentication/k8s.py @@ -1,6 +1,7 @@ """Manage authentication flow for FastAPI endpoints with K8S/OCP.""" import os +import time from http import HTTPStatus from typing import Optional, Self, cast @@ -10,9 +11,9 @@ from kubernetes.config import ConfigException from authentication.interface import NO_AUTH_TUPLE, AuthInterface, AuthTuple -from authentication.utils import extract_user_token +from authentication.utils import extract_user_token, record_auth_metrics from configuration import configuration -from constants import DEFAULT_VIRTUAL_PATH +from constants import AUTH_MOD_K8S, DEFAULT_VIRTUAL_PATH from log import get_logger from models.api.responses.error import ( ForbiddenResponse, @@ -382,7 +383,141 @@ def get_user_info(token: str) -> Optional[kubernetes.client.V1TokenReviewStatus] raise HTTPException(**response_obj.model_dump()) from e except Exception as e: # pylint: disable=broad-exception-caught logger.error("Unexpected error during TokenReview: %s", e) - return None + response_obj = InternalServerErrorResponse( + response="Internal server error", + cause=f"Unexpected error during TokenReview: {e}", + ) + raise HTTPException(**response_obj.model_dump()) from e + + +def _resolve_kube_admin_uid( + user: kubernetes.client.V1UserInfo, start_time: float +) -> str: + """Resolve the effective UID for a Kubernetes user. + + For ``kube:admin`` users, the UID is replaced with the cluster ID + to provide a stable identifier. For all other users, the original + UID is returned unchanged. + + Args: + user: Kubernetes user info from token review. + start_time: Monotonic timestamp for auth duration metrics. + + Returns: + The resolved user UID string. + + Raises: + HTTPException: If cluster ID resolution fails for kube:admin users. + """ + if user.username != "kube:admin": + return cast(str, user.uid) + try: + return K8sClientSingleton.get_cluster_id() + except K8sAPIConnectionError as e: + logger.error("Cannot connect to Kubernetes API: %s", e) + record_auth_metrics(AUTH_MOD_K8S, "failure", "k8s_api_unavailable", start_time) + resp = ServiceUnavailableResponse( + backend_name="Kubernetes API", + cause=str(e), + ) + raise HTTPException(**resp.model_dump()) from e + except K8sConfigurationError as e: + logger.error("Cluster configuration error: %s", e) + record_auth_metrics(AUTH_MOD_K8S, "failure", "k8s_config_error", start_time) + resp = InternalServerErrorResponse( + response="Internal server error", + cause=str(e), + ) + raise HTTPException(**resp.model_dump()) from e + except Exception as e: # pylint: disable=broad-exception-caught + logger.exception("Unexpected error while resolving kube:admin cluster ID") + record_auth_metrics(AUTH_MOD_K8S, "failure", "unexpected_error", start_time) + resp = InternalServerErrorResponse( + response="Internal server error", + cause=str(e), + ) + raise HTTPException(**resp.model_dump()) from e + + +def _create_subject_access_review( + user: kubernetes.client.V1UserInfo, virtual_path: str, start_time: float +) -> kubernetes.client.V1SubjectAccessReview: + """Create a Kubernetes SubjectAccessReview and record bounded failures.""" + try: + authorization_api = K8sClientSingleton.get_authz_api() + except Exception as e: + logger.error("Failed to initialize Kubernetes authorization API: %s", e) + record_auth_metrics( + AUTH_MOD_K8S, "failure", "authorization_check_error", start_time + ) + response = ServiceUnavailableResponse( + backend_name="Kubernetes API", + cause=f"Unable to initialize Kubernetes client: {e}", + ) + raise HTTPException(**response.model_dump()) from e + + try: + sar = kubernetes.client.V1SubjectAccessReview( + spec=kubernetes.client.V1SubjectAccessReviewSpec( + user=user.username, + groups=user.groups, + non_resource_attributes=kubernetes.client.V1NonResourceAttributes( + path=virtual_path, verb="get" + ), + ) + ) + return cast( + kubernetes.client.V1SubjectAccessReview, + authorization_api.create_subject_access_review(sar), + ) + except ApiException as e: + record_auth_metrics( + AUTH_MOD_K8S, "failure", "authorization_check_error", start_time + ) + if e.status is None: + logger.error( + "Kubernetes API error during SubjectAccessReview with no status code: %s", + e.reason, + ) + response = ServiceUnavailableResponse( + backend_name="Kubernetes API", + cause=f"Failed to connect to Kubernetes API: {e.reason}", + ) + raise HTTPException(**response.model_dump()) from e + + if ( + e.status >= HTTPStatus.INTERNAL_SERVER_ERROR + or e.status == HTTPStatus.TOO_MANY_REQUESTS + ): + logger.error( + "Kubernetes API unavailable during SubjectAccessReview (status %s): %s", + e.status, + e.reason, + ) + response = ServiceUnavailableResponse( + backend_name="Kubernetes API", + cause=f"Kubernetes API unavailable: {e.reason} (status {e.status})", + ) + raise HTTPException(**response.model_dump()) from e + + logger.error( + "Kubernetes API returned client error during SubjectAccessReview (status %s): %s", + e.status, + e.reason, + ) + response_obj = InternalServerErrorResponse( + response="Internal server error", + cause=f"Kubernetes API request failed: {e.reason} (status {e.status})", + ) + raise HTTPException(**response_obj.model_dump()) from e + except Exception as e: # pylint: disable=broad-exception-caught + logger.exception("Unexpected error during SubjectAccessReview") + record_auth_metrics(AUTH_MOD_K8S, "failure", "unexpected_error", start_time) + response = InternalServerErrorResponse( + response="Internal server error", + cause=str(e), + ) + raise HTTPException(**response.model_dump()) from e class K8SAuthDependency(AuthInterface): # pylint: disable=too-few-public-methods @@ -436,80 +571,59 @@ async def __call__(self, request: Request) -> AuthTuple: ------ HTTPException: If authentication or authorization fails. """ + start_time = time.monotonic() + # LCORE-694: Config option to skip authorization for readiness and liveness probe if not request.headers.get("Authorization"): if configuration.authentication_configuration.skip_for_health_probes: if request.url.path in ("/readiness", "/liveness"): + record_auth_metrics( + AUTH_MOD_K8S, "skipped", "health_probe", start_time + ) return NO_AUTH_TUPLE # Skip auth for metrics endpoint when configured if configuration.authentication_configuration.skip_for_metrics: if request.url.path in ("/metrics",): + record_auth_metrics(AUTH_MOD_K8S, "skipped", "metrics", start_time) return NO_AUTH_TUPLE - token = extract_user_token(request.headers) - user_info = get_user_info(token) + try: + token = extract_user_token(request.headers) + except HTTPException: + record_auth_metrics(AUTH_MOD_K8S, "failure", "missing_token", start_time) + raise + try: + user_info = get_user_info(token) + except HTTPException: + record_auth_metrics( + AUTH_MOD_K8S, "failure", "token_review_error", start_time + ) + raise if user_info is None: + record_auth_metrics(AUTH_MOD_K8S, "failure", "invalid_token", start_time) response = UnauthorizedResponse(cause="Invalid or expired Kubernetes token") raise HTTPException(**response.model_dump()) # Cast user to proper type for type checking user = cast(kubernetes.client.V1UserInfo, user_info.user) - if user.username == "kube:admin": - try: - user.uid = K8sClientSingleton.get_cluster_id() - except K8sAPIConnectionError as e: - # Kubernetes API is unreachable - return 503 - logger.error("Cannot connect to Kubernetes API: %s", e) - response = ServiceUnavailableResponse( - backend_name="Kubernetes API", - cause=str(e), - ) - raise HTTPException(**response.model_dump()) from e - except K8sConfigurationError as e: - # Cluster misconfiguration or client error - return 500 - logger.error("Cluster configuration error: %s", e) - response = InternalServerErrorResponse( - response="Internal server error", - cause=str(e), - ) - raise HTTPException(**response.model_dump()) from e - - try: - authorization_api = K8sClientSingleton.get_authz_api() - sar = kubernetes.client.V1SubjectAccessReview( - spec=kubernetes.client.V1SubjectAccessReviewSpec( - user=user.username, - groups=user.groups, - non_resource_attributes=kubernetes.client.V1NonResourceAttributes( - path=self.virtual_path, verb="get" - ), - ) - ) - sar_response = cast( - kubernetes.client.V1SubjectAccessReview, - authorization_api.create_subject_access_review(sar), - ) - - except Exception as e: - logger.error("API exception during SubjectAccessReview: %s", e) - response = ServiceUnavailableResponse( - backend_name="Kubernetes API", - cause="Unable to perform authorization check", - ) - raise HTTPException(**response.model_dump()) from e + user_uid = _resolve_kube_admin_uid(user, start_time) + sar_response = _create_subject_access_review( + user, self.virtual_path, start_time + ) sar_status = cast( kubernetes.client.V1SubjectAccessReviewStatus, sar_response.status ) - user_uid = cast(str, user.uid) username = cast(str, user.username) if not sar_status.allowed: + record_auth_metrics(AUTH_MOD_K8S, "failure", "not_authorized", start_time) response = ForbiddenResponse.endpoint(user_id=user_uid) raise HTTPException(**response.model_dump()) + record_auth_metrics(AUTH_MOD_K8S, "success", "authenticated", start_time) return ( user_uid, username, diff --git a/src/authentication/noop.py b/src/authentication/noop.py index 6d32f45c3..bab22479a 100644 --- a/src/authentication/noop.py +++ b/src/authentication/noop.py @@ -1,9 +1,13 @@ """Manage authentication flow for FastAPI endpoints with no-op auth.""" +import time + from fastapi import HTTPException, Request from authentication.interface import AuthInterface, AuthTuple +from authentication.utils import record_auth_metrics from constants import ( + AUTH_MOD_NOOP, DEFAULT_USER_NAME, DEFAULT_USER_UID, DEFAULT_VIRTUAL_PATH, @@ -47,6 +51,8 @@ async def __call__(self, request: Request) -> AuthTuple: - skip_userid_check: True to indicate the user ID check is skipped. - token: NO_USER_TOKEN. """ + start_time = time.monotonic() + logger.warning( "No-op authentication dependency is being used. " "The service is running in insecure mode intended solely for development purposes" @@ -54,6 +60,8 @@ async def __call__(self, request: Request) -> AuthTuple: # try to extract user ID from request user_id = request.query_params.get("user_id", DEFAULT_USER_UID) if not user_id: + record_auth_metrics(AUTH_MOD_NOOP, "failure", "empty_user_id", start_time) raise HTTPException(status_code=400, detail="user_id cannot be empty") logger.debug("Retrieved user ID: %s", user_id) + record_auth_metrics(AUTH_MOD_NOOP, "skipped", "no_auth_required", start_time) return user_id, DEFAULT_USER_NAME, self.skip_userid_check, NO_USER_TOKEN diff --git a/src/authentication/noop_with_token.py b/src/authentication/noop_with_token.py index 0656d952a..3b33404ac 100644 --- a/src/authentication/noop_with_token.py +++ b/src/authentication/noop_with_token.py @@ -6,14 +6,21 @@ - Reads a user token from request headers via `authentication.utils.extract_user_token`. - Reads `user_id` from query params (falls back to `DEFAULT_USER_UID`) and pairs it with `DEFAULT_USER_NAME`. -- Returns a tuple: (user_id, DEFAULT_USER_NAME, user_token). +- Returns a tuple: (user_id, DEFAULT_USER_NAME, skip_userid_check, user_token). """ +import time + from fastapi import HTTPException, Request from authentication.interface import AuthInterface, AuthTuple -from authentication.utils import extract_user_token +from authentication.utils import ( + AUTH_CAUSE_NO_HEADER, + extract_user_token, + record_auth_metrics, +) from constants import ( + AUTH_MOD_NOOP_WITH_TOKEN, DEFAULT_USER_NAME, DEFAULT_USER_UID, DEFAULT_VIRTUAL_PATH, @@ -59,15 +66,34 @@ async def __call__(self, request: Request) -> AuthTuple: - skip_userid_check: True to indicate user-id checks are skipped. - user_token: Token extracted from the request headers. """ + start_time = time.monotonic() + logger.warning( "No-op with token authentication dependency is being used. " "The service is running in insecure mode intended solely for development purposes" ) # try to extract user token from request - user_token = extract_user_token(request.headers) + try: + user_token = extract_user_token(request.headers) + except HTTPException as exc: + # Distinguish missing header from malformed token using structured detail + reason = "missing_token" + if ( + isinstance(exc.detail, dict) + and exc.detail.get("cause") == AUTH_CAUSE_NO_HEADER + ): + reason = "missing_header" + record_auth_metrics(AUTH_MOD_NOOP_WITH_TOKEN, "failure", reason, start_time) + raise # try to extract user ID from request user_id = request.query_params.get("user_id", DEFAULT_USER_UID) if not user_id: + record_auth_metrics( + AUTH_MOD_NOOP_WITH_TOKEN, "failure", "empty_user_id", start_time + ) raise HTTPException(status_code=400, detail="user_id cannot be empty") logger.debug("Retrieved user ID: %s", user_id) + record_auth_metrics( + AUTH_MOD_NOOP_WITH_TOKEN, "skipped", "no_auth_required", start_time + ) return user_id, DEFAULT_USER_NAME, self.skip_userid_check, user_token diff --git a/src/authentication/rh_identity.py b/src/authentication/rh_identity.py index f772ae9e5..026a45b8e 100644 --- a/src/authentication/rh_identity.py +++ b/src/authentication/rh_identity.py @@ -6,13 +6,16 @@ import base64 import json -from typing import Any, Optional +import time +from typing import Any, Final, Optional from fastapi import HTTPException, Request from authentication.interface import NO_AUTH_TUPLE, AuthInterface, AuthTuple +from authentication.utils import record_auth_metrics from configuration import configuration from constants import ( + AUTH_MOD_RH_IDENTITY, DEFAULT_RH_IDENTITY_MAX_HEADER_SIZE, DEFAULT_VIRTUAL_PATH, NO_USER_TOKEN, @@ -24,6 +27,9 @@ RH_INSIGHTS_REQUEST_ID_HEADER = "x-rh-insights-request-id" REQUEST_ID_HEADER = "x-request-id" +HEALTH_PROBE_SKIP_PATHS: Final[frozenset[str]] = frozenset({"/readiness", "/liveness"}) +METRICS_SKIP_PATH: Final[str] = "/metrics" + def _get_request_id(request: Request) -> str: """Return the inbound request identifier available during authentication.""" @@ -33,6 +39,30 @@ def _get_request_id(request: Request) -> str: ) +def _record_rh_identity_auth(result: str, reason: str, start_time: float) -> None: + """Record RH Identity authentication metrics with bounded labels.""" + record_auth_metrics(AUTH_MOD_RH_IDENTITY, result, reason, start_time) + + +def _normalized_request_path(request: Request) -> str: + """Return a canonical request path for exact skip-path comparisons.""" + return request.url.path.rstrip("/") or "/" + + +def _get_auth_skip_tuple(request: Request, start_time: float) -> Optional[AuthTuple]: + """Return an auth tuple for configured RH Identity skip paths.""" + request_path = _normalized_request_path(request) + if request_path in HEALTH_PROBE_SKIP_PATHS: + if configuration.authentication_configuration.skip_for_health_probes: + _record_rh_identity_auth("skipped", "health_probe", start_time) + return NO_AUTH_TUPLE + if request_path == METRICS_SKIP_PATH: + if configuration.authentication_configuration.skip_for_metrics: + _record_rh_identity_auth("skipped", "metrics", start_time) + return NO_AUTH_TUPLE + return None + + class RHIdentityData: """Extracts and validates Red Hat Identity header data. @@ -73,6 +103,12 @@ def _validate_structure(self) -> None: raise HTTPException(status_code=400, detail="Invalid identity data") identity = self.identity_data["identity"] + if not isinstance(identity, dict): + logger.warning( + "Identity validation failed: 'identity' is %s, expected dict", + type(identity).__name__, + ) + raise HTTPException(status_code=400, detail="Invalid identity data") if "type" not in identity: logger.warning("Identity validation failed: missing 'type' field") raise HTTPException(status_code=400, detail="Invalid identity data") @@ -106,6 +142,12 @@ def _validate_user_fields(self, identity: dict) -> None: ) raise HTTPException(status_code=400, detail="Invalid identity data") user = identity["user"] + if not isinstance(user, dict): + logger.warning( + "Identity validation failed: 'user' is %s, expected dict", + type(user).__name__, + ) + raise HTTPException(status_code=400, detail="Invalid identity data") if "user_id" not in user: logger.warning("Identity validation failed: missing 'user_id' in user data") raise HTTPException(status_code=400, detail="Invalid identity data") @@ -132,6 +174,12 @@ def _validate_system_fields(self, identity: dict) -> None: ) raise HTTPException(status_code=400, detail="Invalid identity data") system = identity["system"] + if not isinstance(system, dict): + logger.warning( + "Identity validation failed: 'system' is %s, expected dict", + type(system).__name__, + ) + raise HTTPException(status_code=400, detail="Invalid identity data") if "cn" not in system: logger.warning("Identity validation failed: missing 'cn' in system data") raise HTTPException(status_code=400, detail="Invalid identity data") @@ -321,18 +369,20 @@ async def __call__(self, request: Request) -> AuthTuple: - 400: Invalid base64, invalid JSON, or missing required fields - 403: Missing required entitlements """ + start_time = time.monotonic() + + # Short-circuit for configured skip paths (health probes, metrics) + # before any header validation so malformed headers on skip paths + # don't cause spurious errors. + auth_skip = _get_auth_skip_tuple(request, start_time) + if auth_skip is not None: + return auth_skip + # Extract header identity_header = request.headers.get("x-rh-identity") if not identity_header: - # Skip auth for health probes when configured - if request.url.path.endswith(("/readiness", "/liveness")): - if configuration.authentication_configuration.skip_for_health_probes: - return NO_AUTH_TUPLE - # Skip auth for metrics endpoint when configured - if request.url.path.endswith("/metrics"): - if configuration.authentication_configuration.skip_for_metrics: - return NO_AUTH_TUPLE logger.warning("Missing x-rh-identity header") + _record_rh_identity_auth("failure", "missing_header", start_time) raise HTTPException(status_code=401, detail="Missing x-rh-identity header") # Enforce header size limit before decoding @@ -342,6 +392,7 @@ async def __call__(self, request: Request) -> AuthTuple: len(identity_header), self.max_header_size, ) + _record_rh_identity_auth("failure", "header_too_large", start_time) raise HTTPException( status_code=400, detail="x-rh-identity header exceeds maximum allowed size", @@ -353,6 +404,7 @@ async def __call__(self, request: Request) -> AuthTuple: decoded_str = decoded_bytes.decode("utf-8") except (ValueError, UnicodeDecodeError) as exc: logger.warning("Invalid base64 in x-rh-identity header: %s", exc) + _record_rh_identity_auth("failure", "invalid_base64", start_time) raise HTTPException( status_code=400, detail="Invalid base64 encoding in x-rh-identity header", @@ -363,18 +415,38 @@ async def __call__(self, request: Request) -> AuthTuple: identity_data = json.loads(decoded_str) except json.JSONDecodeError as exc: logger.warning("Invalid JSON in x-rh-identity header: %s", exc) + _record_rh_identity_auth("failure", "invalid_json", start_time) raise HTTPException( status_code=400, detail="Invalid JSON in x-rh-identity header" ) from exc + # Guard against non-dict JSON payloads (null, list, number, etc.) + if not isinstance(identity_data, dict): + logger.warning( + "x-rh-identity decoded to non-dict type: %s", + type(identity_data).__name__, + ) + _record_rh_identity_auth("failure", "invalid_identity", start_time) + raise HTTPException( + status_code=400, + detail="Invalid identity data in x-rh-identity header", + ) + # Extract and validate identity - rh_identity = RHIdentityData( - identity_data, - required_entitlements=self.required_entitlements, - ) + try: + rh_identity = RHIdentityData( + identity_data, + required_entitlements=self.required_entitlements, + ) - # Validate entitlements if configured - rh_identity.validate_entitlements() + # Validate entitlements if configured + rh_identity.validate_entitlements() + except HTTPException as exc: + reason = ( + "entitlement_missing" if exc.status_code == 403 else "invalid_identity" + ) + _record_rh_identity_auth("failure", reason, start_time) + raise # Store identity data in request.state for downstream access request.state.rh_identity_data = rh_identity @@ -390,5 +462,6 @@ async def __call__(self, request: Request) -> AuthTuple: rh_identity.get_org_id(), rh_identity.get_system_id(), ) + _record_rh_identity_auth("success", "authenticated", start_time) return user_id, username, self.skip_userid_check, NO_USER_TOKEN diff --git a/src/authentication/utils.py b/src/authentication/utils.py index aad460c00..988cb0a56 100644 --- a/src/authentication/utils.py +++ b/src/authentication/utils.py @@ -1,10 +1,22 @@ """Authentication utility functions.""" +import time +from typing import Final + from fastapi import HTTPException from starlette.datastructures import Headers +from log import get_logger +from metrics import recording from models.api.responses.error import UnauthorizedResponse +logger = get_logger(__name__) + +# Cause strings used in UnauthorizedResponse for token extraction failures. +# Auth modules compare against these constants to classify failure reasons. +AUTH_CAUSE_NO_HEADER: Final[str] = "No Authorization header found" +AUTH_CAUSE_NO_TOKEN: Final[str] = "No token found in Authorization header" + def extract_user_token(headers: Headers) -> str: """Extract the bearer token from an HTTP Authorization header. @@ -24,12 +36,42 @@ def extract_user_token(headers: Headers) -> str: """ authorization_header = headers.get("Authorization") if not authorization_header: - response = UnauthorizedResponse(cause="No Authorization header found") + response = UnauthorizedResponse(cause=AUTH_CAUSE_NO_HEADER) raise HTTPException(**response.model_dump()) scheme_and_token = authorization_header.strip().split() if len(scheme_and_token) != 2 or scheme_and_token[0].lower() != "bearer": - response = UnauthorizedResponse(cause="No token found in Authorization header") + response = UnauthorizedResponse(cause=AUTH_CAUSE_NO_TOKEN) raise HTTPException(**response.model_dump()) return scheme_and_token[1] + + +def record_auth_metrics( + auth_module: str, result: str, reason: str, start_time: float +) -> None: + """Record authentication attempt and duration metrics together. + + Parameters: + ---------- + auth_module: Configured authentication module name. + result: Bounded result label, such as ``success`` or ``failure``. + reason: Bounded reason label for the result. + start_time: Monotonic clock time captured at the start of auth handling. + + Returns: + ------- + None: Metrics are recorded as a side effect. + """ + try: + recording.record_auth_attempt(auth_module, result, reason) + recording.record_auth_duration( + auth_module, result, time.monotonic() - start_time + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to record authentication metrics for module %s with result %s", + auth_module, + result, + exc_info=True, + ) diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py index 63d7b8e29..d288bfe92 100644 --- a/src/metrics/__init__.py +++ b/src/metrics/__init__.py @@ -22,6 +22,20 @@ float("inf"), ) +AUTH_DURATION_BUCKETS: Final[tuple[float, ...]] = ( + 0.001, + 0.005, + 0.01, + 0.025, + 0.05, + 0.1, + 0.25, + 0.5, + 1.0, + 2.5, + 5.0, + float("inf"), +) # Counter to track REST API calls # This will be used to count how many times each API endpoint is called # and the status code of the response @@ -75,10 +89,27 @@ ["provider", "model", "endpoint"], ) + # Histogram to measure the latency of direct LLM inference backend calls. -llm_inference_duration_seconds = Histogram( +llm_inference_duration_seconds: Final[Histogram] = Histogram( "ls_llm_inference_duration_seconds", "LLM inference call duration", ["provider", "model", "endpoint", "result"], buckets=LLM_INFERENCE_DURATION_BUCKETS, ) + + +# Counter to track authentication attempts with bounded auth_module, result, and reason labels. +auth_attempts_total: Final[Counter] = Counter( + "ls_auth_attempts_total", + "Authentication attempts", + ["auth_module", "result", "reason"], +) + +# Histogram to measure authentication dependency latency with bounded module and result labels. +auth_duration_seconds: Final[Histogram] = Histogram( + "ls_auth_duration_seconds", + "Authentication duration", + ["auth_module", "result"], + buckets=AUTH_DURATION_BUCKETS, +) diff --git a/src/metrics/recording.py b/src/metrics/recording.py index a9b35d208..50af0b332 100644 --- a/src/metrics/recording.py +++ b/src/metrics/recording.py @@ -10,10 +10,84 @@ from typing import Final import metrics +from constants import SUPPORTED_AUTHENTICATION_MODULES from log import get_logger logger = get_logger(__name__) +AUTH_RESULT_SUCCESS: Final[str] = "success" +AUTH_RESULT_FAILURE: Final[str] = "failure" +AUTH_RESULT_SKIPPED: Final[str] = "skipped" +AUTH_RESULT_UNKNOWN: Final[str] = "unknown" +AUTH_REASON_UNKNOWN: Final[str] = "unknown" + +ALLOWED_AUTH_RESULTS: Final[frozenset[str]] = frozenset( + { + AUTH_RESULT_SUCCESS, + AUTH_RESULT_FAILURE, + AUTH_RESULT_SKIPPED, + } +) +# Allowed ``reason`` label values for auth metrics. This set must be kept in +# sync with all ``reason`` strings passed to ``record_auth_attempt`` across every +# authentication module (api_key_token, jwk_token, k8s, noop_with_token, +# rh_identity). Adding a new reason? Add it here too, otherwise it will be +# normalised to "unknown". +ALLOWED_AUTH_REASONS: Final[frozenset[str]] = frozenset( + { + "authenticated", + "authorization_check_error", + "empty_user_id", + "entitlement_missing", + "header_too_large", + "health_probe", + "invalid_base64", + "invalid_claim", + "invalid_identity", + "invalid_json", + "invalid_jwk", + "invalid_key", + "invalid_token", + "jwk_fetch_error", + "k8s_api_unavailable", + "k8s_config_error", + "malformed_token", + "metrics", + "missing_claim", + "missing_header", + "missing_token", + "no_auth_required", + "not_authorized", + "token_decode_error", + "token_expired", + "token_review_error", + "token_validation_error", + "unexpected_error", + "valid_key", + } +) + + +def normalize_auth_module(auth_module: str) -> str: + """Return a bounded authentication module label.""" + if auth_module in SUPPORTED_AUTHENTICATION_MODULES: + return auth_module + return AUTH_RESULT_UNKNOWN + + +def normalize_auth_result(result: str) -> str: + """Return a bounded authentication result label.""" + if result in ALLOWED_AUTH_RESULTS: + return result + return AUTH_RESULT_FAILURE + + +def normalize_auth_reason(reason: str) -> str: + """Return a bounded authentication reason label.""" + if reason in ALLOWED_AUTH_REASONS: + return reason + return AUTH_REASON_UNKNOWN + @contextmanager def measure_response_duration(path: str) -> Iterator[None]: @@ -157,3 +231,43 @@ def record_llm_inference_duration( ).observe(duration) except (AttributeError, TypeError, ValueError): logger.warning("Failed to update LLM inference duration metric", exc_info=True) + + +def record_auth_attempt(auth_module: str, result: str, reason: str) -> None: + """Record one authentication attempt. + + Args: + auth_module: Configured authentication module name. Unknown values are + recorded as ``unknown`` to keep metric cardinality bounded. + result: Bounded result label, such as ``success`` or ``failure``. + Unknown values are recorded as ``failure``. + reason: Bounded reason label for the result. Unknown values are recorded + as ``unknown``. + """ + try: + metrics.auth_attempts_total.labels( + normalize_auth_module(auth_module), + normalize_auth_result(result), + normalize_auth_reason(reason), + ).inc() + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update authentication metric", exc_info=True) + + +def record_auth_duration(auth_module: str, result: str, duration: float) -> None: + """Record authentication duration. + + Args: + auth_module: Configured authentication module name. Unknown values are + recorded as ``unknown`` to keep metric cardinality bounded. + result: Bounded result label, such as ``success`` or ``failure``. + Unknown values are recorded as ``failure``. + duration: Authentication duration in seconds. + """ + try: + metrics.auth_duration_seconds.labels( + normalize_auth_module(auth_module), + normalize_auth_result(result), + ).observe(duration) + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update authentication duration metric", exc_info=True) diff --git a/tests/unit/authentication/test_jwk_token.py b/tests/unit/authentication/test_jwk_token.py index 7b3ae4d2f..e4ae16e2d 100644 --- a/tests/unit/authentication/test_jwk_token.py +++ b/tests/unit/authentication/test_jwk_token.py @@ -13,6 +13,7 @@ from pytest_mock import MockerFixture from authentication.jwk_token import JwkTokenAuthDependency, _jwk_cache +from constants import AUTH_MOD_JWK_TOKEN from models.config import JwkConfiguration, JwtConfiguration TEST_USER_ID = "test-user-123" @@ -503,6 +504,29 @@ async def test_no_bearer( assert detail["cause"] == "No token found in Authorization header" +@pytest.mark.asyncio +async def test_unexpected_token_extraction_error_records_metrics( + mocker: MockerFixture, + default_jwk_configuration: JwkConfiguration, + valid_token: str, +) -> None: + """Test unexpected token extraction errors are recorded before re-raising.""" + mocker.patch( + "authentication.jwk_token.extract_user_token", + side_effect=RuntimeError("header parser failed"), + ) + mock_metrics = mocker.patch("authentication.jwk_token.record_auth_metrics") + mocker.patch("authentication.jwk_token.time.monotonic", return_value=7.0) + dependency = JwkTokenAuthDependency(default_jwk_configuration) + + with pytest.raises(RuntimeError): + await dependency(dummy_request(valid_token)) + + mock_metrics.assert_called_once_with( + AUTH_MOD_JWK_TOKEN, "failure", "unexpected_error", 7.0 + ) + + @pytest.fixture def no_user_id_token( single_key_set: list[dict[str, Any]], @@ -532,6 +556,21 @@ def no_user_id_token( ).decode() +@pytest.fixture +def invalid_user_id_token( + single_key_set: list[dict[str, Any]], + token_payload: dict[str, Any], + token_header: dict[str, Any], +) -> str: + """Token with an invalid user_id claim value.""" + jwt_instance = JsonWebToken(algorithms=["RS256"]) + token_payload["user_id"] = "" + + return jwt_instance.encode( + token_header, token_payload, single_key_set[0]["private_key"] + ).decode() + + @pytest.mark.asyncio async def test_no_user_id( default_jwk_configuration: JwkConfiguration, @@ -552,6 +591,24 @@ async def test_no_user_id( ) +@pytest.mark.asyncio +async def test_invalid_user_id_claim_has_exception_chain( + default_jwk_configuration: JwkConfiguration, + mocked_signing_keys_server: Any, + invalid_user_id_token: str, +) -> None: + """Test invalid required claims preserve root-cause exception context.""" + _ = mocked_signing_keys_server + + dependency = JwkTokenAuthDependency(default_jwk_configuration) + + with pytest.raises(HTTPException) as exc_info: + await dependency(dummy_request(invalid_user_id_token)) + + assert exc_info.value.status_code == 401 + assert isinstance(exc_info.value.__cause__, ValueError) + + @pytest.fixture def no_username_token( single_key_set: list[dict[str, Any]], diff --git a/tests/unit/authentication/test_k8s.py b/tests/unit/authentication/test_k8s.py index 94e0eb633..2a8563796 100644 --- a/tests/unit/authentication/test_k8s.py +++ b/tests/unit/authentication/test_k8s.py @@ -8,7 +8,7 @@ import pytest from fastapi import HTTPException, Request -from kubernetes.client import AuthenticationV1Api, AuthorizationV1Api +from kubernetes.client import AuthenticationV1Api, AuthorizationV1Api, V1UserInfo from kubernetes.client.rest import ApiException from pytest_mock import MockerFixture @@ -21,6 +21,7 @@ K8SAuthDependency, K8sClientSingleton, K8sConfigurationError, + _create_subject_access_review, get_user_info, ) from configuration import AppConfig @@ -1135,3 +1136,90 @@ def test_get_user_info_api_error_handling( detail = cast(dict[str, str], exc_info.value.detail) assert detail["response"] == expected_response assert expected_cause_fragment in detail["cause"] + + +@pytest.mark.parametrize( + "api_status,reason,expected_status,expected_response,expected_cause_fragment", + [ + ( + HTTPStatus.SERVICE_UNAVAILABLE, + "Service Unavailable", + 503, + "Unable to connect to Kubernetes API", + "Service Unavailable", + ), + ( + HTTPStatus.TOO_MANY_REQUESTS, + "Too Many Requests", + 503, + "Unable to connect to Kubernetes API", + "Too Many Requests", + ), + ( + None, + "Connection failed", + 503, + "Unable to connect to Kubernetes API", + "Connection failed", + ), + ( + HTTPStatus.BAD_REQUEST, + "Bad Request", + 500, + "Internal server error", + "Bad Request", + ), + ], +) +def test_create_subject_access_review_api_error_handling( + mocker: MockerFixture, + api_status: Optional[int], + reason: str, + expected_status: int, + expected_response: str, + expected_cause_fragment: str, +) -> None: + """Test SubjectAccessReview maps Kubernetes API errors consistently.""" + mock_authz_api = mocker.patch("authentication.k8s.K8sClientSingleton.get_authz_api") + mock_authz_api.return_value.create_subject_access_review.side_effect = ApiException( + status=api_status, reason=reason + ) + mock_metrics = mocker.patch("authentication.k8s.record_auth_metrics") + user = cast( + V1UserInfo, MockK8sUser(username="user@example.com", groups=["lsc-group"]) + ) + + with pytest.raises(HTTPException) as exc_info: + _create_subject_access_review(user, "/api/lightspeed/v1/query", 10.0) + + assert exc_info.value.status_code == expected_status + detail = cast(dict[str, str], exc_info.value.detail) + assert detail["response"] == expected_response + assert expected_cause_fragment in detail["cause"] + mock_metrics.assert_called_once_with( + "k8s", "failure", "authorization_check_error", 10.0 + ) + + +def test_create_subject_access_review_authz_init_failure( + mocker: MockerFixture, +) -> None: + """Test get_authz_api() init failure records authorization check errors.""" + mocker.patch( + "authentication.k8s.K8sClientSingleton.get_authz_api", + side_effect=RuntimeError("failed to load kubeconfig"), + ) + mock_metrics = mocker.patch("authentication.k8s.record_auth_metrics") + user = cast( + V1UserInfo, MockK8sUser(username="user@example.com", groups=["lsc-group"]) + ) + + with pytest.raises(HTTPException) as exc_info: + _create_subject_access_review(user, "/api/lightspeed/v1/query", 10.0) + + assert exc_info.value.status_code == HTTPStatus.SERVICE_UNAVAILABLE + detail = cast(dict[str, str], exc_info.value.detail) + assert "Unable to initialize Kubernetes client" in detail["cause"] + mock_metrics.assert_called_once_with( + "k8s", "failure", "authorization_check_error", 10.0 + ) diff --git a/tests/unit/authentication/test_rh_identity.py b/tests/unit/authentication/test_rh_identity.py index 814678d75..aa08f85a2 100644 --- a/tests/unit/authentication/test_rh_identity.py +++ b/tests/unit/authentication/test_rh_identity.py @@ -254,6 +254,14 @@ def test_validate_entitlements( {"identity": {"type": "User", "org_id": "123"}}, "Invalid identity data", ), + ( + {"identity": {"type": "User", "org_id": "123", "user": 1}}, + "Invalid identity data", + ), + ( + {"identity": {"type": "User", "org_id": "123", "user": []}}, + "Invalid identity data", + ), ( { "identity": { @@ -278,6 +286,10 @@ def test_validate_entitlements( {"identity": {"type": "System", "org_id": "123"}}, "Invalid identity data", ), + ( + {"identity": {"type": "System", "org_id": "123", "system": 1}}, + "Invalid identity data", + ), ( {"identity": {"type": "System", "org_id": "123", "system": {}}}, "Invalid identity data", @@ -424,6 +436,49 @@ async def test_invalid_json(self, mocker: MockerFixture) -> None: assert exc_info.value.status_code == 400 assert "Invalid JSON" in str(exc_info.value.detail) + @pytest.mark.asyncio + @pytest.mark.parametrize( + "identity_data", + [ + pytest.param({"identity": 1}, id="identity-not-dict"), + pytest.param( + {"identity": {"type": "User", "user": 1}}, + id="user-not-dict", + ), + pytest.param( + {"identity": {"type": "User", "user": []}}, + id="user-list", + ), + pytest.param( + {"identity": {"type": "System", "system": 1}}, + id="system-not-dict", + ), + ], + ) + async def test_invalid_nested_identity_payloads_record_metrics( + self, + mocker: MockerFixture, + identity_data: dict, + ) -> None: + """Test malformed nested identity payloads record invalid identity metrics.""" + auth_dep = RHIdentityAuthDependency() + header_value = create_auth_header(identity_data) + request = create_request_with_header(mocker, header_value) + mock_record = mocker.patch( + "authentication.rh_identity._record_rh_identity_auth" + ) + + with pytest.raises(HTTPException) as exc_info: + await auth_dep(request) + + assert exc_info.value.status_code == 400 + assert "Invalid identity data" in str(exc_info.value.detail) + mock_record.assert_called_once() + result, reason, duration = mock_record.call_args.args + assert result == "failure" + assert reason == "invalid_identity" + assert duration >= 0 + @pytest.mark.asyncio @pytest.mark.parametrize( "required_entitlements,should_raise,expected_error", @@ -509,9 +564,9 @@ def _mock_configuration( "path", [ "/readiness", + "/readiness/", "/liveness", - "/api/lightspeed/readiness", - "/api/lightspeed/liveness", + "/liveness/", ], ) async def test_probe_paths_skip_auth_when_enabled( @@ -532,8 +587,6 @@ async def test_probe_paths_skip_auth_when_enabled( [ "/readiness", "/liveness", - "/api/lightspeed/readiness", - "/api/lightspeed/liveness", ], ) async def test_probe_paths_require_auth_when_disabled( @@ -550,11 +603,19 @@ async def test_probe_paths_require_auth_when_disabled( assert exc_info.value.status_code == 401 @pytest.mark.asyncio - @pytest.mark.parametrize("path", ["/", "/v1/query"]) + @pytest.mark.parametrize( + "path", + [ + "/", + "/v1/query", + "/api/lightspeed/readiness", + "/api/lightspeed/liveness", + ], + ) async def test_non_probe_paths_require_auth_when_skip_enabled( self, mocker: MockerFixture, path: str ) -> None: - """Test non-probe paths still require auth even when skip_for_health_probes is True.""" + """Test non-probe paths still require auth even when probe skipping is enabled.""" self._mock_configuration(mocker, skip_for_health_probes=True) auth_dep = RHIdentityAuthDependency() @@ -587,7 +648,7 @@ def _mock_configuration(mocker: MockerFixture, skip_for_metrics: bool) -> None: "path", [ "/metrics", - "/api/lightspeed/metrics", + "/metrics/", ], ) async def test_metrics_path_skips_auth_when_enabled( @@ -607,7 +668,6 @@ async def test_metrics_path_skips_auth_when_enabled( "path", [ "/metrics", - "/api/lightspeed/metrics", ], ) async def test_metrics_path_requires_auth_when_disabled( @@ -624,11 +684,19 @@ async def test_metrics_path_requires_auth_when_disabled( assert exc_info.value.status_code == 401 @pytest.mark.asyncio - @pytest.mark.parametrize("path", ["/", "/v1/query"]) + @pytest.mark.parametrize( + "path", + [ + "/", + "/v1/query", + "/api/lightspeed/metrics", + "/v1/notmetrics", + ], + ) async def test_non_metrics_paths_require_auth_when_skip_enabled( self, mocker: MockerFixture, path: str ) -> None: - """Test non-metrics paths still require auth even when skip_for_metrics is True.""" + """Test non-metrics paths still require auth even when metrics skipping is enabled.""" self._mock_configuration(mocker, skip_for_metrics=True) auth_dep = RHIdentityAuthDependency() diff --git a/tests/unit/authentication/test_utils.py b/tests/unit/authentication/test_utils.py index 36490d2bb..a3c4c43d9 100644 --- a/tests/unit/authentication/test_utils.py +++ b/tests/unit/authentication/test_utils.py @@ -3,9 +3,10 @@ from typing import cast from fastapi import HTTPException +from pytest_mock import MockerFixture from starlette.datastructures import Headers -from authentication.utils import extract_user_token +from authentication.utils import extract_user_token, record_auth_metrics def test_extract_user_token() -> None: @@ -41,3 +42,50 @@ def test_extract_user_token_invalid_format() -> None: "Missing or invalid credentials provided by client" ) assert detail["cause"] == "No token found in Authorization header" + + +def test_record_auth_metrics_records_attempt_and_duration( + mocker: MockerFixture, +) -> None: + """Test recording auth attempt and duration through the shared helper.""" + mock_attempt = mocker.patch("authentication.utils.recording.record_auth_attempt") + mock_duration = mocker.patch("authentication.utils.recording.record_auth_duration") + mock_monotonic = mocker.patch("authentication.utils.time.monotonic") + mock_monotonic.return_value = 12.5 + + record_auth_metrics("jwk-token", "success", "authenticated", 10.0) + + mock_attempt.assert_called_once_with("jwk-token", "success", "authenticated") + mock_duration.assert_called_once_with("jwk-token", "success", 2.5) + + +def test_record_auth_metrics_records_failure_attempt_and_duration( + mocker: MockerFixture, +) -> None: + """Test recording failed auth attempts and duration through the shared helper.""" + mock_attempt = mocker.patch("authentication.utils.recording.record_auth_attempt") + mock_duration = mocker.patch("authentication.utils.recording.record_auth_duration") + mock_monotonic = mocker.patch("authentication.utils.time.monotonic") + mock_monotonic.return_value = 15.25 + + record_auth_metrics("jwk-token", "failure", "missing_token", 10.0) + + mock_attempt.assert_called_once_with("jwk-token", "failure", "missing_token") + mock_duration.assert_called_once_with("jwk-token", "failure", 5.25) + + +def test_record_auth_metrics_swallows_recording_exceptions( + mocker: MockerFixture, +) -> None: + """Test that record_auth_metrics catches and logs recording failures.""" + mocker.patch( + "authentication.utils.recording.record_auth_attempt", + side_effect=RuntimeError("metrics backend unavailable"), + ) + mock_warning = mocker.patch("authentication.utils.logger.warning") + + # Should not raise despite recording failure + record_auth_metrics("jwk-token", "failure", "missing_token", 10.0) + + mock_warning.assert_called_once() + assert "Failed to record authentication metrics" in mock_warning.call_args[0][0] diff --git a/tests/unit/metrics/test_recording.py b/tests/unit/metrics/test_recording.py index 6cf4fcb6e..6c797142f 100644 --- a/tests/unit/metrics/test_recording.py +++ b/tests/unit/metrics/test_recording.py @@ -218,6 +218,7 @@ def test_histogram_recorders_observe_metrics_and_log_errors( ) + def test_record_llm_inference_duration_bounds_result_label( mocker: MockerFixture, ) -> None: @@ -234,3 +235,81 @@ def test_record_llm_inference_duration_bounds_result_label( "vertexai", "gemini", "/v1/responses", "failure" ) mock_metric.labels.return_value.observe.assert_called_once_with(2.0) + + +def test_record_auth_attempt_updates_metric( + mocker: MockerFixture, +) -> None: + """Test auth attempt helper success behavior.""" + mock_metric = mocker.patch("metrics.recording.metrics.auth_attempts_total") + + recording.record_auth_attempt("rh-identity", "success", "authenticated") + + mock_metric.labels.assert_called_once_with( + "rh-identity", "success", "authenticated" + ) + mock_metric.labels.return_value.inc.assert_called_once() + + +def test_record_auth_attempt_bounds_labels(mocker: MockerFixture) -> None: + """Test auth attempt helper normalizes unbounded label values.""" + mock_metric = mocker.patch("metrics.recording.metrics.auth_attempts_total") + + recording.record_auth_attempt("customer-123", "timeout", "database-down") + + mock_metric.labels.assert_called_once_with("unknown", "failure", "unknown") + mock_metric.labels.return_value.inc.assert_called_once() + + +def test_record_auth_attempt_logs_metric_errors( + mocker: MockerFixture, + recording_logger: MockType, +) -> None: + """Test auth attempt helper logs and swallows metric failures.""" + mock_metric = mocker.patch("metrics.recording.metrics.auth_attempts_total") + + mock_metric.labels.return_value.inc.side_effect = AttributeError("missing") + + recording.record_auth_attempt("rh-identity", "success", "authenticated") + + recording_logger.warning.assert_called_once_with( + "Failed to update authentication metric", exc_info=True + ) + + +def test_record_auth_duration_updates_metric( + mocker: MockerFixture, +) -> None: + """Test auth duration helper success behavior.""" + mock_metric = mocker.patch("metrics.recording.metrics.auth_duration_seconds") + + recording.record_auth_duration("rh-identity", "success", 0.5) + + mock_metric.labels.assert_called_once_with("rh-identity", "success") + mock_metric.labels.return_value.observe.assert_called_once_with(0.5) + + +def test_record_auth_duration_bounds_labels(mocker: MockerFixture) -> None: + """Test auth duration helper normalizes unbounded label values.""" + mock_metric = mocker.patch("metrics.recording.metrics.auth_duration_seconds") + + recording.record_auth_duration("customer-123", "timeout", 0.5) + + mock_metric.labels.assert_called_once_with("unknown", "failure") + mock_metric.labels.return_value.observe.assert_called_once_with(0.5) + + +def test_record_auth_duration_logs_metric_errors( + mocker: MockerFixture, + recording_logger: MockType, +) -> None: + """Test auth duration helper logs and swallows metric failures.""" + mock_metric = mocker.patch("metrics.recording.metrics.auth_duration_seconds") + + mock_metric.labels.return_value.observe.side_effect = TypeError("bad") + + recording.record_auth_duration("rh-identity", "success", 0.5) + + recording_logger.warning.assert_called_once_with( + "Failed to update authentication duration metric", exc_info=True + ) From 1ed841147f1596049cb46ffef7077a6290f42797 Mon Sep 17 00:00:00 2001 From: Major Hayden Date: Wed, 29 Apr 2026 18:22:59 -0500 Subject: [PATCH 2/3] feat: add authorization monitoring metrics Signed-off-by: Major Hayden --- src/authorization/middleware.py | 97 ++++++++++++++------- src/metrics/__init__.py | 31 +++++++ src/metrics/recording.py | 83 ++++++++++++++++++ tests/unit/authorization/test_middleware.py | 31 +++++++ tests/unit/metrics/test_recording.py | 79 +++++++++++++++++ 5 files changed, 290 insertions(+), 31 deletions(-) diff --git a/src/authorization/middleware.py b/src/authorization/middleware.py index 2aaa8d415..5073e841d 100644 --- a/src/authorization/middleware.py +++ b/src/authorization/middleware.py @@ -1,5 +1,6 @@ """Authorization middleware and decorators.""" +import time from collections.abc import Callable from functools import lru_cache, wraps from typing import Any, Optional @@ -18,6 +19,7 @@ ) from configuration import configuration from log import get_logger +from metrics import recording from models.api.responses.error import ( ForbiddenResponse, InternalServerErrorResponse, @@ -27,6 +29,31 @@ logger = get_logger(__name__) +def _record_authorization_metrics( + action: Action, + result: str, + start_time: float, +) -> None: + """Record authorization metrics without affecting request authorization flow. + + Args: + action: Protected action being authorized. + result: Authorization result label. + start_time: Monotonic timestamp captured before authorization began. + """ + duration = time.monotonic() - start_time + + try: + recording.record_authorization_check(action.value, result) + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to record authorization check metric", exc_info=True) + + try: + recording.record_authorization_duration(action.value, result, duration) + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Failed to record authorization duration metric", exc_info=True) + + @lru_cache(maxsize=1) def get_authorization_resolvers() -> tuple[RolesResolver, AccessResolver]: """Get authorization resolvers from configuration (cached). @@ -124,39 +151,47 @@ async def _perform_authorization_check( HTTPException: with 403 Forbidden if the resolved roles are not permitted to perform `action`. """ - role_resolver, access_resolver = get_authorization_resolvers() + start_time = time.monotonic() + result = "error" try: - auth = kwargs["auth"] - except KeyError as exc: - logger.error( - "Authorization only allowed on endpoints that accept " - "'auth: Any = Depends(get_auth_dependency())'" - ) - response = InternalServerErrorResponse.generic() - raise HTTPException(**response.model_dump()) from exc - - # Everyone gets the everyone (aka *) role - everyone_roles = {"*"} - - user_roles = await role_resolver.resolve_roles(auth) | everyone_roles - - if not access_resolver.check_access(action, user_roles): - response = ForbiddenResponse.endpoint(user_id=auth[0]) - raise HTTPException(**response.model_dump()) - - authorized_actions = access_resolver.get_actions(user_roles) - - req: Optional[Request] = None - if "request" in kwargs and isinstance(kwargs["request"], Request): - req = kwargs["request"] - else: - for arg in args: - if isinstance(arg, Request): - req = arg - break - if req is not None: - req.state.authorized_actions = authorized_actions + role_resolver, access_resolver = get_authorization_resolvers() + + try: + auth = kwargs["auth"] + except KeyError as exc: + logger.error( + "Authorization only allowed on endpoints that accept " + "'auth: Any = Depends(get_auth_dependency())'" + ) + response = InternalServerErrorResponse.generic() + raise HTTPException(**response.model_dump()) from exc + + # Everyone gets the everyone (aka *) role + everyone_roles = {"*"} + + user_roles = await role_resolver.resolve_roles(auth) | everyone_roles + + if not access_resolver.check_access(action, user_roles): + response = ForbiddenResponse.endpoint(user_id=auth[0]) + result = "denied" + raise HTTPException(**response.model_dump()) + + authorized_actions = access_resolver.get_actions(user_roles) + + req: Optional[Request] = None + if "request" in kwargs and isinstance(kwargs["request"], Request): + req = kwargs["request"] + else: + for arg in args: + if isinstance(arg, Request): + req = arg + break + if req is not None: + req.state.authorized_actions = authorized_actions + result = "success" + finally: + _record_authorization_metrics(action, result, start_time) def authorize(action: Action) -> Callable: diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py index d288bfe92..5324f35e7 100644 --- a/src/metrics/__init__.py +++ b/src/metrics/__init__.py @@ -36,6 +36,21 @@ 5.0, float("inf"), ) + +AUTHORIZATION_DURATION_BUCKETS: Final[tuple[float, ...]] = ( + 0.001, + 0.005, + 0.01, + 0.025, + 0.05, + 0.1, + 0.25, + 0.5, + 1.0, + 2.5, + 5.0, + float("inf"), +) # Counter to track REST API calls # This will be used to count how many times each API endpoint is called # and the status code of the response @@ -113,3 +128,19 @@ ["auth_module", "result"], buckets=AUTH_DURATION_BUCKETS, ) + +# Counter to track authorization checks by bounded protected action and result. +# Actions are normalized against the Action enum; results are success, denied, or error. +authorization_checks_total: Final[Counter] = Counter( + "ls_authorization_checks_total", + "Authorization checks", + ["action", "result"], +) + +# Histogram to measure authorization check latency by bounded action and result. +authorization_duration_seconds: Final[Histogram] = Histogram( + "ls_authorization_duration_seconds", + "Authorization check duration", + ["action", "result"], + buckets=AUTHORIZATION_DURATION_BUCKETS, +) diff --git a/src/metrics/recording.py b/src/metrics/recording.py index 50af0b332..05d9a8a92 100644 --- a/src/metrics/recording.py +++ b/src/metrics/recording.py @@ -12,6 +12,7 @@ import metrics from constants import SUPPORTED_AUTHENTICATION_MODULES from log import get_logger +from models.config import Action logger = get_logger(__name__) @@ -89,6 +90,50 @@ def normalize_auth_reason(reason: str) -> str: return AUTH_REASON_UNKNOWN +AUTHORIZATION_ACTION_UNKNOWN: Final[str] = "unknown" +AUTHORIZATION_RESULT_SUCCESS: Final[str] = "success" +AUTHORIZATION_RESULT_DENIED: Final[str] = "denied" +AUTHORIZATION_RESULT_ERROR: Final[str] = "error" + +ALLOWED_AUTHORIZATION_ACTIONS: Final[frozenset[str]] = frozenset( + action.value for action in Action +) +ALLOWED_AUTHORIZATION_RESULTS: Final[frozenset[str]] = frozenset( + { + AUTHORIZATION_RESULT_SUCCESS, + AUTHORIZATION_RESULT_DENIED, + AUTHORIZATION_RESULT_ERROR, + } +) + + +def normalize_authorization_action(action: str) -> str: + """Normalize authorization action labels to the bounded Action enum values. + + Args: + action: Raw authorization action label. + + Returns: + The action when it is a known protected action, otherwise ``unknown``. + """ + if action in ALLOWED_AUTHORIZATION_ACTIONS: + return action + return AUTHORIZATION_ACTION_UNKNOWN + + +def normalize_authorization_result(result: str) -> str: + """Normalize authorization result labels to the bounded result set. + + Args: + result: Raw authorization result label. + + Returns: + The result when it is allowed, otherwise ``error``. + """ + if result in ALLOWED_AUTHORIZATION_RESULTS: + return result + return AUTHORIZATION_RESULT_ERROR + @contextmanager def measure_response_duration(path: str) -> Iterator[None]: """Measure REST API response duration for a route path. @@ -233,6 +278,7 @@ def record_llm_inference_duration( logger.warning("Failed to update LLM inference duration metric", exc_info=True) + def record_auth_attempt(auth_module: str, result: str, reason: str) -> None: """Record one authentication attempt. @@ -271,3 +317,40 @@ def record_auth_duration(auth_module: str, result: str, duration: float) -> None ).observe(duration) except (AttributeError, TypeError, ValueError): logger.warning("Failed to update authentication duration metric", exc_info=True) + + +def record_authorization_check(action: str, result: str) -> None: + """Record one authorization check. + + Args: + action: Protected action name. Unknown values are recorded as ``unknown``. + result: Bounded result label. Unknown values are recorded as ``error``. + """ + normalized_action = normalize_authorization_action(action) + normalized_result = normalize_authorization_result(result) + + try: + metrics.authorization_checks_total.labels( + normalized_action, normalized_result + ).inc() + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update authorization metric", exc_info=True) + + +def record_authorization_duration(action: str, result: str, duration: float) -> None: + """Record authorization check duration. + + Args: + action: Protected action name. Unknown values are recorded as ``unknown``. + result: Bounded result label. Unknown values are recorded as ``error``. + duration: Authorization check duration in seconds. + """ + normalized_action = normalize_authorization_action(action) + normalized_result = normalize_authorization_result(result) + + try: + metrics.authorization_duration_seconds.labels( + normalized_action, normalized_result + ).observe(duration) + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update authorization duration metric", exc_info=True) diff --git a/tests/unit/authorization/test_middleware.py b/tests/unit/authorization/test_middleware.py index 9f2904b17..054702972 100644 --- a/tests/unit/authorization/test_middleware.py +++ b/tests/unit/authorization/test_middleware.py @@ -322,6 +322,37 @@ async def test_everyone_role_added( Action.QUERY, {"employee", "*"} ) + @pytest.mark.asyncio + async def test_authorization_metric_errors_do_not_mask_success( + self, + mocker: MockerFixture, + dummy_auth_tuple: AuthTuple, + mock_resolvers: tuple[MockType, MockType], + ) -> None: + """Test metric recorder failures do not fail successful authorization.""" + mocker.patch( + "authorization.middleware.get_authorization_resolvers", + return_value=mock_resolvers, + ) + mock_check = mocker.patch( + "authorization.middleware.recording.record_authorization_check", + side_effect=RuntimeError("metric backend unavailable"), + ) + mock_duration = mocker.patch( + "authorization.middleware.recording.record_authorization_duration" + ) + mock_logger = mocker.patch("authorization.middleware.logger") + + await _perform_authorization_check(Action.QUERY, (), {"auth": dummy_auth_tuple}) + + mock_check.assert_called_once_with(Action.QUERY.value, "success") + mock_duration.assert_called_once() + assert mock_duration.call_args.args[:2] == (Action.QUERY.value, "success") + assert mock_duration.call_args.args[2] >= 0 + mock_logger.warning.assert_called_once_with( + "Failed to record authorization check metric", exc_info=True + ) + class TestAuthorizeDecorator: """Test cases for authorize decorator.""" diff --git a/tests/unit/metrics/test_recording.py b/tests/unit/metrics/test_recording.py index 6c797142f..255e91421 100644 --- a/tests/unit/metrics/test_recording.py +++ b/tests/unit/metrics/test_recording.py @@ -9,6 +9,17 @@ from metrics import recording +@dataclass(frozen=True) +class CounterRecorderCase: + """Expected behavior for a counter-style metric recorder.""" + + metric_path: str + recorder: Callable[..., None] + args: tuple[object, ...] + labels: tuple[object, ...] + warning_message: str + + @dataclass(frozen=True) class HistogramRecorderCase: """Expected behavior for a histogram-style metric recorder.""" @@ -183,6 +194,42 @@ def recording_logger_fixture(mocker: MockerFixture) -> MockType: return mocker.patch("metrics.recording.logger") +@pytest.mark.parametrize( + "case", + [ + CounterRecorderCase( + metric_path="metrics.recording.metrics.authorization_checks_total", + recorder=recording.record_authorization_check, + args=("responses", "success"), + labels=("responses", "success"), + warning_message="Failed to update authorization metric", + ), + ], +) +def test_counter_recorders_update_metrics_and_log_errors( + mocker: MockerFixture, + recording_logger: MockType, + case: CounterRecorderCase, +) -> None: + """Test new single-counter helpers with shared success and failure coverage.""" + mock_metric = mocker.patch(case.metric_path) + + case.recorder(*case.args) + + mock_metric.labels.assert_called_once_with(*case.labels) + mock_metric.labels.return_value.inc.assert_called_once() + recording_logger.warning.assert_not_called() + + mock_metric.reset_mock() + recording_logger.reset_mock() + mock_metric.labels.return_value.inc.side_effect = AttributeError("missing") + case.recorder(*case.args) + + recording_logger.warning.assert_called_once_with( + case.warning_message, exc_info=True + ) + + @pytest.mark.parametrize( "case", [ @@ -194,6 +241,14 @@ def recording_logger_fixture(mocker: MockerFixture) -> MockType: duration=1.5, warning_message="Failed to update LLM inference duration metric", ), + HistogramRecorderCase( + metric_path="metrics.recording.metrics.authorization_duration_seconds", + recorder=recording.record_authorization_duration, + args=("responses", "success", 0.25), + labels=("responses", "success"), + duration=0.25, + warning_message="Failed to update authorization duration metric", + ), ], ) def test_histogram_recorders_observe_metrics_and_log_errors( @@ -208,8 +263,10 @@ def test_histogram_recorders_observe_metrics_and_log_errors( mock_metric.labels.assert_called_once_with(*case.labels) mock_metric.labels.return_value.observe.assert_called_once_with(case.duration) + recording_logger.warning.assert_not_called() mock_metric.reset_mock() + recording_logger.reset_mock() mock_metric.labels.return_value.observe.side_effect = TypeError("bad") case.recorder(*case.args) @@ -313,3 +370,25 @@ def test_record_auth_duration_logs_metric_errors( recording_logger.warning.assert_called_once_with( "Failed to update authentication duration metric", exc_info=True ) + + +def test_record_authorization_check_bounds_labels(mocker: MockerFixture) -> None: + """Test authorization check labels are normalized before recording.""" + mock_metric = mocker.patch("metrics.recording.metrics.authorization_checks_total") + + recording.record_authorization_check("customer-123", "unexpected") + + mock_metric.labels.assert_called_once_with("unknown", "error") + mock_metric.labels.return_value.inc.assert_called_once() + + +def test_record_authorization_duration_bounds_labels(mocker: MockerFixture) -> None: + """Test authorization duration labels are normalized before recording.""" + mock_metric = mocker.patch( + "metrics.recording.metrics.authorization_duration_seconds" + ) + + recording.record_authorization_duration("customer-123", "unexpected", 0.25) + + mock_metric.labels.assert_called_once_with("unknown", "error") + mock_metric.labels.return_value.observe.assert_called_once_with(0.25) From 0eccec015eab1f8060134bd6bec6cb37a17c1185 Mon Sep 17 00:00:00 2001 From: Major Hayden Date: Wed, 29 Apr 2026 18:24:35 -0500 Subject: [PATCH 3/3] feat: add quota monitoring metrics Signed-off-by: Major Hayden --- src/app/endpoints/responses.py | 34 ++++++++++- src/app/endpoints/rlsapi_v1.py | 70 +++++++++++++++++++-- src/metrics/__init__.py | 33 ++++++++++ src/metrics/recording.py | 71 +++++++++++++++++++++- tests/unit/app/endpoints/test_responses.py | 24 ++++++++ tests/unit/app/endpoints/test_rlsapi_v1.py | 40 ++++++++++++ tests/unit/metrics/test_recording.py | 57 ++++++++++++++++- 7 files changed, 320 insertions(+), 9 deletions(-) diff --git a/src/app/endpoints/responses.py b/src/app/endpoints/responses.py index 2705a27dd..2a7dc9ace 100644 --- a/src/app/endpoints/responses.py +++ b/src/app/endpoints/responses.py @@ -138,6 +138,37 @@ def _get_user_agent(request: Request) -> Optional[str]: return sanitized or None +def _check_response_quota(user_id: str, endpoint_path: str) -> None: + """Check response quota availability and record bounded quota metrics.""" + quota_start_time = time.monotonic() + try: + check_tokens_available(configuration.quota_limiters, user_id) + except HTTPException: + recording.record_quota_check( + endpoint_path, + recording.QUOTA_TYPE_USER_ID, + recording.QUOTA_RESULT_FAILURE, + time.monotonic() - quota_start_time, + ) + raise + except Exception: # pylint: disable=broad-exception-caught + # Unexpected quota backend failures still need bounded metrics before + # propagating to the endpoint error handling layer. + recording.record_quota_check( + endpoint_path, + recording.QUOTA_TYPE_USER_ID, + recording.QUOTA_RESULT_ERROR, + time.monotonic() - quota_start_time, + ) + raise + recording.record_quota_check( + endpoint_path, + recording.QUOTA_TYPE_USER_ID, + recording.QUOTA_RESULT_SUCCESS, + time.monotonic() - quota_start_time, + ) + + responses_response: dict[int | str, dict[str, Any]] = { 200: ResponsesResponse.openapi_response(), 401: UnauthorizedResponse.openapi_response( @@ -358,11 +389,12 @@ async def responses_endpoint_handler( started_at = datetime.now(UTC) rh_identity_context = get_rh_identity_context(request) user_id, _, _, token = auth + endpoint_path = ENDPOINT_PATH_RESPONSES await check_mcp_auth(configuration, mcp_headers, token, request.headers) # Check token availability - check_tokens_available(configuration.quota_limiters, user_id) + _check_response_quota(user_id, endpoint_path) # Enforce RBAC: optionally disallow overriding model in requests validate_model_provider_override( diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index 08555a4d5..14ff8e823 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -575,6 +575,63 @@ def _resolve_quota_subject(request: Request, auth: AuthTuple) -> Optional[str]: return system_id +def _check_infer_quota( + request: Request, auth: AuthTuple, endpoint_path: str +) -> Optional[str]: + """Check infer quota availability and record bounded quota metrics. + + Resolves the quota subject from the request and auth context, then + verifies that the subject has tokens available. All outcomes (success, + failure, error, skipped) are recorded as Prometheus metrics. + + Args: + request: The incoming FastAPI request used to resolve the quota subject. + auth: Authentication tuple ``(user_id, username, skip_userid_check, token)``. + endpoint_path: API endpoint path for metric labeling. + + Returns: + The resolved quota subject identifier, or ``None`` when quota is disabled. + + Raises: + HTTPException: Re-raised from the quota limiter when the subject has + exhausted its token allowance (HTTP 429). + """ + quota_id = _resolve_quota_subject(request, auth) + quota_type = configuration.rlsapi_v1.quota_subject or "disabled" + if quota_id is None: + recording.record_quota_check( + endpoint_path, quota_type, recording.QUOTA_RESULT_SKIPPED, 0.0 + ) + return None + + quota_start_time = time.monotonic() + try: + check_tokens_available(configuration.quota_limiters, quota_id) + except HTTPException: + recording.record_quota_check( + endpoint_path, + quota_type, + recording.QUOTA_RESULT_FAILURE, + time.monotonic() - quota_start_time, + ) + raise + except Exception: # pylint: disable=broad-exception-caught + recording.record_quota_check( + endpoint_path, + quota_type, + recording.QUOTA_RESULT_ERROR, + time.monotonic() - quota_start_time, + ) + raise + recording.record_quota_check( + endpoint_path, + quota_type, + recording.QUOTA_RESULT_SUCCESS, + time.monotonic() - quota_start_time, + ) + return quota_id + + def _build_infer_response( response_text: str, request_id: str, @@ -733,16 +790,17 @@ async def infer_endpoint( # pylint: disable=R0914,R0915 logger.info("Processing rlsapi v1 /infer request %s", request_id) - # Quota enforcement: resolve subject and check availability before any work. - # No-op when quota_subject is not configured or no quota limiters exist. - quota_id = _resolve_quota_subject(request, auth) - if quota_id is not None: + # Quota enforcement: check availability before any work and record metrics for + # both enforced and disabled quota paths. + quota_subject = configuration.rlsapi_v1.quota_subject + if quota_subject is not None: logger.info( "Checking quota availability for rlsapi v1 request %s using subject type %s", request_id, - configuration.rlsapi_v1.quota_subject, + quota_subject, ) - check_tokens_available(configuration.quota_limiters, quota_id) + quota_id = _check_infer_quota(request, auth, endpoint_path) + if quota_id is not None: logger.info( "Quota availability check passed for rlsapi v1 request %s", request_id ) diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py index 5324f35e7..c4d78f6b9 100644 --- a/src/metrics/__init__.py +++ b/src/metrics/__init__.py @@ -51,6 +51,21 @@ 5.0, float("inf"), ) + +QUOTA_CHECK_DURATION_BUCKETS: Final[tuple[float, ...]] = ( + 0.001, + 0.005, + 0.01, + 0.025, + 0.05, + 0.1, + 0.25, + 0.5, + 1.0, + 2.5, + 5.0, + float("inf"), +) # Counter to track REST API calls # This will be used to count how many times each API endpoint is called # and the status code of the response @@ -144,3 +159,21 @@ ["action", "result"], buckets=AUTHORIZATION_DURATION_BUCKETS, ) + +# Counter to track pre-request quota checks. Labels must stay bounded: +# endpoint uses static route patterns, quota_type is a configured quota subject, +# and result is one terminal state from the recording helper. +quota_checks_total: Final[Counter] = Counter( + "ls_quota_checks_total", + "Quota availability checks", + ["endpoint", "quota_type", "result"], +) + +# Histogram to measure quota availability check latency with sub-second buckets. +# It uses the same bounded endpoint/quota_type/result labels as the counter. +quota_check_duration_seconds: Final[Histogram] = Histogram( + "ls_quota_check_duration_seconds", + "Quota availability check duration", + ["endpoint", "quota_type", "result"], + buckets=QUOTA_CHECK_DURATION_BUCKETS, +) diff --git a/src/metrics/recording.py b/src/metrics/recording.py index 05d9a8a92..d940273cd 100644 --- a/src/metrics/recording.py +++ b/src/metrics/recording.py @@ -134,6 +134,48 @@ def normalize_authorization_result(result: str) -> str: return result return AUTHORIZATION_RESULT_ERROR + +QUOTA_TYPE_USER_ID: Final[str] = "user_id" +QUOTA_TYPE_ORG_ID: Final[str] = "org_id" +QUOTA_TYPE_SYSTEM_ID: Final[str] = "system_id" +QUOTA_TYPE_DISABLED: Final[str] = "disabled" +QUOTA_RESULT_SUCCESS: Final[str] = "success" +QUOTA_RESULT_FAILURE: Final[str] = "failure" +QUOTA_RESULT_SKIPPED: Final[str] = "skipped" +QUOTA_RESULT_ERROR: Final[str] = "error" + +ALLOWED_QUOTA_TYPES: Final[frozenset[str]] = frozenset( + { + QUOTA_TYPE_USER_ID, + QUOTA_TYPE_ORG_ID, + QUOTA_TYPE_SYSTEM_ID, + QUOTA_TYPE_DISABLED, + } +) +ALLOWED_QUOTA_RESULTS: Final[frozenset[str]] = frozenset( + { + QUOTA_RESULT_SUCCESS, + QUOTA_RESULT_FAILURE, + QUOTA_RESULT_SKIPPED, + QUOTA_RESULT_ERROR, + } +) + + +def normalize_quota_type(quota_type: str) -> str: + """Return a bounded quota type label for Prometheus cardinality safety.""" + if quota_type in ALLOWED_QUOTA_TYPES: + return quota_type + return "unknown" + + +def normalize_quota_result(result: str) -> str: + """Return a bounded quota result label for Prometheus cardinality safety.""" + if result in ALLOWED_QUOTA_RESULTS: + return result + return QUOTA_RESULT_ERROR + + @contextmanager def measure_response_duration(path: str) -> Iterator[None]: """Measure REST API response duration for a route path. @@ -278,7 +320,6 @@ def record_llm_inference_duration( logger.warning("Failed to update LLM inference duration metric", exc_info=True) - def record_auth_attempt(auth_module: str, result: str, reason: str) -> None: """Record one authentication attempt. @@ -354,3 +395,31 @@ def record_authorization_duration(action: str, result: str, duration: float) -> ).observe(duration) except (AttributeError, TypeError, ValueError): logger.warning("Failed to update authorization duration metric", exc_info=True) + + +def record_quota_check( + endpoint_path: str, quota_type: str, result: str, duration: float +) -> None: + """Record a quota availability check. + + Args: + endpoint_path: API endpoint path for metric labeling. + quota_type: Bounded quota subject type, not the subject identifier. Out-of-set + values are recorded as ``unknown``. + result: Bounded result label. Out-of-set values are recorded as ``error``. + duration: Quota check duration in seconds. + """ + normalized_quota_type = normalize_quota_type(quota_type) + normalized_result = normalize_quota_result(result) + try: + metrics.quota_checks_total.labels( + endpoint_path, normalized_quota_type, normalized_result + ).inc() + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update quota check counter", exc_info=True) + try: + metrics.quota_check_duration_seconds.labels( + endpoint_path, normalized_quota_type, normalized_result + ).observe(duration) + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update quota check duration metric", exc_info=True) diff --git a/tests/unit/app/endpoints/test_responses.py b/tests/unit/app/endpoints/test_responses.py index b5d2e5bd9..096f2f450 100644 --- a/tests/unit/app/endpoints/test_responses.py +++ b/tests/unit/app/endpoints/test_responses.py @@ -18,6 +18,7 @@ from pytest_mock import MockerFixture from app.endpoints.responses import ( + _check_response_quota, _is_server_mcp_output_item, _sanitize_response_dict, _should_filter_mcp_chunk, @@ -278,6 +279,29 @@ def _request_with_previous_response_id( return request +def test_check_response_quota_records_unexpected_errors( + minimal_config: AppConfig, + mocker: MockerFixture, +) -> None: + """Test unexpected quota failures are recorded before being re-raised.""" + mocker.patch(f"{MODULE}.configuration", minimal_config) + mocker.patch( + f"{MODULE}.check_tokens_available", + side_effect=RuntimeError("quota backend unavailable"), + ) + mock_record = mocker.patch(f"{MODULE}.recording.record_quota_check") + + with pytest.raises(RuntimeError, match="quota backend unavailable"): + _check_response_quota("user-123", "/v1/responses") + + mock_record.assert_called_once() + endpoint_path, quota_type, result, duration = mock_record.call_args.args + assert endpoint_path == "/v1/responses" + assert quota_type == "user_id" + assert result == "error" + assert duration >= 0 + + class TestResponsesEndpointHandler: """Unit tests for responses_endpoint_handler.""" diff --git a/tests/unit/app/endpoints/test_rlsapi_v1.py b/tests/unit/app/endpoints/test_rlsapi_v1.py index 38227cf4c..254c8781b 100644 --- a/tests/unit/app/endpoints/test_rlsapi_v1.py +++ b/tests/unit/app/endpoints/test_rlsapi_v1.py @@ -1185,6 +1185,7 @@ async def test_infer_quota_skipped_when_not_configured( """Test /infer skips quota calls when quota_subject is None (default).""" mock_check = mocker.patch("app.endpoints.rlsapi_v1.check_tokens_available") mock_consume = mocker.patch("app.endpoints.rlsapi_v1.consume_query_tokens") + mock_record = mocker.patch("app.endpoints.rlsapi_v1.recording.record_quota_check") await infer_endpoint( infer_request=RlsapiV1InferRequest(question="How do I list files?"), @@ -1195,6 +1196,12 @@ async def test_infer_quota_skipped_when_not_configured( mock_check.assert_not_called() mock_consume.assert_not_called() + mock_record.assert_called_once_with( + rlsapi_v1.ENDPOINT_PATH_INFER, + rlsapi_v1.recording.QUOTA_TYPE_DISABLED, + rlsapi_v1.recording.QUOTA_RESULT_SKIPPED, + 0.0, + ) @pytest.mark.asyncio @@ -1224,6 +1231,39 @@ async def test_infer_quota_exceeded_returns_429( assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS +@pytest.mark.asyncio +async def test_infer_quota_records_unexpected_errors( + mocker: MockerFixture, + mock_quota_config: Callable[[str], None], + mock_llm_response: None, + mock_auth_resolvers: None, + mock_request_factory: Callable[..., Any], + mock_background_tasks: Any, +) -> None: + """Test unexpected quota failures are recorded before being re-raised.""" + mock_quota_config("user_id") + mocker.patch( + "app.endpoints.rlsapi_v1.check_tokens_available", + side_effect=RuntimeError("quota backend unavailable"), + ) + mock_record = mocker.patch("app.endpoints.rlsapi_v1.recording.record_quota_check") + + with pytest.raises(RuntimeError, match="quota backend unavailable"): + await infer_endpoint( + infer_request=RlsapiV1InferRequest(question="How do I list files?"), + request=mock_request_factory(), + background_tasks=mock_background_tasks, + auth=MOCK_AUTH, + ) + + mock_record.assert_called_once() + endpoint_path, quota_type, result, duration = mock_record.call_args.args + assert endpoint_path == rlsapi_v1.ENDPOINT_PATH_INFER + assert quota_type == "user_id" + assert result == rlsapi_v1.recording.QUOTA_RESULT_ERROR + assert duration >= 0 + + @pytest.mark.parametrize( ("quota_subject", "rh_identity_setup", "expected_subject"), [ diff --git a/tests/unit/metrics/test_recording.py b/tests/unit/metrics/test_recording.py index 255e91421..dc5260d8f 100644 --- a/tests/unit/metrics/test_recording.py +++ b/tests/unit/metrics/test_recording.py @@ -275,7 +275,6 @@ def test_histogram_recorders_observe_metrics_and_log_errors( ) - def test_record_llm_inference_duration_bounds_result_label( mocker: MockerFixture, ) -> None: @@ -392,3 +391,59 @@ def test_record_authorization_duration_bounds_labels(mocker: MockerFixture) -> N mock_metric.labels.assert_called_once_with("unknown", "error") mock_metric.labels.return_value.observe.assert_called_once_with(0.25) + + +@pytest.mark.parametrize("failing_metric", ["counter", "histogram"]) +def test_record_quota_check_updates_metrics_and_logs_errors( + mocker: MockerFixture, + recording_logger: MockType, + failing_metric: str, +) -> None: + """Test quota helper counter and histogram updates plus both failure points.""" + mock_counter = mocker.patch("metrics.recording.metrics.quota_checks_total") + mock_histogram = mocker.patch( + "metrics.recording.metrics.quota_check_duration_seconds" + ) + + recording.record_quota_check("/v1/infer", "org_id", "success", 0.75) + + mock_counter.labels.assert_called_once_with("/v1/infer", "org_id", "success") + mock_counter.labels.return_value.inc.assert_called_once() + mock_histogram.labels.assert_called_once_with("/v1/infer", "org_id", "success") + mock_histogram.labels.return_value.observe.assert_called_once_with(0.75) + recording_logger.warning.assert_not_called() + + mock_counter.reset_mock() + mock_histogram.reset_mock() + recording_logger.reset_mock() + if failing_metric == "counter": + mock_counter.labels.return_value.inc.side_effect = TypeError("bad") + expected_warning = "Failed to update quota check counter" + else: + mock_histogram.labels.return_value.observe.side_effect = TypeError("bad") + expected_warning = "Failed to update quota check duration metric" + + recording.record_quota_check("/v1/infer", "org_id", "failure", 0.75) + + recording_logger.warning.assert_called_once_with(expected_warning, exc_info=True) + + # With independent try/except blocks, the non-failing metric must still update. + if failing_metric == "counter": + mock_histogram.labels.return_value.observe.assert_called_once_with(0.75) + else: + mock_counter.labels.return_value.inc.assert_called_once() + + +def test_record_quota_check_bounds_labels(mocker: MockerFixture) -> None: + """Test quota helper maps unexpected label values to bounded fallbacks.""" + mock_counter = mocker.patch("metrics.recording.metrics.quota_checks_total") + mock_histogram = mocker.patch( + "metrics.recording.metrics.quota_check_duration_seconds" + ) + + recording.record_quota_check("/v1/responses", "customer-123", "timeout", 0.25) + + mock_counter.labels.assert_called_once_with("/v1/responses", "unknown", "error") + mock_counter.labels.return_value.inc.assert_called_once() + mock_histogram.labels.assert_called_once_with("/v1/responses", "unknown", "error") + mock_histogram.labels.return_value.observe.assert_called_once_with(0.25)