diff --git a/src/app/endpoints/responses.py b/src/app/endpoints/responses.py index 704a628a1..2705a27dd 100644 --- a/src/app/endpoints/responses.py +++ b/src/app/endpoints/responses.py @@ -2,7 +2,6 @@ """Handler for REST API call to provide answer using Responses API (LCORE specification).""" -import asyncio import json import time from collections.abc import AsyncIterator, Sequence @@ -32,6 +31,11 @@ APIStatusError as OpenAIAPIStatusError, ) +from app.endpoints.responses_telemetry import ( + queue_blocked_response_event, + queue_completed_response_event, + queue_responses_error_event, +) from authentication import get_auth_dependency from authentication.interface import AuthTuple from authorization.azure_token_manager import AzureEntraIDManager @@ -60,7 +64,6 @@ from models.common.responses.responses_context import ResponsesContext from models.common.turn_summary import TurnSummary from models.config import Action -from observability import ResponsesEventData, build_responses_event, send_splunk_event from utils.conversations import append_turn_items_to_conversation from utils.endpoints import ( check_configuration_loaded, @@ -159,96 +162,6 @@ def _get_user_agent(request: Request) -> Optional[str]: } -# Strong references for fire-and-forget telemetry tasks so they aren't -# garbage-collected before completion (the event loop only holds weak refs). -_background_splunk_tasks: set[asyncio.Task[None]] = set() - - -def _queue_responses_splunk_event( # pylint: disable=too-many-arguments,too-many-positional-arguments - background_tasks: Optional[BackgroundTasks], - input_text: str, - response_text: str, - conversation_id: str, - model: str, - rh_identity_context: tuple[str, str], - inference_time: float, - sourcetype: str, - input_tokens: int = 0, - output_tokens: int = 0, - fire_and_forget: bool = False, - user_agent: Optional[str] = None, -) -> None: - """Build and queue a Splunk telemetry event for the responses endpoint. - - No-op when background_tasks is None and fire_and_forget is False - (Splunk telemetry disabled). - - Args: - background_tasks: FastAPI background task manager, or None if disabled. - input_text: User input text. - response_text: Response text from LLM or shield. - conversation_id: Conversation identifier. - model: Model name used for inference. - rh_identity_context: Tuple of (org_id, system_id) from RH identity. - inference_time: Request processing duration in seconds. - sourcetype: Splunk sourcetype for the event. - input_tokens: Number of prompt tokens consumed. - output_tokens: Number of completion tokens produced. - fire_and_forget: When True, dispatch via asyncio.create_task() instead - of background_tasks. Use for error paths where an HTTPException - follows, since FastAPI discards BackgroundTasks on non-2xx responses. - user_agent: Sanitized User-Agent string from the request header, or None. - """ - if not fire_and_forget and background_tasks is None: - return - org_id, system_id = rh_identity_context - event_data = ResponsesEventData( - input_text=input_text, - response_text=response_text, - conversation_id=conversation_id, - model=model, - org_id=org_id, - system_id=system_id, - inference_time=inference_time, - input_tokens=input_tokens, - output_tokens=output_tokens, - user_agent=user_agent, - ) - event = build_responses_event(event_data) - if fire_and_forget: - task = asyncio.create_task(send_splunk_event(event, sourcetype)) - _background_splunk_tasks.add(task) - task.add_done_callback(_background_splunk_tasks.discard) - elif background_tasks is not None: - background_tasks.add_task(send_splunk_event, event, sourcetype) - - -def _queue_responses_error_event( - error: Exception, - api_params: ResponsesApiParams, - context: ResponsesContext, -) -> None: - """Queue fire-and-forget Splunk telemetry for a Responses API error. - - Args: - error: The backend exception being converted into an HTTP error. - api_params: Responses API parameters for the failed request. - context: Request-scoped Responses API context. - """ - _queue_responses_splunk_event( - background_tasks=context.background_tasks, - input_text=context.input_text, - response_text=str(error), - conversation_id=normalize_conversation_id(api_params.conversation), - model=api_params.model, - rh_identity_context=context.rh_identity_context, - inference_time=(datetime.now(UTC) - context.started_at).total_seconds(), - sourcetype="responses_error", - fire_and_forget=True, - user_agent=context.user_agent, - ) - - def _http_exception_for_response_api_error( error: Exception, api_params: ResponsesApiParams, @@ -297,7 +210,7 @@ def _raise_response_api_http_exception( http_exception = _http_exception_for_response_api_error(error, api_params) if http_exception is None: raise error - _queue_responses_error_event(error, api_params, context) + queue_responses_error_event(error, api_params, context) raise http_exception from error @@ -321,31 +234,6 @@ async def _persist_blocked_response_turn( ) -def _queue_blocked_response_event( - api_params: ResponsesApiParams, - context: ResponsesContext, - response_text: str, -) -> None: - """Queue Splunk telemetry for a shield-blocked Responses API request. - - Args: - api_params: Responses API parameters for the blocked request. - context: Request-scoped Responses API context. - response_text: Refusal text sent to the client. - """ - _queue_responses_splunk_event( - background_tasks=context.background_tasks, - input_text=context.input_text, - response_text=response_text, - conversation_id=normalize_conversation_id(api_params.conversation), - model=api_params.model, - rh_identity_context=context.rh_identity_context, - inference_time=(datetime.now(UTC) - context.started_at).total_seconds(), - sourcetype="responses_shield_blocked", - user_agent=context.user_agent, - ) - - async def _append_previous_response_turn( api_params: ResponsesApiParams, context: ResponsesContext, @@ -419,39 +307,6 @@ def _store_response_query_results( ) -def _queue_completed_response_event( - api_params: ResponsesApiParams, - context: ResponsesContext, - turn_summary: TurnSummary, - completed_at: datetime, - response_text: str, -) -> None: - """Queue Splunk telemetry for a completed Responses API request. - - Args: - api_params: Responses API parameters for the completed request. - context: Request-scoped Responses API context. - turn_summary: Summary containing token usage for telemetry. - completed_at: Time when response handling completed. - response_text: Final text sent to the client. - """ - if context.moderation_result.decision != "passed": - return - _queue_responses_splunk_event( - background_tasks=context.background_tasks, - input_text=context.input_text, - response_text=response_text, - conversation_id=normalize_conversation_id(api_params.conversation), - model=api_params.model, - rh_identity_context=context.rh_identity_context, - inference_time=(completed_at - context.started_at).total_seconds(), - sourcetype="responses_completed", - input_tokens=turn_summary.token_usage.input_tokens, - output_tokens=turn_summary.token_usage.output_tokens, - user_agent=context.user_agent, - ) - - @router.post( "/responses", responses=responses_response, @@ -682,7 +537,7 @@ async def handle_streaming_response( turn_summary.llm_response = context.moderation_result.message generator = shield_violation_generator(api_params, context) await _persist_blocked_response_turn(api_params, context) - _queue_blocked_response_event( + queue_blocked_response_event( api_params, context, context.moderation_result.message, @@ -1141,7 +996,7 @@ async def generate_response( completed_at, topic_summary, ) - _queue_completed_response_event( + queue_completed_response_event( api_params, context, turn_summary, @@ -1178,7 +1033,7 @@ async def handle_non_streaming_response( **api_params.echoed_params(configuration.rag_id_mapping), ) await _persist_blocked_response_turn(api_params, context) - _queue_blocked_response_event(api_params, context, output_text) + queue_blocked_response_event(api_params, context, output_text) else: inference_start_time = time.monotonic() inference_metric_recorded = False @@ -1258,7 +1113,7 @@ async def handle_non_streaming_response( completed_at, topic_summary, ) - _queue_completed_response_event( + queue_completed_response_event( api_params, context, turn_summary, diff --git a/src/app/endpoints/responses_telemetry.py b/src/app/endpoints/responses_telemetry.py new file mode 100644 index 000000000..8441a1e0d --- /dev/null +++ b/src/app/endpoints/responses_telemetry.py @@ -0,0 +1,162 @@ +"""Splunk telemetry helpers for the Responses API endpoint. + +Extracted from responses.py to reduce module size while keeping telemetry +functions co-located with the endpoint they serve. +""" + +from datetime import UTC, datetime +from typing import Optional + +from fastapi import BackgroundTasks + +from log import get_logger +from models.common.responses.responses_api_params import ResponsesApiParams +from models.common.responses.responses_context import ResponsesContext +from models.common.turn_summary import TurnSummary +from observability import ResponsesEventData, build_responses_event +from observability.splunk import dispatch_splunk_event +from utils.suid import normalize_conversation_id + +logger = get_logger(__name__) + + +def queue_responses_splunk_event( # pylint: disable=too-many-arguments,too-many-positional-arguments + background_tasks: Optional[BackgroundTasks], + input_text: str, + response_text: str, + conversation_id: str, + model: str, + rh_identity_context: tuple[str, str], + inference_time: float, + sourcetype: str, + input_tokens: int = 0, + output_tokens: int = 0, + fire_and_forget: bool = False, + user_agent: Optional[str] = None, +) -> None: + """Build and queue a Splunk telemetry event for the responses endpoint. + + No-op when background_tasks is None and fire_and_forget is False + (Splunk telemetry disabled). + + Args: + background_tasks: FastAPI background task manager, or None if disabled. + input_text: User input text. + response_text: Response text from LLM or shield. + conversation_id: Conversation identifier. + model: Model name used for inference. + rh_identity_context: Tuple of (org_id, system_id) from RH identity. + inference_time: Request processing duration in seconds. + sourcetype: Splunk sourcetype for the event. + input_tokens: Number of prompt tokens consumed. + output_tokens: Number of completion tokens produced. + fire_and_forget: When True, dispatch via asyncio.create_task() instead + of background_tasks. Use for error paths where an HTTPException + follows, since FastAPI discards BackgroundTasks on non-2xx responses. + user_agent: Sanitized User-Agent string from the request header, or None. + """ + if not fire_and_forget and background_tasks is None: + return + event_data = ResponsesEventData( + input_text=input_text, + response_text=response_text, + conversation_id=conversation_id, + model=model, + org_id=rh_identity_context[0], + system_id=rh_identity_context[1], + inference_time=inference_time, + input_tokens=input_tokens, + output_tokens=output_tokens, + user_agent=user_agent, + ) + event = build_responses_event(event_data) + dispatch_splunk_event( + event, + sourcetype, + background_tasks=background_tasks, + fire_and_forget=fire_and_forget, + ) + + +def queue_responses_error_event( + error: Exception, + api_params: ResponsesApiParams, + context: ResponsesContext, +) -> None: + """Queue fire-and-forget Splunk telemetry for a Responses API error. + + Args: + error: The backend exception being converted into an HTTP error. + api_params: Responses API parameters for the failed request. + context: Request-scoped Responses API context. + """ + queue_responses_splunk_event( + background_tasks=context.background_tasks, + input_text=context.input_text, + response_text=type(error).__name__, + conversation_id=normalize_conversation_id(api_params.conversation), + model=api_params.model, + rh_identity_context=context.rh_identity_context, + inference_time=(datetime.now(UTC) - context.started_at).total_seconds(), + sourcetype="responses_error", + fire_and_forget=True, + user_agent=context.user_agent, + ) + + +def queue_blocked_response_event( + api_params: ResponsesApiParams, + context: ResponsesContext, + response_text: str, +) -> None: + """Queue Splunk telemetry for a shield-blocked Responses API request. + + Args: + api_params: Responses API parameters for the blocked request. + context: Request-scoped Responses API context. + response_text: Refusal text sent to the client. + """ + queue_responses_splunk_event( + background_tasks=context.background_tasks, + input_text=context.input_text, + response_text=response_text, + conversation_id=normalize_conversation_id(api_params.conversation), + model=api_params.model, + rh_identity_context=context.rh_identity_context, + inference_time=(datetime.now(UTC) - context.started_at).total_seconds(), + sourcetype="responses_shield_blocked", + user_agent=context.user_agent, + ) + + +def queue_completed_response_event( + api_params: ResponsesApiParams, + context: ResponsesContext, + turn_summary: TurnSummary, + completed_at: datetime, + response_text: str, +) -> None: + """Queue Splunk telemetry for a completed Responses API request. + + Args: + api_params: Responses API parameters for the completed request. + context: Request-scoped Responses API context. + turn_summary: Summary containing token usage for telemetry. + completed_at: Time when response handling completed. + response_text: Final text sent to the client. + """ + if context.moderation_result.decision != "passed": + return + queue_responses_splunk_event( + background_tasks=context.background_tasks, + input_text=context.input_text, + response_text=response_text, + conversation_id=normalize_conversation_id(api_params.conversation), + model=api_params.model, + rh_identity_context=context.rh_identity_context, + inference_time=(completed_at - context.started_at).total_seconds(), + sourcetype="responses_completed", + input_tokens=turn_summary.token_usage.input_tokens, + output_tokens=turn_summary.token_usage.output_tokens, + user_agent=context.user_agent, + ) diff --git a/src/authentication/api_key_token.py b/src/authentication/api_key_token.py index 2b9b27900..b5cdf473c 100644 --- a/src/authentication/api_key_token.py +++ b/src/authentication/api_key_token.py @@ -7,12 +7,14 @@ """ import secrets +import time from fastapi import HTTPException, Request, status from authentication.interface import AuthInterface, AuthTuple -from authentication.utils import extract_user_token +from authentication.utils import extract_user_token, record_auth_metrics from constants import ( + AUTH_MOD_APIKEY_TOKEN, DEFAULT_USER_NAME, DEFAULT_USER_UID, DEFAULT_VIRTUAL_PATH, @@ -59,16 +61,32 @@ async def __call__(self, request: Request) -> AuthTuple: HTTPException: If the bearer token is missing or doesn't match the configured API key (HTTP 401). """ + start_time = time.monotonic() + # try to extract user token from request - user_token = extract_user_token(request.headers) + try: + user_token = extract_user_token(request.headers) + except HTTPException as exc: + # Distinguish missing header from malformed token + reason = "missing_token" + if isinstance( + exc.detail, dict + ) and "No Authorization header" in exc.detail.get("cause", ""): + reason = "missing_header" + record_auth_metrics(AUTH_MOD_APIKEY_TOKEN, "failure", reason, start_time) + raise # API Key validation. Use secrets.compare_digest for constant-time comparison if not secrets.compare_digest( user_token, self.config.api_key.get_secret_value() ): + record_auth_metrics( + AUTH_MOD_APIKEY_TOKEN, "failure", "invalid_key", start_time + ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key", ) + record_auth_metrics(AUTH_MOD_APIKEY_TOKEN, "success", "valid_key", start_time) return DEFAULT_USER_UID, DEFAULT_USER_NAME, self.skip_userid_check, user_token diff --git a/src/authentication/jwk_token.py b/src/authentication/jwk_token.py index 7cc15870b..db8241b73 100644 --- a/src/authentication/jwk_token.py +++ b/src/authentication/jwk_token.py @@ -1,6 +1,7 @@ """Manage authentication flow for FastAPI endpoints with JWK based JWT auth.""" import json +import time from asyncio import Lock from collections.abc import Callable from typing import Any @@ -17,12 +18,13 @@ from fastapi import HTTPException, Request from authentication.interface import AuthInterface, AuthTuple -from authentication.utils import extract_user_token +from authentication.utils import extract_user_token, record_auth_metrics from constants import ( + AUTH_MOD_JWK_TOKEN, DEFAULT_VIRTUAL_PATH, ) from log import get_logger -from models.api.responses.error import UnauthorizedResponse +from models.api.responses.error import ServiceUnavailableResponse, UnauthorizedResponse from models.config import JwkConfiguration logger = get_logger(__name__) @@ -141,6 +143,93 @@ def _internal(header: dict[str, Any], _payload: dict[str, Any]) -> Key: return _internal +async def _get_jwk_set_for_auth(config: JwkConfiguration, start_time: float) -> KeySet: + """Load the configured JWK set and record bounded auth failures.""" + try: + return await get_jwk_set(str(config.url)) + except aiohttp.ClientError as exc: + logger.error("Failed to fetch JWK set: %s", exc) + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "jwk_fetch_error", start_time + ) + response = ServiceUnavailableResponse( + backend_name="JWK key server", + cause="Unable to reach authentication key server", + ) + raise HTTPException(**response.model_dump()) from exc + except json.JSONDecodeError as exc: + logger.error("Invalid JSON in JWK set response: %s", exc) + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_json", start_time) + response = ServiceUnavailableResponse( + backend_name="JWK key server", + cause="Authentication key server returned invalid data", + ) + raise HTTPException(**response.model_dump()) from exc + except JoseError as exc: + logger.error("Invalid JWK set format: %s", exc) + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_jwk", start_time) + response = ServiceUnavailableResponse( + backend_name="JWK key server", + cause="Authentication keys are malformed", + ) + raise HTTPException(**response.model_dump()) from exc + + +def _decode_jwk_claims(user_token: str, jwk_set: KeySet, start_time: float) -> Any: + """Decode a JWT and record bounded auth failures.""" + try: + return jwt.decode(user_token, key=key_resolver_func(jwk_set)) + except (KeyNotFoundError, BadSignatureError, DecodeError, JoseError) as exc: + logger.warning("Token decode error: %s", exc) + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "token_decode_error", start_time + ) + if isinstance(exc, KeyNotFoundError): + cause = "Token signed by unknown key" + elif isinstance(exc, BadSignatureError): + cause = "Invalid token signature" + elif isinstance(exc, DecodeError): + cause = "Token could not be decoded" + else: + cause = "Token format error" + response = UnauthorizedResponse(cause=cause) + raise HTTPException(**response.model_dump()) from exc + + +def _validate_jwk_claims(claims: Any, start_time: float) -> None: + """Validate decoded JWT claims and record bounded auth failures.""" + try: + claims.validate() + except ExpiredTokenError as exc: + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "token_expired", start_time) + response = UnauthorizedResponse(cause="Token has expired") + raise HTTPException(**response.model_dump()) from exc + except JoseError as exc: + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "token_validation_error", start_time + ) + response = UnauthorizedResponse(cause="Token validation failed") + raise HTTPException(**response.model_dump()) from exc + + +def _get_required_claim(claims: Any, claim_name: str, start_time: float) -> str: + """Return a required JWT claim and record bounded auth failures when missing.""" + try: + value = claims[claim_name] + except KeyError as exc: + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "missing_claim", start_time) + response = UnauthorizedResponse(cause=f"Token missing claim: {claim_name}") + raise HTTPException(**response.model_dump()) from exc + if not isinstance(value, str) or not value.strip(): + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_claim", start_time) + response = UnauthorizedResponse(cause=f"Token has invalid claim: {claim_name}") + invalid_claim_error = ValueError( + f"Token claim {claim_name} must be a non-empty string" + ) + raise HTTPException(**response.model_dump()) from invalid_claim_error + return value + + class JwkTokenAuthDependency(AuthInterface): # pylint: disable=too-few-public-methods """JWK AuthDependency class for JWK-based JWT authentication.""" @@ -189,73 +278,40 @@ async def __call__(self, request: Request) -> AuthTuple: extracted from the validated JWT. Only returned on successful authentication; all error paths raise HTTPException. """ + start_time = time.monotonic() + if not request.headers.get("Authorization"): + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "missing_header", start_time + ) response = UnauthorizedResponse(cause="No Authorization header found") raise HTTPException(**response.model_dump()) - user_token = extract_user_token(request.headers) - - try: - jwk_set = await get_jwk_set(str(self.config.url)) - except aiohttp.ClientError as exc: - logger.error("Failed to fetch JWK set: %s", exc) - response = UnauthorizedResponse( - cause="Unable to reach authentication key server" - ) - raise HTTPException(**response.model_dump()) from exc - except json.JSONDecodeError as exc: - logger.error("Invalid JSON in JWK set response: %s", exc) - response = UnauthorizedResponse( - cause="Authentication key server returned invalid data" - ) - raise HTTPException(**response.model_dump()) from exc - except JoseError as exc: - logger.error("Invalid JWK set format: %s", exc) - response = UnauthorizedResponse(cause="Authentication keys are malformed") - raise HTTPException(**response.model_dump()) from exc - - try: - claims = jwt.decode(user_token, key=key_resolver_func(jwk_set)) - except (KeyNotFoundError, BadSignatureError, DecodeError, JoseError) as exc: - logger.warning("Token decode error: %s", exc) - cause_map = { - KeyNotFoundError: "Token signed by unknown key", - BadSignatureError: "Invalid token signature", - DecodeError: "Token could not be decoded", - JoseError: "Token format error", - } - response = UnauthorizedResponse( - cause=cause_map.get(type(exc), "Unknown token error") - ) - raise HTTPException(**response.model_dump()) from exc - try: - claims.validate() - except ExpiredTokenError as exc: - response = UnauthorizedResponse(cause="Token has expired") - raise HTTPException(**response.model_dump()) from exc - except JoseError as exc: - response = UnauthorizedResponse(cause="Token validation failed") - raise HTTPException(**response.model_dump()) from exc - - try: - user_id: str = claims[self.config.jwt_configuration.user_id_claim] - except KeyError as exc: - missing_claim = self.config.jwt_configuration.user_id_claim - response = UnauthorizedResponse( - cause=f"Token missing claim: {missing_claim}" + user_token = extract_user_token(request.headers) + except HTTPException: + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "missing_token", start_time ) - raise HTTPException(**response.model_dump()) from exc - - try: - username: str = claims[self.config.jwt_configuration.username_claim] - except KeyError as exc: - missing_claim = self.config.jwt_configuration.username_claim - response = UnauthorizedResponse( - cause=f"Token missing claim: {missing_claim}" + raise + except Exception: # pylint: disable=broad-exception-caught + logger.exception("Unexpected error while extracting JWK bearer token") + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "unexpected_error", start_time ) - raise HTTPException(**response.model_dump()) from exc + raise + + jwk_set = await _get_jwk_set_for_auth(self.config, start_time) + claims = _decode_jwk_claims(user_token, jwk_set, start_time) + _validate_jwk_claims(claims, start_time) + user_id = _get_required_claim( + claims, self.config.jwt_configuration.user_id_claim, start_time + ) + username = _get_required_claim( + claims, self.config.jwt_configuration.username_claim, start_time + ) logger.info("Successfully authenticated user %s (ID: %s)", username, user_id) + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "success", "authenticated", start_time) return user_id, username, self.skip_userid_check, user_token diff --git a/src/authentication/k8s.py b/src/authentication/k8s.py index 4a8c9dbdf..285f6c172 100644 --- a/src/authentication/k8s.py +++ b/src/authentication/k8s.py @@ -1,6 +1,7 @@ """Manage authentication flow for FastAPI endpoints with K8S/OCP.""" import os +import time from http import HTTPStatus from typing import Optional, Self, cast @@ -10,9 +11,9 @@ from kubernetes.config import ConfigException from authentication.interface import NO_AUTH_TUPLE, AuthInterface, AuthTuple -from authentication.utils import extract_user_token +from authentication.utils import extract_user_token, record_auth_metrics from configuration import configuration -from constants import DEFAULT_VIRTUAL_PATH +from constants import AUTH_MOD_K8S, DEFAULT_VIRTUAL_PATH from log import get_logger from models.api.responses.error import ( ForbiddenResponse, @@ -382,7 +383,126 @@ def get_user_info(token: str) -> Optional[kubernetes.client.V1TokenReviewStatus] raise HTTPException(**response_obj.model_dump()) from e except Exception as e: # pylint: disable=broad-exception-caught logger.error("Unexpected error during TokenReview: %s", e) - return None + response_obj = InternalServerErrorResponse( + response="Internal server error", + cause=f"Unexpected error during TokenReview: {e}", + ) + raise HTTPException(**response_obj.model_dump()) from e + + +def _populate_kube_admin_uid( + user: kubernetes.client.V1UserInfo, start_time: float +) -> None: + """Populate kube:admin UID from the cluster ID and record bounded failures.""" + if user.username != "kube:admin": + return + try: + user.uid = K8sClientSingleton.get_cluster_id() + except K8sAPIConnectionError as e: + logger.error("Cannot connect to Kubernetes API: %s", e) + record_auth_metrics(AUTH_MOD_K8S, "failure", "k8s_api_unavailable", start_time) + response = ServiceUnavailableResponse( + backend_name="Kubernetes API", + cause=str(e), + ) + raise HTTPException(**response.model_dump()) from e + except K8sConfigurationError as e: + logger.error("Cluster configuration error: %s", e) + record_auth_metrics(AUTH_MOD_K8S, "failure", "k8s_config_error", start_time) + response = InternalServerErrorResponse( + response="Internal server error", + cause=str(e), + ) + raise HTTPException(**response.model_dump()) from e + except Exception as e: # pylint: disable=broad-exception-caught + logger.exception("Unexpected error while resolving kube:admin cluster ID") + record_auth_metrics(AUTH_MOD_K8S, "failure", "unexpected_error", start_time) + response = InternalServerErrorResponse( + response="Internal server error", + cause=str(e), + ) + raise HTTPException(**response.model_dump()) from e + + +def _create_subject_access_review( + user: kubernetes.client.V1UserInfo, virtual_path: str, start_time: float +) -> kubernetes.client.V1SubjectAccessReview: + """Create a Kubernetes SubjectAccessReview and record bounded failures.""" + try: + authorization_api = K8sClientSingleton.get_authz_api() + except Exception as e: + logger.error("Failed to initialize Kubernetes authorization API: %s", e) + record_auth_metrics( + AUTH_MOD_K8S, "failure", "authorization_check_error", start_time + ) + response = ServiceUnavailableResponse( + backend_name="Kubernetes API", + cause=f"Unable to initialize Kubernetes client: {e}", + ) + raise HTTPException(**response.model_dump()) from e + + try: + sar = kubernetes.client.V1SubjectAccessReview( + spec=kubernetes.client.V1SubjectAccessReviewSpec( + user=user.username, + groups=user.groups, + non_resource_attributes=kubernetes.client.V1NonResourceAttributes( + path=virtual_path, verb="get" + ), + ) + ) + return cast( + kubernetes.client.V1SubjectAccessReview, + authorization_api.create_subject_access_review(sar), + ) + except ApiException as e: + record_auth_metrics( + AUTH_MOD_K8S, "failure", "authorization_check_error", start_time + ) + if e.status is None: + logger.error( + "Kubernetes API error during SubjectAccessReview with no status code: %s", + e.reason, + ) + response = ServiceUnavailableResponse( + backend_name="Kubernetes API", + cause=f"Failed to connect to Kubernetes API: {e.reason}", + ) + raise HTTPException(**response.model_dump()) from e + + if ( + e.status >= HTTPStatus.INTERNAL_SERVER_ERROR + or e.status == HTTPStatus.TOO_MANY_REQUESTS + ): + logger.error( + "Kubernetes API unavailable during SubjectAccessReview (status %s): %s", + e.status, + e.reason, + ) + response = ServiceUnavailableResponse( + backend_name="Kubernetes API", + cause=f"Kubernetes API unavailable: {e.reason} (status {e.status})", + ) + raise HTTPException(**response.model_dump()) from e + + logger.error( + "Kubernetes API returned client error during SubjectAccessReview (status %s): %s", + e.status, + e.reason, + ) + response_obj = InternalServerErrorResponse( + response="Internal server error", + cause=f"Kubernetes API request failed: {e.reason} (status {e.status})", + ) + raise HTTPException(**response_obj.model_dump()) from e + except Exception as e: # pylint: disable=broad-exception-caught + logger.exception("Unexpected error during SubjectAccessReview") + record_auth_metrics(AUTH_MOD_K8S, "failure", "unexpected_error", start_time) + response = InternalServerErrorResponse( + response="Internal server error", + cause=str(e), + ) + raise HTTPException(**response.model_dump()) from e class K8SAuthDependency(AuthInterface): # pylint: disable=too-few-public-methods @@ -436,69 +556,47 @@ async def __call__(self, request: Request) -> AuthTuple: ------ HTTPException: If authentication or authorization fails. """ + start_time = time.monotonic() + # LCORE-694: Config option to skip authorization for readiness and liveness probe if not request.headers.get("Authorization"): if configuration.authentication_configuration.skip_for_health_probes: if request.url.path in ("/readiness", "/liveness"): + record_auth_metrics( + AUTH_MOD_K8S, "skipped", "health_probe", start_time + ) return NO_AUTH_TUPLE # Skip auth for metrics endpoint when configured if configuration.authentication_configuration.skip_for_metrics: if request.url.path in ("/metrics",): + record_auth_metrics(AUTH_MOD_K8S, "skipped", "metrics", start_time) return NO_AUTH_TUPLE - token = extract_user_token(request.headers) - user_info = get_user_info(token) + try: + token = extract_user_token(request.headers) + except HTTPException: + record_auth_metrics(AUTH_MOD_K8S, "failure", "missing_token", start_time) + raise + try: + user_info = get_user_info(token) + except HTTPException: + record_auth_metrics( + AUTH_MOD_K8S, "failure", "token_review_error", start_time + ) + raise if user_info is None: + record_auth_metrics(AUTH_MOD_K8S, "failure", "invalid_token", start_time) response = UnauthorizedResponse(cause="Invalid or expired Kubernetes token") raise HTTPException(**response.model_dump()) # Cast user to proper type for type checking user = cast(kubernetes.client.V1UserInfo, user_info.user) - if user.username == "kube:admin": - try: - user.uid = K8sClientSingleton.get_cluster_id() - except K8sAPIConnectionError as e: - # Kubernetes API is unreachable - return 503 - logger.error("Cannot connect to Kubernetes API: %s", e) - response = ServiceUnavailableResponse( - backend_name="Kubernetes API", - cause=str(e), - ) - raise HTTPException(**response.model_dump()) from e - except K8sConfigurationError as e: - # Cluster misconfiguration or client error - return 500 - logger.error("Cluster configuration error: %s", e) - response = InternalServerErrorResponse( - response="Internal server error", - cause=str(e), - ) - raise HTTPException(**response.model_dump()) from e - - try: - authorization_api = K8sClientSingleton.get_authz_api() - sar = kubernetes.client.V1SubjectAccessReview( - spec=kubernetes.client.V1SubjectAccessReviewSpec( - user=user.username, - groups=user.groups, - non_resource_attributes=kubernetes.client.V1NonResourceAttributes( - path=self.virtual_path, verb="get" - ), - ) - ) - sar_response = cast( - kubernetes.client.V1SubjectAccessReview, - authorization_api.create_subject_access_review(sar), - ) - - except Exception as e: - logger.error("API exception during SubjectAccessReview: %s", e) - response = ServiceUnavailableResponse( - backend_name="Kubernetes API", - cause="Unable to perform authorization check", - ) - raise HTTPException(**response.model_dump()) from e + _populate_kube_admin_uid(user, start_time) + sar_response = _create_subject_access_review( + user, self.virtual_path, start_time + ) sar_status = cast( kubernetes.client.V1SubjectAccessReviewStatus, sar_response.status @@ -507,9 +605,11 @@ async def __call__(self, request: Request) -> AuthTuple: username = cast(str, user.username) if not sar_status.allowed: + record_auth_metrics(AUTH_MOD_K8S, "failure", "not_authorized", start_time) response = ForbiddenResponse.endpoint(user_id=user_uid) raise HTTPException(**response.model_dump()) + record_auth_metrics(AUTH_MOD_K8S, "success", "authenticated", start_time) return ( user_uid, username, diff --git a/src/authentication/noop.py b/src/authentication/noop.py index 6d32f45c3..bab22479a 100644 --- a/src/authentication/noop.py +++ b/src/authentication/noop.py @@ -1,9 +1,13 @@ """Manage authentication flow for FastAPI endpoints with no-op auth.""" +import time + from fastapi import HTTPException, Request from authentication.interface import AuthInterface, AuthTuple +from authentication.utils import record_auth_metrics from constants import ( + AUTH_MOD_NOOP, DEFAULT_USER_NAME, DEFAULT_USER_UID, DEFAULT_VIRTUAL_PATH, @@ -47,6 +51,8 @@ async def __call__(self, request: Request) -> AuthTuple: - skip_userid_check: True to indicate the user ID check is skipped. - token: NO_USER_TOKEN. """ + start_time = time.monotonic() + logger.warning( "No-op authentication dependency is being used. " "The service is running in insecure mode intended solely for development purposes" @@ -54,6 +60,8 @@ async def __call__(self, request: Request) -> AuthTuple: # try to extract user ID from request user_id = request.query_params.get("user_id", DEFAULT_USER_UID) if not user_id: + record_auth_metrics(AUTH_MOD_NOOP, "failure", "empty_user_id", start_time) raise HTTPException(status_code=400, detail="user_id cannot be empty") logger.debug("Retrieved user ID: %s", user_id) + record_auth_metrics(AUTH_MOD_NOOP, "skipped", "no_auth_required", start_time) return user_id, DEFAULT_USER_NAME, self.skip_userid_check, NO_USER_TOKEN diff --git a/src/authentication/noop_with_token.py b/src/authentication/noop_with_token.py index 0656d952a..bd8ec3943 100644 --- a/src/authentication/noop_with_token.py +++ b/src/authentication/noop_with_token.py @@ -6,14 +6,17 @@ - Reads a user token from request headers via `authentication.utils.extract_user_token`. - Reads `user_id` from query params (falls back to `DEFAULT_USER_UID`) and pairs it with `DEFAULT_USER_NAME`. -- Returns a tuple: (user_id, DEFAULT_USER_NAME, user_token). +- Returns a tuple: (user_id, DEFAULT_USER_NAME, skip_userid_check, user_token). """ +import time + from fastapi import HTTPException, Request from authentication.interface import AuthInterface, AuthTuple -from authentication.utils import extract_user_token +from authentication.utils import extract_user_token, record_auth_metrics from constants import ( + AUTH_MOD_NOOP_WITH_TOKEN, DEFAULT_USER_NAME, DEFAULT_USER_UID, DEFAULT_VIRTUAL_PATH, @@ -59,15 +62,32 @@ async def __call__(self, request: Request) -> AuthTuple: - skip_userid_check: True to indicate user-id checks are skipped. - user_token: Token extracted from the request headers. """ + start_time = time.monotonic() + logger.warning( "No-op with token authentication dependency is being used. " "The service is running in insecure mode intended solely for development purposes" ) # try to extract user token from request - user_token = extract_user_token(request.headers) + try: + user_token = extract_user_token(request.headers) + except HTTPException as exc: + reason = ( + "missing_token" + if "No Authorization header" in str(exc.detail) + else "malformed_token" + ) + record_auth_metrics(AUTH_MOD_NOOP_WITH_TOKEN, "failure", reason, start_time) + raise # try to extract user ID from request user_id = request.query_params.get("user_id", DEFAULT_USER_UID) if not user_id: + record_auth_metrics( + AUTH_MOD_NOOP_WITH_TOKEN, "failure", "empty_user_id", start_time + ) raise HTTPException(status_code=400, detail="user_id cannot be empty") logger.debug("Retrieved user ID: %s", user_id) + record_auth_metrics( + AUTH_MOD_NOOP_WITH_TOKEN, "skipped", "no_auth_required", start_time + ) return user_id, DEFAULT_USER_NAME, self.skip_userid_check, user_token diff --git a/src/authentication/rh_identity.py b/src/authentication/rh_identity.py index f772ae9e5..bbda184e9 100644 --- a/src/authentication/rh_identity.py +++ b/src/authentication/rh_identity.py @@ -6,13 +6,16 @@ import base64 import json -from typing import Any, Optional +import time +from typing import Any, Final, Optional from fastapi import HTTPException, Request from authentication.interface import NO_AUTH_TUPLE, AuthInterface, AuthTuple +from authentication.utils import record_auth_metrics from configuration import configuration from constants import ( + AUTH_MOD_RH_IDENTITY, DEFAULT_RH_IDENTITY_MAX_HEADER_SIZE, DEFAULT_VIRTUAL_PATH, NO_USER_TOKEN, @@ -24,6 +27,9 @@ RH_INSIGHTS_REQUEST_ID_HEADER = "x-rh-insights-request-id" REQUEST_ID_HEADER = "x-request-id" +HEALTH_PROBE_SKIP_PATHS: Final[frozenset[str]] = frozenset({"/readiness", "/liveness"}) +METRICS_SKIP_PATH: Final[str] = "/metrics" + def _get_request_id(request: Request) -> str: """Return the inbound request identifier available during authentication.""" @@ -33,6 +39,30 @@ def _get_request_id(request: Request) -> str: ) +def _record_rh_identity_auth(result: str, reason: str, start_time: float) -> None: + """Record RH Identity authentication metrics with bounded labels.""" + record_auth_metrics(AUTH_MOD_RH_IDENTITY, result, reason, start_time) + + +def _normalized_request_path(request: Request) -> str: + """Return a canonical request path for exact skip-path comparisons.""" + return request.url.path.rstrip("/") or "/" + + +def _get_auth_skip_tuple(request: Request, start_time: float) -> Optional[AuthTuple]: + """Return an auth tuple for configured RH Identity skip paths.""" + request_path = _normalized_request_path(request) + if request_path in HEALTH_PROBE_SKIP_PATHS: + if configuration.authentication_configuration.skip_for_health_probes: + _record_rh_identity_auth("skipped", "health_probe", start_time) + return NO_AUTH_TUPLE + if request_path == METRICS_SKIP_PATH: + if configuration.authentication_configuration.skip_for_metrics: + _record_rh_identity_auth("skipped", "metrics", start_time) + return NO_AUTH_TUPLE + return None + + class RHIdentityData: """Extracts and validates Red Hat Identity header data. @@ -73,6 +103,12 @@ def _validate_structure(self) -> None: raise HTTPException(status_code=400, detail="Invalid identity data") identity = self.identity_data["identity"] + if not isinstance(identity, dict): + logger.warning( + "Identity validation failed: 'identity' is %s, expected dict", + type(identity).__name__, + ) + raise HTTPException(status_code=400, detail="Invalid identity data") if "type" not in identity: logger.warning("Identity validation failed: missing 'type' field") raise HTTPException(status_code=400, detail="Invalid identity data") @@ -106,6 +142,12 @@ def _validate_user_fields(self, identity: dict) -> None: ) raise HTTPException(status_code=400, detail="Invalid identity data") user = identity["user"] + if not isinstance(user, dict): + logger.warning( + "Identity validation failed: 'user' is %s, expected dict", + type(user).__name__, + ) + raise HTTPException(status_code=400, detail="Invalid identity data") if "user_id" not in user: logger.warning("Identity validation failed: missing 'user_id' in user data") raise HTTPException(status_code=400, detail="Invalid identity data") @@ -132,6 +174,12 @@ def _validate_system_fields(self, identity: dict) -> None: ) raise HTTPException(status_code=400, detail="Invalid identity data") system = identity["system"] + if not isinstance(system, dict): + logger.warning( + "Identity validation failed: 'system' is %s, expected dict", + type(system).__name__, + ) + raise HTTPException(status_code=400, detail="Invalid identity data") if "cn" not in system: logger.warning("Identity validation failed: missing 'cn' in system data") raise HTTPException(status_code=400, detail="Invalid identity data") @@ -321,18 +369,16 @@ async def __call__(self, request: Request) -> AuthTuple: - 400: Invalid base64, invalid JSON, or missing required fields - 403: Missing required entitlements """ + start_time = time.monotonic() + # Extract header identity_header = request.headers.get("x-rh-identity") if not identity_header: - # Skip auth for health probes when configured - if request.url.path.endswith(("/readiness", "/liveness")): - if configuration.authentication_configuration.skip_for_health_probes: - return NO_AUTH_TUPLE - # Skip auth for metrics endpoint when configured - if request.url.path.endswith("/metrics"): - if configuration.authentication_configuration.skip_for_metrics: - return NO_AUTH_TUPLE + auth_skip = _get_auth_skip_tuple(request, start_time) + if auth_skip is not None: + return auth_skip logger.warning("Missing x-rh-identity header") + _record_rh_identity_auth("failure", "missing_header", start_time) raise HTTPException(status_code=401, detail="Missing x-rh-identity header") # Enforce header size limit before decoding @@ -342,6 +388,7 @@ async def __call__(self, request: Request) -> AuthTuple: len(identity_header), self.max_header_size, ) + _record_rh_identity_auth("failure", "header_too_large", start_time) raise HTTPException( status_code=400, detail="x-rh-identity header exceeds maximum allowed size", @@ -353,6 +400,7 @@ async def __call__(self, request: Request) -> AuthTuple: decoded_str = decoded_bytes.decode("utf-8") except (ValueError, UnicodeDecodeError) as exc: logger.warning("Invalid base64 in x-rh-identity header: %s", exc) + _record_rh_identity_auth("failure", "invalid_base64", start_time) raise HTTPException( status_code=400, detail="Invalid base64 encoding in x-rh-identity header", @@ -363,18 +411,38 @@ async def __call__(self, request: Request) -> AuthTuple: identity_data = json.loads(decoded_str) except json.JSONDecodeError as exc: logger.warning("Invalid JSON in x-rh-identity header: %s", exc) + _record_rh_identity_auth("failure", "invalid_json", start_time) raise HTTPException( status_code=400, detail="Invalid JSON in x-rh-identity header" ) from exc + # Guard against non-dict JSON payloads (null, list, number, etc.) + if not isinstance(identity_data, dict): + logger.warning( + "x-rh-identity decoded to non-dict type: %s", + type(identity_data).__name__, + ) + _record_rh_identity_auth("failure", "invalid_identity", start_time) + raise HTTPException( + status_code=400, + detail="Invalid identity data in x-rh-identity header", + ) + # Extract and validate identity - rh_identity = RHIdentityData( - identity_data, - required_entitlements=self.required_entitlements, - ) + try: + rh_identity = RHIdentityData( + identity_data, + required_entitlements=self.required_entitlements, + ) - # Validate entitlements if configured - rh_identity.validate_entitlements() + # Validate entitlements if configured + rh_identity.validate_entitlements() + except HTTPException as exc: + reason = ( + "entitlement_missing" if exc.status_code == 403 else "invalid_identity" + ) + _record_rh_identity_auth("failure", reason, start_time) + raise # Store identity data in request.state for downstream access request.state.rh_identity_data = rh_identity @@ -390,5 +458,6 @@ async def __call__(self, request: Request) -> AuthTuple: rh_identity.get_org_id(), rh_identity.get_system_id(), ) + _record_rh_identity_auth("success", "authenticated", start_time) return user_id, username, self.skip_userid_check, NO_USER_TOKEN diff --git a/src/authentication/utils.py b/src/authentication/utils.py index aad460c00..84b0899bd 100644 --- a/src/authentication/utils.py +++ b/src/authentication/utils.py @@ -1,10 +1,16 @@ """Authentication utility functions.""" +import time + from fastapi import HTTPException from starlette.datastructures import Headers +from log import get_logger +from metrics import recording from models.api.responses.error import UnauthorizedResponse +logger = get_logger(__name__) + def extract_user_token(headers: Headers) -> str: """Extract the bearer token from an HTTP Authorization header. @@ -33,3 +39,33 @@ def extract_user_token(headers: Headers) -> str: raise HTTPException(**response.model_dump()) return scheme_and_token[1] + + +def record_auth_metrics( + auth_module: str, result: str, reason: str, start_time: float +) -> None: + """Record authentication attempt and duration metrics together. + + Parameters: + ---------- + auth_module: Configured authentication module name. + result: Bounded result label, such as ``success`` or ``failure``. + reason: Bounded reason label for the result. + start_time: Monotonic clock time captured at the start of auth handling. + + Returns: + ------- + None: Metrics are recorded as a side effect. + """ + try: + recording.record_auth_attempt(auth_module, result, reason) + recording.record_auth_duration( + auth_module, result, time.monotonic() - start_time + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to record authentication metrics for module %s with result %s", + auth_module, + result, + exc_info=True, + ) diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py index 25a11a194..6833f58c6 100644 --- a/src/metrics/__init__.py +++ b/src/metrics/__init__.py @@ -22,6 +22,20 @@ float("inf"), ) +AUTH_DURATION_BUCKETS: Final[tuple[float, ...]] = ( + 0.001, + 0.005, + 0.01, + 0.025, + 0.05, + 0.1, + 0.25, + 0.5, + 1.0, + 2.5, + 5.0, + float("inf"), +) # Counter to track REST API calls # This will be used to count how many times each API endpoint is called # and the status code of the response @@ -72,10 +86,26 @@ ["provider", "model", "endpoint"], ) + # Histogram to measure the latency of direct LLM inference backend calls. -llm_inference_duration_seconds = Histogram( +llm_inference_duration_seconds: Final[Histogram] = Histogram( "ls_llm_inference_duration_seconds", "LLM inference call duration", ["provider", "model", "endpoint", "result"], buckets=LLM_INFERENCE_DURATION_BUCKETS, ) + +# Counter to track authentication attempts with bounded auth_module, result, and reason labels. +auth_attempts_total: Final[Counter] = Counter( + "ls_auth_attempts_total", + "Authentication attempts", + ["auth_module", "result", "reason"], +) + +# Histogram to measure authentication dependency latency with bounded module and result labels. +auth_duration_seconds: Final[Histogram] = Histogram( + "ls_auth_duration_seconds", + "Authentication duration", + ["auth_module", "result"], + buckets=AUTH_DURATION_BUCKETS, +) diff --git a/src/metrics/recording.py b/src/metrics/recording.py index a9b35d208..7dc0b1a12 100644 --- a/src/metrics/recording.py +++ b/src/metrics/recording.py @@ -10,10 +10,79 @@ from typing import Final import metrics +from constants import SUPPORTED_AUTHENTICATION_MODULES from log import get_logger logger = get_logger(__name__) +AUTH_RESULT_SUCCESS: Final[str] = "success" +AUTH_RESULT_FAILURE: Final[str] = "failure" +AUTH_RESULT_SKIPPED: Final[str] = "skipped" +AUTH_RESULT_UNKNOWN: Final[str] = "unknown" +AUTH_REASON_UNKNOWN: Final[str] = "unknown" + +ALLOWED_AUTH_RESULTS: Final[frozenset[str]] = frozenset( + { + AUTH_RESULT_SUCCESS, + AUTH_RESULT_FAILURE, + AUTH_RESULT_SKIPPED, + } +) +ALLOWED_AUTH_REASONS: Final[frozenset[str]] = frozenset( + { + "authenticated", + "authorization_check_error", + "empty_user_id", + "entitlement_missing", + "header_too_large", + "health_probe", + "invalid_base64", + "invalid_claim", + "invalid_identity", + "invalid_jwk", + "invalid_json", + "invalid_key", + "invalid_token", + "jwk_fetch_error", + "k8s_api_unavailable", + "k8s_config_error", + "malformed_token", + "metrics", + "missing_claim", + "missing_header", + "missing_token", + "no_auth_required", + "not_authorized", + "token_decode_error", + "token_expired", + "token_review_error", + "token_validation_error", + "unexpected_error", + "valid_key", + } +) + + +def normalize_auth_module(auth_module: str) -> str: + """Return a bounded authentication module label.""" + if auth_module in SUPPORTED_AUTHENTICATION_MODULES: + return auth_module + return AUTH_RESULT_UNKNOWN + + +def normalize_auth_result(result: str) -> str: + """Return a bounded authentication result label.""" + if result in ALLOWED_AUTH_RESULTS: + return result + return AUTH_RESULT_FAILURE + + +def normalize_auth_reason(reason: str) -> str: + """Return a bounded authentication reason label.""" + if reason in ALLOWED_AUTH_REASONS: + return reason + return AUTH_REASON_UNKNOWN + @contextmanager def measure_response_duration(path: str) -> Iterator[None]: @@ -157,3 +226,43 @@ def record_llm_inference_duration( ).observe(duration) except (AttributeError, TypeError, ValueError): logger.warning("Failed to update LLM inference duration metric", exc_info=True) + + +def record_auth_attempt(auth_module: str, result: str, reason: str) -> None: + """Record one authentication attempt. + + Args: + auth_module: Configured authentication module name. Unknown values are + recorded as ``unknown`` to keep metric cardinality bounded. + result: Bounded result label, such as ``success`` or ``failure``. + Unknown values are recorded as ``failure``. + reason: Bounded reason label for the result. Unknown values are recorded + as ``unknown``. + """ + try: + metrics.auth_attempts_total.labels( + normalize_auth_module(auth_module), + normalize_auth_result(result), + normalize_auth_reason(reason), + ).inc() + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update authentication metric", exc_info=True) + + +def record_auth_duration(auth_module: str, result: str, duration: float) -> None: + """Record authentication duration. + + Args: + auth_module: Configured authentication module name. Unknown values are + recorded as ``unknown`` to keep metric cardinality bounded. + result: Bounded result label, such as ``success`` or ``failure``. + Unknown values are recorded as ``failure``. + duration: Authentication duration in seconds. + """ + try: + metrics.auth_duration_seconds.labels( + normalize_auth_module(auth_module), + normalize_auth_result(result), + ).observe(duration) + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update authentication duration metric", exc_info=True) diff --git a/src/observability/__init__.py b/src/observability/__init__.py index 8b3981ccc..65373b703 100644 --- a/src/observability/__init__.py +++ b/src/observability/__init__.py @@ -14,12 +14,13 @@ build_inference_event, build_responses_event, ) -from observability.splunk import send_splunk_event +from observability.splunk import dispatch_splunk_event, send_splunk_event __all__ = [ "InferenceEventData", "build_inference_event", "ResponsesEventData", "build_responses_event", + "dispatch_splunk_event", "send_splunk_event", ] diff --git a/src/observability/splunk.py b/src/observability/splunk.py index f00d7a0bb..2763b0ef1 100644 --- a/src/observability/splunk.py +++ b/src/observability/splunk.py @@ -1,10 +1,12 @@ """Async Splunk HEC client for sending telemetry events.""" +import asyncio import platform import time from typing import Any, Optional import aiohttp +from fastapi import BackgroundTasks from configuration import configuration from log import get_logger @@ -88,3 +90,58 @@ async def send_splunk_event(event: dict[str, Any], sourcetype: str) -> None: logger.warning("Splunk HEC request failed: %s", e) except TimeoutError: logger.warning("Splunk HEC request timed out after %ds", splunk_config.timeout) + + +# Strong references for fire-and-forget telemetry tasks so they aren't +# garbage-collected before completion (the event loop only holds weak refs). +_fire_and_forget_tasks: set[asyncio.Task[None]] = set() + + +def _cleanup_fire_and_forget_task(task: asyncio.Task[None]) -> None: + """Remove completed task from tracking and surface unexpected failures. + + Called as a done-callback on fire-and-forget asyncio tasks. Without + explicit retrieval of the task result, any exception raised inside + ``send_splunk_event`` would go unobserved and trigger a noisy + "Task exception was never retrieved" warning at garbage-collection time. + """ + _fire_and_forget_tasks.discard(task) + try: + task.result() + except asyncio.CancelledError: + logger.debug("Splunk fire-and-forget task was cancelled") + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Splunk fire-and-forget task failed", exc_info=True) + + +def dispatch_splunk_event( + event: dict[str, Any], + sourcetype: str, + background_tasks: Optional[BackgroundTasks] = None, + fire_and_forget: bool = False, +) -> None: + """Dispatch a Splunk event via BackgroundTasks or fire-and-forget. + + Centralizes the two dispatch strategies used across endpoints: + + - **BackgroundTasks** (default): FastAPI runs the send after the response + completes. Preferred for successful responses. + - **fire-and-forget**: Creates an ``asyncio.Task`` directly, bypassing + BackgroundTasks. Required on error paths where FastAPI discards + BackgroundTasks for non-2xx responses. + + No-op when ``background_tasks`` is None and ``fire_and_forget`` is False. + + Args: + event: The Splunk event payload dict. + sourcetype: Splunk sourcetype for the event. + background_tasks: FastAPI background task manager, or None. + fire_and_forget: When True, dispatch via ``asyncio.create_task()`` + instead of ``background_tasks``. + """ + if fire_and_forget: + task = asyncio.create_task(send_splunk_event(event, sourcetype)) + _fire_and_forget_tasks.add(task) + task.add_done_callback(_cleanup_fire_and_forget_task) + elif background_tasks is not None: + background_tasks.add_task(send_splunk_event, event, sourcetype) diff --git a/tests/e2e-prow/rhoai/manifests/vllm/vllm-runtime-cpu.yaml b/tests/e2e-prow/rhoai/manifests/vllm/vllm-runtime-cpu.yaml index 990dc2df3..e0e630fd2 100644 --- a/tests/e2e-prow/rhoai/manifests/vllm/vllm-runtime-cpu.yaml +++ b/tests/e2e-prow/rhoai/manifests/vllm/vllm-runtime-cpu.yaml @@ -24,7 +24,7 @@ spec: - --port - "8080" - --max-model-len - - "32768" + - "35936" image: quay.io/rh-ee-cpompeia/vllm-cpu:latest name: kserve-container env: diff --git a/tests/e2e-prow/rhoai/manifests/vllm/vllm-runtime-gpu.yaml b/tests/e2e-prow/rhoai/manifests/vllm/vllm-runtime-gpu.yaml index e925890d2..23a322114 100644 --- a/tests/e2e-prow/rhoai/manifests/vllm/vllm-runtime-gpu.yaml +++ b/tests/e2e-prow/rhoai/manifests/vllm/vllm-runtime-gpu.yaml @@ -24,7 +24,7 @@ spec: - --port - "8080" - --max-model-len - - "32768" + - "35936" - --gpu-memory-utilization - "0.9" image: ${VLLM_IMAGE} diff --git a/tests/e2e-prow/rhoai/scripts/e2e-ops.sh b/tests/e2e-prow/rhoai/scripts/e2e-ops.sh index b98eafab3..b2191a158 100755 --- a/tests/e2e-prow/rhoai/scripts/e2e-ops.sh +++ b/tests/e2e-prow/rhoai/scripts/e2e-ops.sh @@ -192,10 +192,10 @@ verify_connectivity() { local http_code="" for ((attempt=1; attempt<=max_attempts; attempt++)); do - # First check /readiness to see if port-forward is alive (accept 200 or 401) + # First check /readiness to see if port-forward is alive (accept 200, 401, or 503) http_code=$(curl -s -o /dev/null -w '%{http_code}' --max-time 5 "http://localhost:$local_port/readiness" 2>/dev/null) || http_code="000" - if [[ "$http_code" == "200" || "$http_code" == "401" ]]; then + if [[ "$http_code" == "200" || "$http_code" == "401" || "$http_code" == "503" ]]; then # Port-forward works; now verify the app is fully initialized by hitting # a real endpoint. /v1/models requires the Llama Stack handshake to complete. # Accept 200 (no auth) or 401 (auth enabled) — both prove the full app diff --git a/tests/e2e/features/environment.py b/tests/e2e/features/environment.py index cce6c00a2..be8696077 100644 --- a/tests/e2e/features/environment.py +++ b/tests/e2e/features/environment.py @@ -237,6 +237,26 @@ def before_scenario(context: Context, scenario: Scenario) -> None: delattr(context, _attr) +def _dump_pod_logs_on_failure(scenario: Scenario, namespace: str) -> None: + """Dump llama-stack and lightspeed-stack pod logs when a scenario fails in Prow.""" + if scenario.status != "failed": + return + for pod in ("llama-stack-service", "lightspeed-stack-service"): + print(f"--- {pod} logs (scenario failed: {scenario.name}) ---") + try: + r = subprocess.run( + ["oc", "logs", pod, "-n", namespace, "--tail=100"], + capture_output=True, + text=True, + timeout=15, + check=False, + ) + print(r.stdout or r.stderr or "(no output)") + except subprocess.TimeoutExpired: + print("(timed out fetching logs)") + print(f"--- end {pod} logs ---") + + def after_scenario(context: Context, scenario: Scenario) -> None: """Run after each scenario is run. @@ -266,6 +286,11 @@ def after_scenario(context: Context, scenario: Scenario) -> None: used for the llama-stack health check. scenario (Scenario): Behave scenario (unused; shield restore uses context flags). """ + if is_prow_environment(): + _dump_pod_logs_on_failure( + scenario, os.environ.get("NAMESPACE", "e2e-rhoai-dsc") + ) + if getattr(context, "scenario_lightspeed_override_active", False): context.scenario_lightspeed_override_active = False feature_cfg = getattr(context, "feature_config", None) diff --git a/tests/unit/app/endpoints/test_responses_splunk.py b/tests/unit/app/endpoints/test_responses_splunk.py index 3c8485b1f..4495dbb33 100644 --- a/tests/unit/app/endpoints/test_responses_splunk.py +++ b/tests/unit/app/endpoints/test_responses_splunk.py @@ -15,19 +15,20 @@ from pytest_mock import MockerFixture from app.endpoints.responses import ( - _background_splunk_tasks, _get_user_agent, - _queue_responses_splunk_event, handle_non_streaming_response, handle_streaming_response, ) +from app.endpoints.responses_telemetry import queue_responses_splunk_event from configuration import AppConfig from models.api.requests import ResponsesRequest from models.common.turn_summary import RAGContext, TurnSummary from observability.formats.responses import ResponsesEventData +from observability.splunk import _fire_and_forget_tasks from tests.unit.app.endpoints.test_responses import build_api_params_and_context MODULE = "app.endpoints.responses" +TELEMETRY_MODULE = "app.endpoints.responses_telemetry" MOCK_AUTH = ( "00000001-0001-0001-0001-000000000001", "mock_username", @@ -99,16 +100,16 @@ def minimal_config_fixture() -> AppConfig: class TestQueueResponsesSplunkEvent: - """Unit tests for the _queue_responses_splunk_event helper.""" + """Unit tests for the queue_responses_splunk_event helper.""" def test_noop_when_background_tasks_is_none( self, mocker: MockerFixture, ) -> None: """Verify no-op when background_tasks is None (Splunk disabled).""" - mock_build = mocker.patch(f"{MODULE}.build_responses_event") + mock_build = mocker.patch(f"{TELEMETRY_MODULE}.build_responses_event") - _queue_responses_splunk_event( + queue_responses_splunk_event( background_tasks=None, input_text="user question", response_text="llm answer", @@ -128,11 +129,11 @@ def test_builds_event_and_queues_background_task( ) -> None: """Verify event is built from ResponsesEventData and queued via add_task.""" mock_build = mocker.patch( - f"{MODULE}.build_responses_event", return_value={"built": True} + f"{TELEMETRY_MODULE}.build_responses_event", return_value={"built": True} ) - mock_send = mocker.patch(f"{MODULE}.send_splunk_event") + mock_dispatch = mocker.patch(f"{TELEMETRY_MODULE}.dispatch_splunk_event") - _queue_responses_splunk_event( + queue_responses_splunk_event( background_tasks=mock_background_tasks, input_text="user question", response_text="llm answer", @@ -158,8 +159,11 @@ def test_builds_event_and_queues_background_task( assert event_data.input_tokens == 100 assert event_data.output_tokens == 50 - mock_background_tasks.add_task.assert_called_once_with( - mock_send, {"built": True}, "responses_completed" + mock_dispatch.assert_called_once_with( + {"built": True}, + "responses_completed", + background_tasks=mock_background_tasks, + fire_and_forget=False, ) def test_fire_and_forget_dispatches_via_create_task( @@ -167,17 +171,21 @@ def test_fire_and_forget_dispatches_via_create_task( mocker: MockerFixture, ) -> None: """fire_and_forget=True dispatches via asyncio.create_task with GC protection.""" - mocker.patch(f"{MODULE}.build_responses_event", return_value={"built": True}) + mocker.patch( + f"{TELEMETRY_MODULE}.build_responses_event", return_value={"built": True} + ) # Use MagicMock (not AsyncMock) so send_splunk_event() returns a # comparable return_value instead of a coroutine object. - mock_send = mocker.patch(f"{MODULE}.send_splunk_event", new=mocker.MagicMock()) + mock_send = mocker.patch( + "observability.splunk.send_splunk_event", new=mocker.MagicMock() + ) mock_task = mocker.MagicMock() mock_create_task = mocker.patch("asyncio.create_task", return_value=mock_task) # Clear any leftover tasks from other tests - _background_splunk_tasks.clear() + _fire_and_forget_tasks.clear() - _queue_responses_splunk_event( + queue_responses_splunk_event( background_tasks=None, input_text="user question", response_text="error message", @@ -193,14 +201,14 @@ def test_fire_and_forget_dispatches_via_create_task( mock_create_task.assert_called_once_with(mock_send.return_value) # Task is held in the module-level set to prevent GC - assert mock_task in _background_splunk_tasks + assert mock_task in _fire_and_forget_tasks # done_callback registered to clean up after completion mock_task.add_done_callback.assert_called_once() # Simulate task completion: callback removes from set done_callback = mock_task.add_done_callback.call_args[0][0] done_callback(mock_task) - assert mock_task not in _background_splunk_tasks + assert mock_task not in _fire_and_forget_tasks # --------------------------------------------------------------------------- @@ -209,7 +217,7 @@ def test_fire_and_forget_dispatches_via_create_task( class TestSplunkTelemetryHooks: - """Integration tests verifying _queue_responses_splunk_event is called at each hook.""" + """Integration tests verifying queue_responses_splunk_event is called at each hook.""" # -- Non-streaming paths ------------------------------------------------ @@ -256,7 +264,7 @@ async def test_non_streaming_shield_blocked( f"{MODULE}.OpenAIResponseObject.model_construct", return_value=mock_api_response, ) - mock_queue = mocker.patch(f"{MODULE}._queue_responses_splunk_event") + mock_queue = mocker.patch(f"{TELEMETRY_MODULE}.queue_responses_splunk_event") api_params, context = build_api_params_and_context( updated_request=request, @@ -336,7 +344,7 @@ async def test_non_streaming_error_fires_telemetry( } ), ) - mock_queue = mocker.patch(f"{MODULE}._queue_responses_splunk_event") + mock_queue = mocker.patch(f"{TELEMETRY_MODULE}.queue_responses_splunk_event") with pytest.raises(HTTPException): api_params, context = build_api_params_and_context( @@ -420,7 +428,7 @@ async def test_non_streaming_success( mocker.patch(f"{MODULE}.build_turn_summary", return_value=mock_turn_summary) mocker.patch(f"{MODULE}.deduplicate_referenced_documents", return_value=[]) - mock_queue = mocker.patch(f"{MODULE}._queue_responses_splunk_event") + mock_queue = mocker.patch(f"{TELEMETRY_MODULE}.queue_responses_splunk_event") api_params, context = build_api_params_and_context( updated_request=request, @@ -483,7 +491,7 @@ async def test_streaming_shield_blocked( mocker.patch(f"{MODULE}.store_query_results") mock_client.conversations.items.create = mocker.AsyncMock() - mock_queue = mocker.patch(f"{MODULE}._queue_responses_splunk_event") + mock_queue = mocker.patch(f"{TELEMETRY_MODULE}.queue_responses_splunk_event") api_params, context = build_api_params_and_context( updated_request=request, @@ -564,7 +572,7 @@ async def test_streaming_error_fires_telemetry( } ), ) - mock_queue = mocker.patch(f"{MODULE}._queue_responses_splunk_event") + mock_queue = mocker.patch(f"{TELEMETRY_MODULE}.queue_responses_splunk_event") with pytest.raises(HTTPException): api_params, context = build_api_params_and_context( @@ -650,7 +658,7 @@ async def mock_stream() -> Any: mock_holder.get_client.return_value = mock_client mocker.patch(f"{MODULE}.AsyncLlamaStackClientHolder", return_value=mock_holder) - mock_queue = mocker.patch(f"{MODULE}._queue_responses_splunk_event") + mock_queue = mocker.patch(f"{TELEMETRY_MODULE}.queue_responses_splunk_event") api_params, context = build_api_params_and_context( updated_request=request, @@ -691,7 +699,7 @@ async def test_splunk_disabled_no_background_tasks( minimal_config: AppConfig, mocker: MockerFixture, ) -> None: - """When background_tasks is None, _queue_responses_splunk_event is never called.""" + """When background_tasks is None, queue_responses_splunk_event is called but is a no-op.""" request = _request_with_model_and_conv("Bad input") mock_client = mocker.AsyncMock(spec=AsyncLlamaStackClient) @@ -727,7 +735,7 @@ async def test_splunk_disabled_no_background_tasks( f"{MODULE}.OpenAIResponseObject.model_construct", return_value=mock_api_response, ) - mock_queue = mocker.patch(f"{MODULE}._queue_responses_splunk_event") + mock_queue = mocker.patch(f"{TELEMETRY_MODULE}.queue_responses_splunk_event") # background_tasks=None (the default) means Splunk is disabled api_params, context = build_api_params_and_context( diff --git a/tests/unit/authentication/test_jwk_token.py b/tests/unit/authentication/test_jwk_token.py index 7b3ae4d2f..e4ae16e2d 100644 --- a/tests/unit/authentication/test_jwk_token.py +++ b/tests/unit/authentication/test_jwk_token.py @@ -13,6 +13,7 @@ from pytest_mock import MockerFixture from authentication.jwk_token import JwkTokenAuthDependency, _jwk_cache +from constants import AUTH_MOD_JWK_TOKEN from models.config import JwkConfiguration, JwtConfiguration TEST_USER_ID = "test-user-123" @@ -503,6 +504,29 @@ async def test_no_bearer( assert detail["cause"] == "No token found in Authorization header" +@pytest.mark.asyncio +async def test_unexpected_token_extraction_error_records_metrics( + mocker: MockerFixture, + default_jwk_configuration: JwkConfiguration, + valid_token: str, +) -> None: + """Test unexpected token extraction errors are recorded before re-raising.""" + mocker.patch( + "authentication.jwk_token.extract_user_token", + side_effect=RuntimeError("header parser failed"), + ) + mock_metrics = mocker.patch("authentication.jwk_token.record_auth_metrics") + mocker.patch("authentication.jwk_token.time.monotonic", return_value=7.0) + dependency = JwkTokenAuthDependency(default_jwk_configuration) + + with pytest.raises(RuntimeError): + await dependency(dummy_request(valid_token)) + + mock_metrics.assert_called_once_with( + AUTH_MOD_JWK_TOKEN, "failure", "unexpected_error", 7.0 + ) + + @pytest.fixture def no_user_id_token( single_key_set: list[dict[str, Any]], @@ -532,6 +556,21 @@ def no_user_id_token( ).decode() +@pytest.fixture +def invalid_user_id_token( + single_key_set: list[dict[str, Any]], + token_payload: dict[str, Any], + token_header: dict[str, Any], +) -> str: + """Token with an invalid user_id claim value.""" + jwt_instance = JsonWebToken(algorithms=["RS256"]) + token_payload["user_id"] = "" + + return jwt_instance.encode( + token_header, token_payload, single_key_set[0]["private_key"] + ).decode() + + @pytest.mark.asyncio async def test_no_user_id( default_jwk_configuration: JwkConfiguration, @@ -552,6 +591,24 @@ async def test_no_user_id( ) +@pytest.mark.asyncio +async def test_invalid_user_id_claim_has_exception_chain( + default_jwk_configuration: JwkConfiguration, + mocked_signing_keys_server: Any, + invalid_user_id_token: str, +) -> None: + """Test invalid required claims preserve root-cause exception context.""" + _ = mocked_signing_keys_server + + dependency = JwkTokenAuthDependency(default_jwk_configuration) + + with pytest.raises(HTTPException) as exc_info: + await dependency(dummy_request(invalid_user_id_token)) + + assert exc_info.value.status_code == 401 + assert isinstance(exc_info.value.__cause__, ValueError) + + @pytest.fixture def no_username_token( single_key_set: list[dict[str, Any]], diff --git a/tests/unit/authentication/test_k8s.py b/tests/unit/authentication/test_k8s.py index 94e0eb633..2a8563796 100644 --- a/tests/unit/authentication/test_k8s.py +++ b/tests/unit/authentication/test_k8s.py @@ -8,7 +8,7 @@ import pytest from fastapi import HTTPException, Request -from kubernetes.client import AuthenticationV1Api, AuthorizationV1Api +from kubernetes.client import AuthenticationV1Api, AuthorizationV1Api, V1UserInfo from kubernetes.client.rest import ApiException from pytest_mock import MockerFixture @@ -21,6 +21,7 @@ K8SAuthDependency, K8sClientSingleton, K8sConfigurationError, + _create_subject_access_review, get_user_info, ) from configuration import AppConfig @@ -1135,3 +1136,90 @@ def test_get_user_info_api_error_handling( detail = cast(dict[str, str], exc_info.value.detail) assert detail["response"] == expected_response assert expected_cause_fragment in detail["cause"] + + +@pytest.mark.parametrize( + "api_status,reason,expected_status,expected_response,expected_cause_fragment", + [ + ( + HTTPStatus.SERVICE_UNAVAILABLE, + "Service Unavailable", + 503, + "Unable to connect to Kubernetes API", + "Service Unavailable", + ), + ( + HTTPStatus.TOO_MANY_REQUESTS, + "Too Many Requests", + 503, + "Unable to connect to Kubernetes API", + "Too Many Requests", + ), + ( + None, + "Connection failed", + 503, + "Unable to connect to Kubernetes API", + "Connection failed", + ), + ( + HTTPStatus.BAD_REQUEST, + "Bad Request", + 500, + "Internal server error", + "Bad Request", + ), + ], +) +def test_create_subject_access_review_api_error_handling( + mocker: MockerFixture, + api_status: Optional[int], + reason: str, + expected_status: int, + expected_response: str, + expected_cause_fragment: str, +) -> None: + """Test SubjectAccessReview maps Kubernetes API errors consistently.""" + mock_authz_api = mocker.patch("authentication.k8s.K8sClientSingleton.get_authz_api") + mock_authz_api.return_value.create_subject_access_review.side_effect = ApiException( + status=api_status, reason=reason + ) + mock_metrics = mocker.patch("authentication.k8s.record_auth_metrics") + user = cast( + V1UserInfo, MockK8sUser(username="user@example.com", groups=["lsc-group"]) + ) + + with pytest.raises(HTTPException) as exc_info: + _create_subject_access_review(user, "/api/lightspeed/v1/query", 10.0) + + assert exc_info.value.status_code == expected_status + detail = cast(dict[str, str], exc_info.value.detail) + assert detail["response"] == expected_response + assert expected_cause_fragment in detail["cause"] + mock_metrics.assert_called_once_with( + "k8s", "failure", "authorization_check_error", 10.0 + ) + + +def test_create_subject_access_review_authz_init_failure( + mocker: MockerFixture, +) -> None: + """Test get_authz_api() init failure records authorization check errors.""" + mocker.patch( + "authentication.k8s.K8sClientSingleton.get_authz_api", + side_effect=RuntimeError("failed to load kubeconfig"), + ) + mock_metrics = mocker.patch("authentication.k8s.record_auth_metrics") + user = cast( + V1UserInfo, MockK8sUser(username="user@example.com", groups=["lsc-group"]) + ) + + with pytest.raises(HTTPException) as exc_info: + _create_subject_access_review(user, "/api/lightspeed/v1/query", 10.0) + + assert exc_info.value.status_code == HTTPStatus.SERVICE_UNAVAILABLE + detail = cast(dict[str, str], exc_info.value.detail) + assert "Unable to initialize Kubernetes client" in detail["cause"] + mock_metrics.assert_called_once_with( + "k8s", "failure", "authorization_check_error", 10.0 + ) diff --git a/tests/unit/authentication/test_rh_identity.py b/tests/unit/authentication/test_rh_identity.py index 814678d75..aa08f85a2 100644 --- a/tests/unit/authentication/test_rh_identity.py +++ b/tests/unit/authentication/test_rh_identity.py @@ -254,6 +254,14 @@ def test_validate_entitlements( {"identity": {"type": "User", "org_id": "123"}}, "Invalid identity data", ), + ( + {"identity": {"type": "User", "org_id": "123", "user": 1}}, + "Invalid identity data", + ), + ( + {"identity": {"type": "User", "org_id": "123", "user": []}}, + "Invalid identity data", + ), ( { "identity": { @@ -278,6 +286,10 @@ def test_validate_entitlements( {"identity": {"type": "System", "org_id": "123"}}, "Invalid identity data", ), + ( + {"identity": {"type": "System", "org_id": "123", "system": 1}}, + "Invalid identity data", + ), ( {"identity": {"type": "System", "org_id": "123", "system": {}}}, "Invalid identity data", @@ -424,6 +436,49 @@ async def test_invalid_json(self, mocker: MockerFixture) -> None: assert exc_info.value.status_code == 400 assert "Invalid JSON" in str(exc_info.value.detail) + @pytest.mark.asyncio + @pytest.mark.parametrize( + "identity_data", + [ + pytest.param({"identity": 1}, id="identity-not-dict"), + pytest.param( + {"identity": {"type": "User", "user": 1}}, + id="user-not-dict", + ), + pytest.param( + {"identity": {"type": "User", "user": []}}, + id="user-list", + ), + pytest.param( + {"identity": {"type": "System", "system": 1}}, + id="system-not-dict", + ), + ], + ) + async def test_invalid_nested_identity_payloads_record_metrics( + self, + mocker: MockerFixture, + identity_data: dict, + ) -> None: + """Test malformed nested identity payloads record invalid identity metrics.""" + auth_dep = RHIdentityAuthDependency() + header_value = create_auth_header(identity_data) + request = create_request_with_header(mocker, header_value) + mock_record = mocker.patch( + "authentication.rh_identity._record_rh_identity_auth" + ) + + with pytest.raises(HTTPException) as exc_info: + await auth_dep(request) + + assert exc_info.value.status_code == 400 + assert "Invalid identity data" in str(exc_info.value.detail) + mock_record.assert_called_once() + result, reason, duration = mock_record.call_args.args + assert result == "failure" + assert reason == "invalid_identity" + assert duration >= 0 + @pytest.mark.asyncio @pytest.mark.parametrize( "required_entitlements,should_raise,expected_error", @@ -509,9 +564,9 @@ def _mock_configuration( "path", [ "/readiness", + "/readiness/", "/liveness", - "/api/lightspeed/readiness", - "/api/lightspeed/liveness", + "/liveness/", ], ) async def test_probe_paths_skip_auth_when_enabled( @@ -532,8 +587,6 @@ async def test_probe_paths_skip_auth_when_enabled( [ "/readiness", "/liveness", - "/api/lightspeed/readiness", - "/api/lightspeed/liveness", ], ) async def test_probe_paths_require_auth_when_disabled( @@ -550,11 +603,19 @@ async def test_probe_paths_require_auth_when_disabled( assert exc_info.value.status_code == 401 @pytest.mark.asyncio - @pytest.mark.parametrize("path", ["/", "/v1/query"]) + @pytest.mark.parametrize( + "path", + [ + "/", + "/v1/query", + "/api/lightspeed/readiness", + "/api/lightspeed/liveness", + ], + ) async def test_non_probe_paths_require_auth_when_skip_enabled( self, mocker: MockerFixture, path: str ) -> None: - """Test non-probe paths still require auth even when skip_for_health_probes is True.""" + """Test non-probe paths still require auth even when probe skipping is enabled.""" self._mock_configuration(mocker, skip_for_health_probes=True) auth_dep = RHIdentityAuthDependency() @@ -587,7 +648,7 @@ def _mock_configuration(mocker: MockerFixture, skip_for_metrics: bool) -> None: "path", [ "/metrics", - "/api/lightspeed/metrics", + "/metrics/", ], ) async def test_metrics_path_skips_auth_when_enabled( @@ -607,7 +668,6 @@ async def test_metrics_path_skips_auth_when_enabled( "path", [ "/metrics", - "/api/lightspeed/metrics", ], ) async def test_metrics_path_requires_auth_when_disabled( @@ -624,11 +684,19 @@ async def test_metrics_path_requires_auth_when_disabled( assert exc_info.value.status_code == 401 @pytest.mark.asyncio - @pytest.mark.parametrize("path", ["/", "/v1/query"]) + @pytest.mark.parametrize( + "path", + [ + "/", + "/v1/query", + "/api/lightspeed/metrics", + "/v1/notmetrics", + ], + ) async def test_non_metrics_paths_require_auth_when_skip_enabled( self, mocker: MockerFixture, path: str ) -> None: - """Test non-metrics paths still require auth even when skip_for_metrics is True.""" + """Test non-metrics paths still require auth even when metrics skipping is enabled.""" self._mock_configuration(mocker, skip_for_metrics=True) auth_dep = RHIdentityAuthDependency() diff --git a/tests/unit/authentication/test_utils.py b/tests/unit/authentication/test_utils.py index 36490d2bb..a3c4c43d9 100644 --- a/tests/unit/authentication/test_utils.py +++ b/tests/unit/authentication/test_utils.py @@ -3,9 +3,10 @@ from typing import cast from fastapi import HTTPException +from pytest_mock import MockerFixture from starlette.datastructures import Headers -from authentication.utils import extract_user_token +from authentication.utils import extract_user_token, record_auth_metrics def test_extract_user_token() -> None: @@ -41,3 +42,50 @@ def test_extract_user_token_invalid_format() -> None: "Missing or invalid credentials provided by client" ) assert detail["cause"] == "No token found in Authorization header" + + +def test_record_auth_metrics_records_attempt_and_duration( + mocker: MockerFixture, +) -> None: + """Test recording auth attempt and duration through the shared helper.""" + mock_attempt = mocker.patch("authentication.utils.recording.record_auth_attempt") + mock_duration = mocker.patch("authentication.utils.recording.record_auth_duration") + mock_monotonic = mocker.patch("authentication.utils.time.monotonic") + mock_monotonic.return_value = 12.5 + + record_auth_metrics("jwk-token", "success", "authenticated", 10.0) + + mock_attempt.assert_called_once_with("jwk-token", "success", "authenticated") + mock_duration.assert_called_once_with("jwk-token", "success", 2.5) + + +def test_record_auth_metrics_records_failure_attempt_and_duration( + mocker: MockerFixture, +) -> None: + """Test recording failed auth attempts and duration through the shared helper.""" + mock_attempt = mocker.patch("authentication.utils.recording.record_auth_attempt") + mock_duration = mocker.patch("authentication.utils.recording.record_auth_duration") + mock_monotonic = mocker.patch("authentication.utils.time.monotonic") + mock_monotonic.return_value = 15.25 + + record_auth_metrics("jwk-token", "failure", "missing_token", 10.0) + + mock_attempt.assert_called_once_with("jwk-token", "failure", "missing_token") + mock_duration.assert_called_once_with("jwk-token", "failure", 5.25) + + +def test_record_auth_metrics_swallows_recording_exceptions( + mocker: MockerFixture, +) -> None: + """Test that record_auth_metrics catches and logs recording failures.""" + mocker.patch( + "authentication.utils.recording.record_auth_attempt", + side_effect=RuntimeError("metrics backend unavailable"), + ) + mock_warning = mocker.patch("authentication.utils.logger.warning") + + # Should not raise despite recording failure + record_auth_metrics("jwk-token", "failure", "missing_token", 10.0) + + mock_warning.assert_called_once() + assert "Failed to record authentication metrics" in mock_warning.call_args[0][0] diff --git a/tests/unit/metrics/test_recording.py b/tests/unit/metrics/test_recording.py index 6cf4fcb6e..89d8e8eb6 100644 --- a/tests/unit/metrics/test_recording.py +++ b/tests/unit/metrics/test_recording.py @@ -234,3 +234,81 @@ def test_record_llm_inference_duration_bounds_result_label( "vertexai", "gemini", "/v1/responses", "failure" ) mock_metric.labels.return_value.observe.assert_called_once_with(2.0) + + +def test_record_auth_attempt_updates_metric( + mocker: MockerFixture, +) -> None: + """Test auth attempt helper success behavior.""" + mock_metric = mocker.patch("metrics.recording.metrics.auth_attempts_total") + + recording.record_auth_attempt("rh-identity", "success", "authenticated") + + mock_metric.labels.assert_called_once_with( + "rh-identity", "success", "authenticated" + ) + mock_metric.labels.return_value.inc.assert_called_once() + + +def test_record_auth_attempt_bounds_labels(mocker: MockerFixture) -> None: + """Test auth attempt helper normalizes unbounded label values.""" + mock_metric = mocker.patch("metrics.recording.metrics.auth_attempts_total") + + recording.record_auth_attempt("customer-123", "timeout", "database-down") + + mock_metric.labels.assert_called_once_with("unknown", "failure", "unknown") + mock_metric.labels.return_value.inc.assert_called_once() + + +def test_record_auth_attempt_logs_metric_errors( + mocker: MockerFixture, + recording_logger: MockType, +) -> None: + """Test auth attempt helper logs and swallows metric failures.""" + mock_metric = mocker.patch("metrics.recording.metrics.auth_attempts_total") + + mock_metric.labels.return_value.inc.side_effect = AttributeError("missing") + + recording.record_auth_attempt("rh-identity", "success", "authenticated") + + recording_logger.warning.assert_called_once_with( + "Failed to update authentication metric", exc_info=True + ) + + +def test_record_auth_duration_updates_metric( + mocker: MockerFixture, +) -> None: + """Test auth duration helper success behavior.""" + mock_metric = mocker.patch("metrics.recording.metrics.auth_duration_seconds") + + recording.record_auth_duration("rh-identity", "success", 0.5) + + mock_metric.labels.assert_called_once_with("rh-identity", "success") + mock_metric.labels.return_value.observe.assert_called_once_with(0.5) + + +def test_record_auth_duration_bounds_labels(mocker: MockerFixture) -> None: + """Test auth duration helper normalizes unbounded label values.""" + mock_metric = mocker.patch("metrics.recording.metrics.auth_duration_seconds") + + recording.record_auth_duration("customer-123", "timeout", 0.5) + + mock_metric.labels.assert_called_once_with("unknown", "failure") + mock_metric.labels.return_value.observe.assert_called_once_with(0.5) + + +def test_record_auth_duration_logs_metric_errors( + mocker: MockerFixture, + recording_logger: MockType, +) -> None: + """Test auth duration helper logs and swallows metric failures.""" + mock_metric = mocker.patch("metrics.recording.metrics.auth_duration_seconds") + + mock_metric.labels.return_value.observe.side_effect = TypeError("bad") + + recording.record_auth_duration("rh-identity", "success", 0.5) + + recording_logger.warning.assert_called_once_with( + "Failed to update authentication duration metric", exc_info=True + ) diff --git a/tests/unit/observability/test_splunk.py b/tests/unit/observability/test_splunk.py index b4b842ea6..4360171e2 100644 --- a/tests/unit/observability/test_splunk.py +++ b/tests/unit/observability/test_splunk.py @@ -1,5 +1,7 @@ """Unit tests for Splunk HEC client.""" +import asyncio +from collections.abc import Generator from pathlib import Path from typing import Any, Optional @@ -7,7 +9,13 @@ import pytest from pytest_mock import MockerFixture -from observability.splunk import _read_token_from_file, send_splunk_event +from observability.splunk import ( + _cleanup_fire_and_forget_task, + _fire_and_forget_tasks, + _read_token_from_file, + dispatch_splunk_event, + send_splunk_event, +) @pytest.fixture(name="mock_splunk_config") @@ -185,3 +193,126 @@ async def test_logs_warning_on_error( await send_splunk_event({"test": "event"}, "test_sourcetype") mock_logger.warning.assert_called() + + +# --------------------------------------------------------------------------- +# dispatch_splunk_event tests +# --------------------------------------------------------------------------- + + +class TestDispatchSplunkEvent: + """Tests for the dispatch_splunk_event dispatcher function.""" + + @pytest.fixture(autouse=True) + def _cleanup_fire_and_forget(self) -> Generator[None, None, None]: + """Ensure _fire_and_forget_tasks is cleaned after each test.""" + yield + _fire_and_forget_tasks.clear() + + def test_noop_when_no_dispatch_method(self, mocker: MockerFixture) -> None: + """No-op when background_tasks is None and fire_and_forget is False.""" + mock_send = mocker.patch("observability.splunk.send_splunk_event") + mock_create_task = mocker.patch("observability.splunk.asyncio.create_task") + + dispatch_splunk_event({"k": "v"}, "test_sourcetype") + + mock_send.assert_not_called() + mock_create_task.assert_not_called() + + def test_dispatches_via_background_tasks(self, mocker: MockerFixture) -> None: + """Queues send_splunk_event via BackgroundTasks when provided.""" + mock_bg = mocker.MagicMock() + + dispatch_splunk_event({"k": "v"}, "test_sourcetype", background_tasks=mock_bg) + + mock_bg.add_task.assert_called_once_with( + send_splunk_event, {"k": "v"}, "test_sourcetype" + ) + + def test_dispatches_fire_and_forget(self, mocker: MockerFixture) -> None: + """Creates asyncio task and registers it for GC protection.""" + sentinel_task = mocker.MagicMock() + mock_create_task = mocker.patch( + "observability.splunk.asyncio.create_task", return_value=sentinel_task + ) + # Prevent real coroutine creation; the mock returns a coroutine-like + # object that create_task can accept. + mocker.patch("observability.splunk.send_splunk_event") + + dispatch_splunk_event({"k": "v"}, "test_sourcetype", fire_and_forget=True) + + mock_create_task.assert_called_once() + assert sentinel_task in _fire_and_forget_tasks + sentinel_task.add_done_callback.assert_called_once_with( + _cleanup_fire_and_forget_task + ) + + def test_fire_and_forget_takes_priority(self, mocker: MockerFixture) -> None: + """When both background_tasks and fire_and_forget are set, fire-and-forget wins.""" + mock_bg = mocker.MagicMock() + sentinel_task = mocker.MagicMock() + mocker.patch( + "observability.splunk.asyncio.create_task", return_value=sentinel_task + ) + mocker.patch("observability.splunk.send_splunk_event") + + dispatch_splunk_event( + {"k": "v"}, + "test_sourcetype", + background_tasks=mock_bg, + fire_and_forget=True, + ) + + mock_bg.add_task.assert_not_called() + assert sentinel_task in _fire_and_forget_tasks + + +# --------------------------------------------------------------------------- +# _cleanup_fire_and_forget_task tests +# --------------------------------------------------------------------------- + + +class TestCleanupFireAndForgetTask: + """Tests for the fire-and-forget done-callback.""" + + @pytest.fixture(autouse=True) + def _cleanup_fire_and_forget(self) -> Generator[None, None, None]: + """Ensure _fire_and_forget_tasks is cleaned after each test.""" + yield + _fire_and_forget_tasks.clear() + + def test_discards_task_on_success(self, mocker: MockerFixture) -> None: + """Successful task is removed from tracking set.""" + task = mocker.MagicMock() + task.result.return_value = None + _fire_and_forget_tasks.add(task) + + _cleanup_fire_and_forget_task(task) + + assert task not in _fire_and_forget_tasks + + def test_logs_debug_on_cancellation(self, mocker: MockerFixture) -> None: + """Cancelled task logs at debug level and is removed.""" + task = mocker.MagicMock() + task.result.side_effect = asyncio.CancelledError() + _fire_and_forget_tasks.add(task) + mock_logger = mocker.patch("observability.splunk.logger") + + _cleanup_fire_and_forget_task(task) + + assert task not in _fire_and_forget_tasks + mock_logger.debug.assert_called_once() + + def test_logs_warning_on_exception(self, mocker: MockerFixture) -> None: + """Failed task logs warning with exc_info and is removed.""" + task = mocker.MagicMock() + task.result.side_effect = RuntimeError("connection refused") + _fire_and_forget_tasks.add(task) + mock_logger = mocker.patch("observability.splunk.logger") + + _cleanup_fire_and_forget_task(task) + + assert task not in _fire_and_forget_tasks + mock_logger.warning.assert_called_once_with( + "Splunk fire-and-forget task failed", exc_info=True + )