diff --git a/src/app/endpoints/responses.py b/src/app/endpoints/responses.py index 1e49a050f..a12721260 100644 --- a/src/app/endpoints/responses.py +++ b/src/app/endpoints/responses.py @@ -4,7 +4,9 @@ import asyncio import json +import time from collections.abc import AsyncIterator +from dataclasses import dataclass from datetime import UTC, datetime from typing import Annotated, Any, Final, Optional, cast @@ -39,6 +41,7 @@ from configuration import configuration from constants import SUBSTITUTED_INSTRUCTIONS_PLACEHOLDER from log import get_logger +from metrics import recording from models.config import Action from models.requests import ResponsesRequest from models.responses import ( @@ -115,6 +118,80 @@ _USER_AGENT_MAX_LENGTH: Final[int] = 128 +@dataclass +class StreamChunkProcessingResult: + """Result of preparing one Responses API stream chunk for SSE output.""" + + event: Optional[str] + sequence_number: int + latest_response_object: Optional[OpenAIResponseObject] = None + inference_metric_recorded: bool = False + + +def _record_response_inference_result( + model_id: str, + endpoint_path: str, + result: str, + duration: float, + record_failure: bool = False, +) -> None: + """Record backend inference metrics for a Responses API call. + + Args: + model_id: Model identifier in provider/model format. + endpoint_path: API endpoint path for metric labeling. + result: Bounded result label for the inference duration metric. + duration: Backend inference call duration in seconds. + record_failure: Whether to increment the LLM failure counter as well. + """ + provider, model = extract_provider_and_model_from_model_id(model_id) + if record_failure: + recording.record_llm_failure(provider, model, endpoint_path) + recording.record_llm_inference_duration( + provider, model, endpoint_path, result, duration + ) + + +def _check_response_quota(user_id: str, endpoint_path: str) -> None: + """Check response quota availability and record bounded quota metrics.""" + quota_start_time = time.monotonic() + try: + check_tokens_available(configuration.quota_limiters, user_id) + except HTTPException: + recording.record_quota_check( + endpoint_path, + "user_id", + "failure", + time.monotonic() - quota_start_time, + ) + raise + recording.record_quota_check( + endpoint_path, + "user_id", + "success", + time.monotonic() - quota_start_time, + ) + + +def _record_stream_terminal_inference_result( + event_type: Optional[str], + api_params: ResponsesApiParams, + endpoint_path: str, + inference_start_time: Optional[float], +) -> None: + """Record inference metrics after a terminal streaming response event.""" + if inference_start_time is None: + return + inference_result = "success" if event_type == "response.completed" else "failure" + _record_response_inference_result( + api_params.model, + endpoint_path, + inference_result, + time.monotonic() - inference_start_time, + record_failure=(inference_result == "failure"), + ) + + def _get_user_agent(request: Request) -> Optional[str]: """Extract and sanitize the User-Agent header from the request. @@ -278,8 +355,9 @@ async def responses_endpoint_handler( await check_mcp_auth(configuration, mcp_headers, token, request.headers) - # Check token availability - check_tokens_available(configuration.quota_limiters, user_id) + endpoint_path = "/v1/responses" + + _check_response_quota(user_id, endpoint_path) # Enforce RBAC: optionally disallow overriding model in requests validate_model_provider_override( @@ -331,7 +409,6 @@ async def responses_endpoint_handler( ) attachments_text = extract_attachments_text(original_request.input) - endpoint_path = "/v1/responses" moderation_result = await run_shield_moderation( client, input_text + "\n\n" + attachments_text, @@ -461,6 +538,7 @@ async def handle_streaming_response( user_agent=user_agent, ) else: + inference_start_time = time.monotonic() try: response = await client.responses.create( **api_params.model_dump(exclude_none=True) @@ -475,8 +553,16 @@ async def handle_streaming_response( inline_rag_context=inline_rag_context, filter_server_tools=filter_server_tools, endpoint_path=endpoint_path, + inference_start_time=inference_start_time, ) except RuntimeError as e: # library mode wraps 413 into runtime error + _record_response_inference_result( + api_params.model, + endpoint_path, + "failure", + time.monotonic() - inference_start_time, + record_failure=True, + ) if is_context_length_error(str(e)): _queue_responses_splunk_event( background_tasks=background_tasks, @@ -494,6 +580,13 @@ async def handle_streaming_response( raise HTTPException(**error_response.model_dump()) from e raise e except APIConnectionError as e: + _record_response_inference_result( + api_params.model, + endpoint_path, + "failure", + time.monotonic() - inference_start_time, + record_failure=True, + ) _queue_responses_splunk_event( background_tasks=background_tasks, input_text=input_text, @@ -512,6 +605,13 @@ async def handle_streaming_response( ) raise HTTPException(**error_response.model_dump()) from e except (LLSApiStatusError, OpenAIAPIStatusError) as e: + _record_response_inference_result( + api_params.model, + endpoint_path, + "failure", + time.monotonic() - inference_start_time, + record_failure=True, + ) _queue_responses_splunk_event( background_tasks=background_tasks, input_text=input_text, @@ -794,6 +894,102 @@ def _populate_turn_summary( turn_summary.rag_chunks = inline_rag_context.rag_chunks + tool_rag_chunks +def _process_stream_chunk( # pylint: disable=too-many-arguments,too-many-positional-arguments + chunk: OpenAIResponseObjectStream, + original_request: ResponsesRequest, + api_params: ResponsesApiParams, + user_id: str, + turn_summary: TurnSummary, + configured_mcp_labels: set[str], + server_mcp_output_indices: set[int], + sequence_number: int, + endpoint_path: str, + inference_start_time: Optional[float], + normalized_conv_id: str, +) -> StreamChunkProcessingResult: + """Prepare one streaming chunk for SSE output and terminal accounting.""" + event_type = getattr(chunk, "type", None) + logger.debug("Processing streaming chunk, type: %s", event_type) + + if _should_filter_mcp_chunk( + chunk, event_type, configured_mcp_labels, server_mcp_output_indices + ): + return StreamChunkProcessingResult(event=None, sequence_number=sequence_number) + + chunk_dict = chunk.model_dump(exclude_none=True, by_alias=True) + chunk_dict["sequence_number"] = sequence_number + sequence_number += 1 + + if "response" in chunk_dict: + chunk_dict["response"]["conversation"] = normalized_conv_id + _sanitize_response_dict( + chunk_dict["response"], configured_mcp_labels, original_request + ) + tools = chunk_dict["response"].get("tools") + if tools is not None: + chunk_dict["response"]["tools"] = translate_vector_store_ids_to_user_facing( + tools, + configuration.rag_id_mapping, + ) + + if event_type == "response.in_progress": + chunk_dict["response"]["available_quotas"] = {} + chunk_dict["response"]["output_text"] = "" + + latest_response_object = None + inference_metric_recorded = False + if event_type in ("response.completed", "response.incomplete", "response.failed"): + latest_response_object = _process_terminal_stream_chunk( + chunk, + chunk_dict, + api_params, + user_id, + turn_summary, + endpoint_path, + inference_start_time, + ) + inference_metric_recorded = True + + return StreamChunkProcessingResult( + event=f"event: {event_type or 'error'}\ndata: {json.dumps(chunk_dict)}\n\n", + sequence_number=sequence_number, + latest_response_object=latest_response_object, + inference_metric_recorded=inference_metric_recorded, + ) + + +def _process_terminal_stream_chunk( # pylint: disable=too-many-arguments,too-many-positional-arguments + chunk: OpenAIResponseObjectStream, + chunk_dict: dict[str, Any], + api_params: ResponsesApiParams, + user_id: str, + turn_summary: TurnSummary, + endpoint_path: str, + inference_start_time: Optional[float], +) -> OpenAIResponseObject: + """Handle token, quota, output, and inference metrics for terminal chunks.""" + latest_response_object = cast(OpenAIResponseObject, cast(Any, chunk).response) + turn_summary.token_usage = extract_token_usage( + latest_response_object.usage, api_params.model, endpoint_path + ) + consume_query_tokens( + user_id=user_id, + model_id=api_params.model, + token_usage=turn_summary.token_usage, + ) + chunk_dict["response"]["available_quotas"] = get_available_quotas( + quota_limiters=configuration.quota_limiters, user_id=user_id + ) + turn_summary.llm_response = extract_text_from_response_items( + latest_response_object.output + ) + chunk_dict["response"]["output_text"] = turn_summary.llm_response + _record_stream_terminal_inference_result( + getattr(chunk, "type", None), api_params, endpoint_path, inference_start_time + ) + return latest_response_object + + async def response_generator( stream: AsyncIterator[OpenAIResponseObjectStream], original_request: ResponsesRequest, @@ -804,6 +1000,7 @@ async def response_generator( inline_rag_context: RAGContext, filter_server_tools: bool = False, endpoint_path: str = "", + inference_start_time: Optional[float] = None, ) -> AsyncIterator[str]: """Generate SSE-formatted streaming response with LCORE-enriched events. @@ -817,6 +1014,7 @@ async def response_generator( inline_rag_context: Inline RAG context to be used for the response filter_server_tools: Whether to filter server-deployed MCP tool events from the stream endpoint_path: API endpoint path used for metric labeling. + inference_start_time: Monotonic clock time when the backend stream was requested. Yields: SSE-formatted strings for streaming events, ending with [DONE] """ @@ -829,77 +1027,42 @@ async def response_generator( configured_mcp_labels = {s.name for s in configuration.mcp_servers} # Track output indices of server-deployed MCP calls to filter their events server_mcp_output_indices: set[int] = set() + inference_metric_recorded = False - async for chunk in stream: - event_type = getattr(chunk, "type", None) - logger.debug("Processing streaming chunk, type: %s", event_type) - - # Filter out streaming events for server-deployed MCP tools. - # These are handled internally by LCS and should not be forwarded - # to clients that don't understand the mcp_call item type. - if _should_filter_mcp_chunk( - chunk, event_type, configured_mcp_labels, server_mcp_output_indices - ): - continue - - chunk_dict = chunk.model_dump(exclude_none=True, by_alias=True) - - # Create own sequence number for chunks to maintain order - chunk_dict["sequence_number"] = sequence_number - sequence_number += 1 - - if "response" in chunk_dict: - chunk_dict["response"]["conversation"] = normalized_conv_id - _sanitize_response_dict( - chunk_dict["response"], - configured_mcp_labels, + try: + async for chunk in stream: + result = _process_stream_chunk( + chunk, original_request, + api_params, + user_id, + turn_summary, + configured_mcp_labels, + server_mcp_output_indices, + sequence_number, + endpoint_path, + inference_start_time, + normalized_conv_id, ) - tools = chunk_dict["response"].get("tools") - if tools is not None: - chunk_dict["response"]["tools"] = ( - translate_vector_store_ids_to_user_facing( - tools, - configuration.rag_id_mapping, - ) - ) - # Intermediate response - no quota consumption and text yet - if event_type == "response.in_progress": - chunk_dict["response"]["available_quotas"] = {} - chunk_dict["response"]["output_text"] = "" - - # Handle completion, incomplete, and failed events - only quota handling here - if event_type in ( - "response.completed", - "response.incomplete", - "response.failed", - ): - latest_response_object = cast( - OpenAIResponseObject, cast(Any, chunk).response - ) - - # Extract and consume tokens if any were used - turn_summary.token_usage = extract_token_usage( - latest_response_object.usage, api_params.model, endpoint_path - ) - consume_query_tokens( - user_id=user_id, - model_id=api_params.model, - token_usage=turn_summary.token_usage, - ) - - # Get available quotas after token consumption - available_quotas = get_available_quotas( - quota_limiters=configuration.quota_limiters, user_id=user_id + sequence_number = result.sequence_number + if result.event is None: + continue + if result.latest_response_object is not None: + latest_response_object = result.latest_response_object + inference_metric_recorded = ( + inference_metric_recorded or result.inference_metric_recorded ) - chunk_dict["response"]["available_quotas"] = available_quotas - turn_summary.llm_response = extract_text_from_response_items( - latest_response_object.output + yield result.event + except (RuntimeError, APIConnectionError, LLSApiStatusError, OpenAIAPIStatusError): + if not inference_metric_recorded and inference_start_time is not None: + _record_response_inference_result( + api_params.model, + endpoint_path, + "failure", + time.monotonic() - inference_start_time, + record_failure=True, ) - chunk_dict["response"]["output_text"] = turn_summary.llm_response - - data_json = json.dumps(chunk_dict) - yield f"event: {event_type or 'error'}\ndata: {data_json}\n\n" + raise # Extract response metadata from final response object if latest_response_object: @@ -1070,6 +1233,7 @@ async def handle_non_streaming_response( user_agent=user_agent, ) else: + inference_start_time = time.monotonic() try: api_response = cast( OpenAIResponseObject, @@ -1077,6 +1241,12 @@ async def handle_non_streaming_response( **api_params.model_dump(exclude_none=True) ), ) + _record_response_inference_result( + api_params.model, + endpoint_path, + "success", + time.monotonic() - inference_start_time, + ) token_usage = extract_token_usage( api_response.usage, api_params.model, endpoint_path ) @@ -1097,6 +1267,13 @@ async def handle_non_streaming_response( ) except RuntimeError as e: + _record_response_inference_result( + api_params.model, + endpoint_path, + "failure", + time.monotonic() - inference_start_time, + record_failure=True, + ) if is_context_length_error(str(e)): _queue_responses_splunk_event( background_tasks=background_tasks, @@ -1114,6 +1291,13 @@ async def handle_non_streaming_response( raise HTTPException(**error_response.model_dump()) from e raise e except APIConnectionError as e: + _record_response_inference_result( + api_params.model, + endpoint_path, + "failure", + time.monotonic() - inference_start_time, + record_failure=True, + ) _queue_responses_splunk_event( background_tasks=background_tasks, input_text=input_text, @@ -1132,6 +1316,13 @@ async def handle_non_streaming_response( ) raise HTTPException(**error_response.model_dump()) from e except (LLSApiStatusError, OpenAIAPIStatusError) as e: + _record_response_inference_result( + api_params.model, + endpoint_path, + "failure", + time.monotonic() - inference_start_time, + record_failure=True, + ) _queue_responses_splunk_event( background_tasks=background_tasks, input_text=input_text, diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index b0d19dfdb..ed1a4a40b 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -455,6 +455,9 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po """ inference_time = time.monotonic() - start_time recording.record_llm_failure(provider, model, endpoint_path) + recording.record_llm_inference_duration( + provider, model, endpoint_path, "failure", inference_time + ) _queue_splunk_event( background_tasks, infer_request, @@ -532,6 +535,36 @@ def _resolve_quota_subject(request: Request, auth: AuthTuple) -> Optional[str]: return system_id +def _check_infer_quota( + request: Request, auth: AuthTuple, endpoint_path: str +) -> Optional[str]: + """Check infer quota availability and record bounded quota metrics.""" + quota_id = _resolve_quota_subject(request, auth) + quota_type = configuration.rlsapi_v1.quota_subject or "disabled" + if quota_id is None: + recording.record_quota_check(endpoint_path, quota_type, "skipped", 0.0) + return None + + quota_start_time = time.monotonic() + try: + check_tokens_available(configuration.quota_limiters, quota_id) + except HTTPException: + recording.record_quota_check( + endpoint_path, + quota_type, + "failure", + time.monotonic() - quota_start_time, + ) + raise + recording.record_quota_check( + endpoint_path, + quota_type, + "success", + time.monotonic() - quota_start_time, + ) + return quota_id + + def _build_infer_response( response_text: str, request_id: str, @@ -669,12 +702,11 @@ async def infer_endpoint( # pylint: disable=R0914 """ # Authentication enforced by get_auth_dependency(), authorization by @authorize decorator. check_configuration_loaded(configuration) + endpoint_path = "/v1/infer" # Quota enforcement: resolve subject and check availability before any work. # No-op when quota_subject is not configured or no quota limiters exist. - quota_id = _resolve_quota_subject(request, auth) - if quota_id is not None: - check_tokens_available(configuration.quota_limiters, quota_id) + quota_id = _check_infer_quota(request, auth, endpoint_path) request_id = get_suid() @@ -685,8 +717,6 @@ async def infer_endpoint( # pylint: disable=R0914 "Request %s: Combined input source length: %d", request_id, len(input_source) ) - endpoint_path = "/v1/infer" - # Run shield moderation on user input before inference. # Uses all configured shields; no-op when no shields are registered. # Runs before model/tool discovery so blocked requests short-circuit @@ -721,6 +751,9 @@ async def infer_endpoint( # pylint: disable=R0914 response_text = extract_text_from_response_items(response.output) token_usage = extract_token_usage(response.usage, model_id, endpoint_path) inference_time = time.monotonic() - start_time + recording.record_llm_inference_duration( + provider, model, endpoint_path, "success", inference_time + ) except _INFER_HANDLED_EXCEPTIONS as error: if response is not None: extract_token_usage(response.usage, model_id, endpoint_path) # type: ignore[arg-type] diff --git a/src/authentication/api_key_token.py b/src/authentication/api_key_token.py index ec24335c2..95c2b9ae3 100644 --- a/src/authentication/api_key_token.py +++ b/src/authentication/api_key_token.py @@ -7,12 +7,14 @@ """ import secrets +import time from fastapi import HTTPException, Request, status from authentication.interface import AuthInterface -from authentication.utils import extract_user_token +from authentication.utils import extract_user_token, record_auth_metrics from constants import ( + AUTH_MOD_APIKEY_TOKEN, DEFAULT_USER_NAME, DEFAULT_USER_UID, DEFAULT_VIRTUAL_PATH, @@ -59,16 +61,28 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]: HTTPException: If the bearer token is missing or doesn't match the configured API key (HTTP 401). """ + start_time = time.monotonic() + # try to extract user token from request - user_token = extract_user_token(request.headers) + try: + user_token = extract_user_token(request.headers) + except HTTPException: + record_auth_metrics( + AUTH_MOD_APIKEY_TOKEN, "failure", "missing_token", start_time + ) + raise # API Key validation. Use secrets.compare_digest for constant-time comparison if not secrets.compare_digest( user_token, self.config.api_key.get_secret_value() ): + record_auth_metrics( + AUTH_MOD_APIKEY_TOKEN, "failure", "invalid_key", start_time + ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key", ) + record_auth_metrics(AUTH_MOD_APIKEY_TOKEN, "success", "valid_key", start_time) return DEFAULT_USER_UID, DEFAULT_USER_NAME, self.skip_userid_check, user_token diff --git a/src/authentication/jwk_token.py b/src/authentication/jwk_token.py index b8c82be10..65d0ac732 100644 --- a/src/authentication/jwk_token.py +++ b/src/authentication/jwk_token.py @@ -1,6 +1,7 @@ """Manage authentication flow for FastAPI endpoints with JWK based JWT auth.""" import json +import time from asyncio import Lock from collections.abc import Callable from typing import Any @@ -17,8 +18,9 @@ from fastapi import HTTPException, Request from authentication.interface import AuthInterface, AuthTuple -from authentication.utils import extract_user_token +from authentication.utils import extract_user_token, record_auth_metrics from constants import ( + AUTH_MOD_JWK_TOKEN, DEFAULT_VIRTUAL_PATH, ) from log import get_logger @@ -139,6 +141,85 @@ def _internal(header: dict[str, Any], _payload: dict[str, Any]) -> Key: return _internal +async def _get_jwk_set_for_auth(config: JwkConfiguration, start_time: float) -> KeySet: + """Load the configured JWK set and record bounded auth failures.""" + try: + return await get_jwk_set(str(config.url)) + except aiohttp.ClientError as exc: + logger.error("Failed to fetch JWK set: %s", exc) + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "jwk_fetch_error", start_time + ) + response = UnauthorizedResponse( + cause="Unable to reach authentication key server" + ) + raise HTTPException(**response.model_dump()) from exc + except json.JSONDecodeError as exc: + logger.error("Invalid JSON in JWK set response: %s", exc) + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_json", start_time) + response = UnauthorizedResponse( + cause="Authentication key server returned invalid data" + ) + raise HTTPException(**response.model_dump()) from exc + except JoseError as exc: + logger.error("Invalid JWK set format: %s", exc) + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_jwk", start_time) + response = UnauthorizedResponse(cause="Authentication keys are malformed") + raise HTTPException(**response.model_dump()) from exc + + +def _decode_jwk_claims(user_token: str, jwk_set: KeySet, start_time: float) -> Any: + """Decode a JWT and record bounded auth failures.""" + try: + return jwt.decode(user_token, key=key_resolver_func(jwk_set)) + except (KeyNotFoundError, BadSignatureError, DecodeError, JoseError) as exc: + logger.warning("Token decode error: %s", exc) + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "token_decode_error", start_time + ) + cause_map = { + KeyNotFoundError: "Token signed by unknown key", + BadSignatureError: "Invalid token signature", + DecodeError: "Token could not be decoded", + JoseError: "Token format error", + } + response = UnauthorizedResponse( + cause=cause_map.get(type(exc), "Unknown token error") + ) + raise HTTPException(**response.model_dump()) from exc + + +def _validate_jwk_claims(claims: Any, start_time: float) -> None: + """Validate decoded JWT claims and record bounded auth failures.""" + try: + claims.validate() + except ExpiredTokenError as exc: + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "token_expired", start_time) + response = UnauthorizedResponse(cause="Token has expired") + raise HTTPException(**response.model_dump()) from exc + except JoseError as exc: + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "token_validation_error", start_time + ) + response = UnauthorizedResponse(cause="Token validation failed") + raise HTTPException(**response.model_dump()) from exc + + +def _get_required_claim(claims: Any, claim_name: str, start_time: float) -> str: + """Return a required JWT claim and record bounded auth failures when missing.""" + try: + value = claims[claim_name] + except KeyError as exc: + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "missing_claim", start_time) + response = UnauthorizedResponse(cause=f"Token missing claim: {claim_name}") + raise HTTPException(**response.model_dump()) from exc + if not isinstance(value, str) or not value: + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_claim", start_time) + response = UnauthorizedResponse(cause=f"Token has invalid claim: {claim_name}") + raise HTTPException(**response.model_dump()) + return value + + class JwkTokenAuthDependency(AuthInterface): # pylint: disable=too-few-public-methods """JWK AuthDependency class for JWK-based JWT authentication.""" @@ -187,73 +268,34 @@ async def __call__(self, request: Request) -> AuthTuple: extracted from the validated JWT. Only returned on successful authentication; all error paths raise HTTPException. """ + start_time = time.monotonic() + if not request.headers.get("Authorization"): + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "missing_header", start_time + ) response = UnauthorizedResponse(cause="No Authorization header found") raise HTTPException(**response.model_dump()) - user_token = extract_user_token(request.headers) - - try: - jwk_set = await get_jwk_set(str(self.config.url)) - except aiohttp.ClientError as exc: - logger.error("Failed to fetch JWK set: %s", exc) - response = UnauthorizedResponse( - cause="Unable to reach authentication key server" - ) - raise HTTPException(**response.model_dump()) from exc - except json.JSONDecodeError as exc: - logger.error("Invalid JSON in JWK set response: %s", exc) - response = UnauthorizedResponse( - cause="Authentication key server returned invalid data" - ) - raise HTTPException(**response.model_dump()) from exc - except JoseError as exc: - logger.error("Invalid JWK set format: %s", exc) - response = UnauthorizedResponse(cause="Authentication keys are malformed") - raise HTTPException(**response.model_dump()) from exc - - try: - claims = jwt.decode(user_token, key=key_resolver_func(jwk_set)) - except (KeyNotFoundError, BadSignatureError, DecodeError, JoseError) as exc: - logger.warning("Token decode error: %s", exc) - cause_map = { - KeyNotFoundError: "Token signed by unknown key", - BadSignatureError: "Invalid token signature", - DecodeError: "Token could not be decoded", - JoseError: "Token format error", - } - response = UnauthorizedResponse( - cause=cause_map.get(type(exc), "Unknown token error") - ) - raise HTTPException(**response.model_dump()) from exc - - try: - claims.validate() - except ExpiredTokenError as exc: - response = UnauthorizedResponse(cause="Token has expired") - raise HTTPException(**response.model_dump()) from exc - except JoseError as exc: - response = UnauthorizedResponse(cause="Token validation failed") - raise HTTPException(**response.model_dump()) from exc - - try: - user_id: str = claims[self.config.jwt_configuration.user_id_claim] - except KeyError as exc: - missing_claim = self.config.jwt_configuration.user_id_claim - response = UnauthorizedResponse( - cause=f"Token missing claim: {missing_claim}" - ) - raise HTTPException(**response.model_dump()) from exc - try: - username: str = claims[self.config.jwt_configuration.username_claim] - except KeyError as exc: - missing_claim = self.config.jwt_configuration.username_claim - response = UnauthorizedResponse( - cause=f"Token missing claim: {missing_claim}" + user_token = extract_user_token(request.headers) + except HTTPException: + record_auth_metrics( + AUTH_MOD_JWK_TOKEN, "failure", "missing_token", start_time ) - raise HTTPException(**response.model_dump()) from exc + raise + + jwk_set = await _get_jwk_set_for_auth(self.config, start_time) + claims = _decode_jwk_claims(user_token, jwk_set, start_time) + _validate_jwk_claims(claims, start_time) + user_id = _get_required_claim( + claims, self.config.jwt_configuration.user_id_claim, start_time + ) + username = _get_required_claim( + claims, self.config.jwt_configuration.username_claim, start_time + ) logger.info("Successfully authenticated user %s (ID: %s)", username, user_id) + record_auth_metrics(AUTH_MOD_JWK_TOKEN, "success", "authenticated", start_time) return user_id, username, self.skip_userid_check, user_token diff --git a/src/authentication/k8s.py b/src/authentication/k8s.py index 6cc99c92c..07b1c9363 100644 --- a/src/authentication/k8s.py +++ b/src/authentication/k8s.py @@ -1,6 +1,7 @@ """Manage authentication flow for FastAPI endpoints with K8S/OCP.""" import os +import time from http import HTTPStatus from typing import Optional, Self, cast @@ -10,9 +11,9 @@ from kubernetes.config import ConfigException from authentication.interface import NO_AUTH_TUPLE, AuthInterface -from authentication.utils import extract_user_token +from authentication.utils import extract_user_token, record_auth_metrics from configuration import configuration -from constants import DEFAULT_VIRTUAL_PATH +from constants import AUTH_MOD_K8S, DEFAULT_VIRTUAL_PATH from log import get_logger from models.responses import ( ForbiddenResponse, @@ -385,6 +386,79 @@ def get_user_info(token: str) -> Optional[kubernetes.client.V1TokenReviewStatus] return None +def _populate_kube_admin_uid( + user: kubernetes.client.V1UserInfo, start_time: float +) -> None: + """Populate kube:admin UID from the cluster ID and record bounded failures.""" + if user.username != "kube:admin": + return + try: + user.uid = K8sClientSingleton.get_cluster_id() + except K8sAPIConnectionError as e: + logger.error("Cannot connect to Kubernetes API: %s", e) + record_auth_metrics(AUTH_MOD_K8S, "failure", "k8s_api_unavailable", start_time) + response = ServiceUnavailableResponse( + backend_name="Kubernetes API", + cause=str(e), + ) + raise HTTPException(**response.model_dump()) from e + except K8sConfigurationError as e: + logger.error("Cluster configuration error: %s", e) + record_auth_metrics(AUTH_MOD_K8S, "failure", "k8s_config_error", start_time) + response = InternalServerErrorResponse( + response="Internal server error", + cause=str(e), + ) + raise HTTPException(**response.model_dump()) from e + except Exception as e: # pylint: disable=broad-exception-caught + logger.exception("Unexpected error while resolving kube:admin cluster ID") + record_auth_metrics(AUTH_MOD_K8S, "failure", "unexpected_error", start_time) + response = InternalServerErrorResponse( + response="Internal server error", + cause=str(e), + ) + raise HTTPException(**response.model_dump()) from e + + +def _create_subject_access_review( + user: kubernetes.client.V1UserInfo, virtual_path: str, start_time: float +) -> kubernetes.client.V1SubjectAccessReview: + """Create a Kubernetes SubjectAccessReview and record bounded failures.""" + try: + authorization_api = K8sClientSingleton.get_authz_api() + sar = kubernetes.client.V1SubjectAccessReview( + spec=kubernetes.client.V1SubjectAccessReviewSpec( + user=user.username, + groups=user.groups, + non_resource_attributes=kubernetes.client.V1NonResourceAttributes( + path=virtual_path, verb="get" + ), + ) + ) + return cast( + kubernetes.client.V1SubjectAccessReview, + authorization_api.create_subject_access_review(sar), + ) + except ApiException as e: + logger.error("API exception during SubjectAccessReview: %s", e) + record_auth_metrics( + AUTH_MOD_K8S, "failure", "authorization_check_error", start_time + ) + response = ServiceUnavailableResponse( + backend_name="Kubernetes API", + cause="Unable to perform authorization check", + ) + raise HTTPException(**response.model_dump()) from e + except Exception as e: # pylint: disable=broad-exception-caught + logger.exception("Unexpected error during SubjectAccessReview") + record_auth_metrics(AUTH_MOD_K8S, "failure", "unexpected_error", start_time) + response = InternalServerErrorResponse( + response="Internal server error", + cause=str(e), + ) + raise HTTPException(**response.model_dump()) from e + + class K8SAuthDependency(AuthInterface): # pylint: disable=too-few-public-methods """FastAPI dependency for Kubernetes (k8s) authentication and authorization. @@ -436,69 +510,47 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]: ------ HTTPException: If authentication or authorization fails. """ + start_time = time.monotonic() + # LCORE-694: Config option to skip authorization for readiness and liveness probe if not request.headers.get("Authorization"): if configuration.authentication_configuration.skip_for_health_probes: if request.url.path in ("/readiness", "/liveness"): + record_auth_metrics( + AUTH_MOD_K8S, "skipped", "health_probe", start_time + ) return NO_AUTH_TUPLE # Skip auth for metrics endpoint when configured if configuration.authentication_configuration.skip_for_metrics: if request.url.path in ("/metrics",): + record_auth_metrics(AUTH_MOD_K8S, "skipped", "metrics", start_time) return NO_AUTH_TUPLE - token = extract_user_token(request.headers) - user_info = get_user_info(token) + try: + token = extract_user_token(request.headers) + except HTTPException: + record_auth_metrics(AUTH_MOD_K8S, "failure", "missing_token", start_time) + raise + try: + user_info = get_user_info(token) + except HTTPException: + record_auth_metrics( + AUTH_MOD_K8S, "failure", "token_review_error", start_time + ) + raise if user_info is None: + record_auth_metrics(AUTH_MOD_K8S, "failure", "invalid_token", start_time) response = UnauthorizedResponse(cause="Invalid or expired Kubernetes token") raise HTTPException(**response.model_dump()) # Cast user to proper type for type checking user = cast(kubernetes.client.V1UserInfo, user_info.user) - if user.username == "kube:admin": - try: - user.uid = K8sClientSingleton.get_cluster_id() - except K8sAPIConnectionError as e: - # Kubernetes API is unreachable - return 503 - logger.error("Cannot connect to Kubernetes API: %s", e) - response = ServiceUnavailableResponse( - backend_name="Kubernetes API", - cause=str(e), - ) - raise HTTPException(**response.model_dump()) from e - except K8sConfigurationError as e: - # Cluster misconfiguration or client error - return 500 - logger.error("Cluster configuration error: %s", e) - response = InternalServerErrorResponse( - response="Internal server error", - cause=str(e), - ) - raise HTTPException(**response.model_dump()) from e - - try: - authorization_api = K8sClientSingleton.get_authz_api() - sar = kubernetes.client.V1SubjectAccessReview( - spec=kubernetes.client.V1SubjectAccessReviewSpec( - user=user.username, - groups=user.groups, - non_resource_attributes=kubernetes.client.V1NonResourceAttributes( - path=self.virtual_path, verb="get" - ), - ) - ) - sar_response = cast( - kubernetes.client.V1SubjectAccessReview, - authorization_api.create_subject_access_review(sar), - ) - - except Exception as e: - logger.error("API exception during SubjectAccessReview: %s", e) - response = ServiceUnavailableResponse( - backend_name="Kubernetes API", - cause="Unable to perform authorization check", - ) - raise HTTPException(**response.model_dump()) from e + _populate_kube_admin_uid(user, start_time) + sar_response = _create_subject_access_review( + user, self.virtual_path, start_time + ) sar_status = cast( kubernetes.client.V1SubjectAccessReviewStatus, sar_response.status @@ -507,9 +559,11 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]: username = cast(str, user.username) if not sar_status.allowed: + record_auth_metrics(AUTH_MOD_K8S, "failure", "not_authorized", start_time) response = ForbiddenResponse.endpoint(user_id=user_uid) raise HTTPException(**response.model_dump()) + record_auth_metrics(AUTH_MOD_K8S, "success", "authenticated", start_time) return ( user_uid, username, diff --git a/src/authentication/noop.py b/src/authentication/noop.py index a4188aa94..5688dfd56 100644 --- a/src/authentication/noop.py +++ b/src/authentication/noop.py @@ -1,9 +1,13 @@ """Manage authentication flow for FastAPI endpoints with no-op auth.""" +import time + from fastapi import HTTPException, Request from authentication.interface import AuthInterface +from authentication.utils import record_auth_metrics from constants import ( + AUTH_MOD_NOOP, DEFAULT_USER_NAME, DEFAULT_USER_UID, DEFAULT_VIRTUAL_PATH, @@ -47,6 +51,8 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]: - skip_userid_check: True to indicate the user ID check is skipped. - token: NO_USER_TOKEN. """ + start_time = time.monotonic() + logger.warning( "No-op authentication dependency is being used. " "The service is running in insecure mode intended solely for development purposes" @@ -54,6 +60,8 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]: # try to extract user ID from request user_id = request.query_params.get("user_id", DEFAULT_USER_UID) if not user_id: + record_auth_metrics(AUTH_MOD_NOOP, "failure", "empty_user_id", start_time) raise HTTPException(status_code=400, detail="user_id cannot be empty") logger.debug("Retrieved user ID: %s", user_id) + record_auth_metrics(AUTH_MOD_NOOP, "skipped", "no_auth_required", start_time) return user_id, DEFAULT_USER_NAME, self.skip_userid_check, NO_USER_TOKEN diff --git a/src/authentication/noop_with_token.py b/src/authentication/noop_with_token.py index 45127e2c9..509d42334 100644 --- a/src/authentication/noop_with_token.py +++ b/src/authentication/noop_with_token.py @@ -9,11 +9,14 @@ - Returns a tuple: (user_id, DEFAULT_USER_NAME, user_token). """ +import time + from fastapi import HTTPException, Request from authentication.interface import AuthInterface -from authentication.utils import extract_user_token +from authentication.utils import extract_user_token, record_auth_metrics from constants import ( + AUTH_MOD_NOOP_WITH_TOKEN, DEFAULT_USER_NAME, DEFAULT_USER_UID, DEFAULT_VIRTUAL_PATH, @@ -59,15 +62,29 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]: - skip_userid_check: True to indicate user-id checks are skipped. - user_token: Token extracted from the request headers. """ + start_time = time.monotonic() + logger.warning( "No-op with token authentication dependency is being used. " "The service is running in insecure mode intended solely for development purposes" ) # try to extract user token from request - user_token = extract_user_token(request.headers) + try: + user_token = extract_user_token(request.headers) + except HTTPException: + record_auth_metrics( + AUTH_MOD_NOOP_WITH_TOKEN, "failure", "missing_token", start_time + ) + raise # try to extract user ID from request user_id = request.query_params.get("user_id", DEFAULT_USER_UID) if not user_id: + record_auth_metrics( + AUTH_MOD_NOOP_WITH_TOKEN, "failure", "empty_user_id", start_time + ) raise HTTPException(status_code=400, detail="user_id cannot be empty") logger.debug("Retrieved user ID: %s", user_id) + record_auth_metrics( + AUTH_MOD_NOOP_WITH_TOKEN, "skipped", "no_auth_required", start_time + ) return user_id, DEFAULT_USER_NAME, self.skip_userid_check, user_token diff --git a/src/authentication/rh_identity.py b/src/authentication/rh_identity.py index 2e9a926ef..b4484689c 100644 --- a/src/authentication/rh_identity.py +++ b/src/authentication/rh_identity.py @@ -6,13 +6,16 @@ import base64 import json +import time from typing import Any, Optional from fastapi import HTTPException, Request from authentication.interface import NO_AUTH_TUPLE, AuthInterface, AuthTuple +from authentication.utils import record_auth_metrics from configuration import configuration from constants import ( + AUTH_MOD_RH_IDENTITY, DEFAULT_RH_IDENTITY_MAX_HEADER_SIZE, DEFAULT_VIRTUAL_PATH, NO_USER_TOKEN, @@ -22,6 +25,24 @@ logger = get_logger(__name__) +def _record_rh_identity_auth(result: str, reason: str, start_time: float) -> None: + """Record RH Identity authentication metrics with bounded labels.""" + record_auth_metrics(AUTH_MOD_RH_IDENTITY, result, reason, start_time) + + +def _get_auth_skip_tuple(request: Request, start_time: float) -> Optional[AuthTuple]: + """Return an auth tuple for configured RH Identity skip paths.""" + if request.url.path.endswith(("/readiness", "/liveness")): + if configuration.authentication_configuration.skip_for_health_probes: + _record_rh_identity_auth("skipped", "health_probe", start_time) + return NO_AUTH_TUPLE + if request.url.path.endswith("/metrics"): + if configuration.authentication_configuration.skip_for_metrics: + _record_rh_identity_auth("skipped", "metrics", start_time) + return NO_AUTH_TUPLE + return None + + class RHIdentityData: """Extracts and validates Red Hat Identity header data. @@ -300,18 +321,16 @@ async def __call__(self, request: Request) -> AuthTuple: - 400: Invalid base64, invalid JSON, or missing required fields - 403: Missing required entitlements """ + start_time = time.monotonic() + # Extract header identity_header = request.headers.get("x-rh-identity") if not identity_header: - # Skip auth for health probes when configured - if request.url.path.endswith(("/readiness", "/liveness")): - if configuration.authentication_configuration.skip_for_health_probes: - return NO_AUTH_TUPLE - # Skip auth for metrics endpoint when configured - if request.url.path.endswith("/metrics"): - if configuration.authentication_configuration.skip_for_metrics: - return NO_AUTH_TUPLE + auth_skip = _get_auth_skip_tuple(request, start_time) + if auth_skip is not None: + return auth_skip logger.warning("Missing x-rh-identity header") + _record_rh_identity_auth("failure", "missing_header", start_time) raise HTTPException(status_code=401, detail="Missing x-rh-identity header") # Enforce header size limit before decoding @@ -321,6 +340,7 @@ async def __call__(self, request: Request) -> AuthTuple: len(identity_header), self.max_header_size, ) + _record_rh_identity_auth("failure", "header_too_large", start_time) raise HTTPException( status_code=400, detail="x-rh-identity header exceeds maximum allowed size", @@ -332,6 +352,7 @@ async def __call__(self, request: Request) -> AuthTuple: decoded_str = decoded_bytes.decode("utf-8") except (ValueError, UnicodeDecodeError) as exc: logger.warning("Invalid base64 in x-rh-identity header: %s", exc) + _record_rh_identity_auth("failure", "invalid_base64", start_time) raise HTTPException( status_code=400, detail="Invalid base64 encoding in x-rh-identity header", @@ -342,18 +363,26 @@ async def __call__(self, request: Request) -> AuthTuple: identity_data = json.loads(decoded_str) except json.JSONDecodeError as exc: logger.warning("Invalid JSON in x-rh-identity header: %s", exc) + _record_rh_identity_auth("failure", "invalid_json", start_time) raise HTTPException( status_code=400, detail="Invalid JSON in x-rh-identity header" ) from exc # Extract and validate identity - rh_identity = RHIdentityData( - identity_data, - required_entitlements=self.required_entitlements, - ) + try: + rh_identity = RHIdentityData( + identity_data, + required_entitlements=self.required_entitlements, + ) - # Validate entitlements if configured - rh_identity.validate_entitlements() + # Validate entitlements if configured + rh_identity.validate_entitlements() + except HTTPException as exc: + reason = ( + "entitlement_missing" if exc.status_code == 403 else "invalid_identity" + ) + _record_rh_identity_auth("failure", reason, start_time) + raise # Store identity data in request.state for downstream access request.state.rh_identity_data = rh_identity @@ -365,5 +394,6 @@ async def __call__(self, request: Request) -> AuthTuple: logger.debug( "RH Identity authenticated: user_id=%s, username=%s", user_id, username ) + _record_rh_identity_auth("success", "authenticated", start_time) return user_id, username, self.skip_userid_check, NO_USER_TOKEN diff --git a/src/authentication/utils.py b/src/authentication/utils.py index 6ac99b974..95df73a06 100644 --- a/src/authentication/utils.py +++ b/src/authentication/utils.py @@ -1,10 +1,16 @@ """Authentication utility functions.""" +import time + from fastapi import HTTPException from starlette.datastructures import Headers +from log import get_logger +from metrics import recording from models.responses import UnauthorizedResponse +logger = get_logger(__name__) + def extract_user_token(headers: Headers) -> str: """Extract the bearer token from an HTTP Authorization header. @@ -33,3 +39,28 @@ def extract_user_token(headers: Headers) -> str: raise HTTPException(**response.model_dump()) return scheme_and_token[1] + + +def record_auth_metrics( + auth_module: str, result: str, reason: str, start_time: float +) -> None: + """Record authentication attempt and duration metrics together. + + Args: + auth_module: Configured authentication module name. + result: Bounded result label, such as ``success`` or ``failure``. + reason: Bounded reason label for the result. + start_time: Monotonic clock time captured at the start of auth handling. + """ + try: + recording.record_auth_attempt(auth_module, result, reason) + recording.record_auth_duration( + auth_module, result, time.monotonic() - start_time + ) + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to record authentication metrics for module %s with result %s", + auth_module, + result, + exc_info=True, + ) diff --git a/src/authorization/middleware.py b/src/authorization/middleware.py index dcad18703..b6973c886 100644 --- a/src/authorization/middleware.py +++ b/src/authorization/middleware.py @@ -1,5 +1,6 @@ """Authorization middleware and decorators.""" +import time from collections.abc import Callable from functools import lru_cache, wraps from typing import Any, Optional @@ -18,6 +19,7 @@ ) from configuration import configuration from log import get_logger +from metrics import recording from models.config import Action from models.responses import ( ForbiddenResponse, @@ -124,39 +126,50 @@ async def _perform_authorization_check( HTTPException: with 403 Forbidden if the resolved roles are not permitted to perform `action`. """ - role_resolver, access_resolver = get_authorization_resolvers() + start_time = time.monotonic() + result = "error" try: - auth = kwargs["auth"] - except KeyError as exc: - logger.error( - "Authorization only allowed on endpoints that accept " - "'auth: Any = Depends(get_auth_dependency())'" + role_resolver, access_resolver = get_authorization_resolvers() + + try: + auth = kwargs["auth"] + except KeyError as exc: + logger.error( + "Authorization only allowed on endpoints that accept " + "'auth: Any = Depends(get_auth_dependency())'" + ) + response = InternalServerErrorResponse.generic() + raise HTTPException(**response.model_dump()) from exc + + # Everyone gets the everyone (aka *) role + everyone_roles = {"*"} + + user_roles = await role_resolver.resolve_roles(auth) | everyone_roles + + if not access_resolver.check_access(action, user_roles): + response = ForbiddenResponse.endpoint(user_id=auth[0]) + result = "denied" + raise HTTPException(**response.model_dump()) + + authorized_actions = access_resolver.get_actions(user_roles) + + req: Optional[Request] = None + if "request" in kwargs and isinstance(kwargs["request"], Request): + req = kwargs["request"] + else: + for arg in args: + if isinstance(arg, Request): + req = arg + break + if req is not None: + req.state.authorized_actions = authorized_actions + result = "success" + finally: + recording.record_authorization_check(action.value, result) + recording.record_authorization_duration( + action.value, result, time.monotonic() - start_time ) - response = InternalServerErrorResponse.generic() - raise HTTPException(**response.model_dump()) from exc - - # Everyone gets the everyone (aka *) role - everyone_roles = {"*"} - - user_roles = await role_resolver.resolve_roles(auth) | everyone_roles - - if not access_resolver.check_access(action, user_roles): - response = ForbiddenResponse.endpoint(user_id=auth[0]) - raise HTTPException(**response.model_dump()) - - authorized_actions = access_resolver.get_actions(user_roles) - - req: Optional[Request] = None - if "request" in kwargs and isinstance(kwargs["request"], Request): - req = kwargs["request"] - else: - for arg in args: - if isinstance(arg, Request): - req = arg - break - if req is not None: - req.state.authorized_actions = authorized_actions def authorize(action: Action) -> Callable: diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py index 893e634db..efc0daae6 100644 --- a/src/metrics/__init__.py +++ b/src/metrics/__init__.py @@ -55,3 +55,52 @@ "LLM tokens received", ["provider", "model", "endpoint"], ) + +# Counter to track authentication attempts by configured authentication module. +auth_attempts_total = Counter( + "ls_auth_attempts_total", + "Authentication attempts", + ["auth_module", "result", "reason"], +) + +# Histogram to measure authentication dependency latency. +auth_duration_seconds = Histogram( + "ls_auth_duration_seconds", + "Authentication duration", + ["auth_module", "result"], +) + +# Counter to track authorization checks by protected action. +authorization_checks_total = Counter( + "ls_authorization_checks_total", + "Authorization checks", + ["action", "result"], +) + +# Histogram to measure authorization check latency by protected action. +authorization_duration_seconds = Histogram( + "ls_authorization_duration_seconds", + "Authorization check duration", + ["action", "result"], +) + +# Counter to track pre-request quota checks by bounded quota category. +quota_checks_total = Counter( + "ls_quota_checks_total", + "Quota availability checks", + ["endpoint", "quota_type", "result"], +) + +# Histogram to measure quota availability check latency by bounded quota category. +quota_check_duration_seconds = Histogram( + "ls_quota_check_duration_seconds", + "Quota availability check duration", + ["endpoint", "quota_type", "result"], +) + +# Histogram to measure the latency of direct LLM inference backend calls. +llm_inference_duration_seconds = Histogram( + "ls_llm_inference_duration_seconds", + "LLM inference call duration", + ["provider", "model", "endpoint", "result"], +) diff --git a/src/metrics/recording.py b/src/metrics/recording.py index e7da276cc..65043cc39 100644 --- a/src/metrics/recording.py +++ b/src/metrics/recording.py @@ -109,3 +109,98 @@ def record_llm_token_usage( ) except (AttributeError, TypeError, ValueError): logger.warning("Failed to update token metrics", exc_info=True) + + +def record_auth_attempt(auth_module: str, result: str, reason: str) -> None: + """Record one authentication attempt. + + Args: + auth_module: Configured authentication module name. + result: Bounded result label, such as ``success`` or ``failure``. + reason: Bounded reason label for the result. + """ + try: + metrics.auth_attempts_total.labels(auth_module, result, reason).inc() + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update authentication metric", exc_info=True) + + +def record_auth_duration(auth_module: str, result: str, duration: float) -> None: + """Record authentication duration. + + Args: + auth_module: Configured authentication module name. + result: Bounded result label, such as ``success`` or ``failure``. + duration: Authentication duration in seconds. + """ + try: + metrics.auth_duration_seconds.labels(auth_module, result).observe(duration) + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update authentication duration metric", exc_info=True) + + +def record_authorization_check(action: str, result: str) -> None: + """Record one authorization check. + + Args: + action: Protected action name. + result: Bounded result label, such as ``success`` or ``denied``. + """ + try: + metrics.authorization_checks_total.labels(action, result).inc() + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update authorization metric", exc_info=True) + + +def record_authorization_duration(action: str, result: str, duration: float) -> None: + """Record authorization check duration. + + Args: + action: Protected action name. + result: Bounded result label, such as ``success`` or ``denied``. + duration: Authorization check duration in seconds. + """ + try: + metrics.authorization_duration_seconds.labels(action, result).observe(duration) + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update authorization duration metric", exc_info=True) + + +def record_quota_check( + endpoint_path: str, quota_type: str, result: str, duration: float +) -> None: + """Record a quota availability check. + + Args: + endpoint_path: API endpoint path for metric labeling. + quota_type: Bounded quota category, not the subject identifier. + result: Bounded result label, such as ``success``, ``skipped``, or ``failure``. + duration: Quota check duration in seconds. + """ + try: + metrics.quota_checks_total.labels(endpoint_path, quota_type, result).inc() + metrics.quota_check_duration_seconds.labels( + endpoint_path, quota_type, result + ).observe(duration) + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update quota check metrics", exc_info=True) + + +def record_llm_inference_duration( + provider: str, model: str, endpoint_path: str, result: str, duration: float +) -> None: + """Record the latency of a direct LLM inference backend call. + + Args: + provider: LLM provider identifier. + model: LLM model identifier without the provider prefix. + endpoint_path: API endpoint path for metric labeling. + result: Bounded result label, such as ``success`` or ``failure``. + duration: Inference call duration in seconds. + """ + try: + metrics.llm_inference_duration_seconds.labels( + provider, model, endpoint_path, result + ).observe(duration) + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update LLM inference duration metric", exc_info=True) diff --git a/tests/unit/app/endpoints/test_responses.py b/tests/unit/app/endpoints/test_responses.py index 96550af75..82da1a26e 100644 --- a/tests/unit/app/endpoints/test_responses.py +++ b/tests/unit/app/endpoints/test_responses.py @@ -2,6 +2,7 @@ """Unit tests for the /responses REST API endpoint (LCORE Responses API).""" import json +from collections.abc import AsyncIterator from datetime import UTC, datetime from typing import Any, Optional, cast @@ -22,6 +23,7 @@ _should_filter_mcp_chunk, handle_non_streaming_response, handle_streaming_response, + response_generator, responses_endpoint_handler, ) from configuration import AppConfig @@ -30,7 +32,12 @@ from models.database.conversations import UserConversation from models.requests import ResponsesRequest from models.responses import ResponsesResponse -from utils.types import RAGContext, ResponsesConversationContext, TurnSummary +from utils.types import ( + RAGContext, + ResponsesApiParams, + ResponsesConversationContext, + TurnSummary, +) MOCK_AUTH = ( "00000001-0001-0001-0001-000000000001", @@ -1260,6 +1267,57 @@ async def mock_stream() -> Any: assert '"available_quotas":{}' in body or '"available_quotas": {}' in body assert "[DONE]" in body + @pytest.mark.asyncio + @pytest.mark.parametrize("yield_chunk_before_error", [False, True]) + async def test_response_generator_records_failure_when_stream_iteration_raises( + self, + minimal_config: AppConfig, + mocker: MockerFixture, + yield_chunk_before_error: bool, + ) -> None: + """Test stream iterator failures record inference failure metrics once.""" + in_progress_chunk = mocker.Mock() + in_progress_chunk.type = "response.in_progress" + in_progress_chunk.model_dump.return_value = { + "type": "response.in_progress", + "response": {"id": "r0"}, + } + + async def failing_stream() -> AsyncIterator[Any]: + """Yield an optional chunk, then raise like a broken backend stream.""" + if yield_chunk_before_error: + yield in_progress_chunk + raise RuntimeError("stream failed") + + request = _request_with_model_and_conv("Hi", model="provider/model1") + mocker.patch(f"{MODULE}.configuration", minimal_config) + mocker.patch(f"{MODULE}.time.monotonic", return_value=13.0) + mock_record = mocker.patch(f"{MODULE}._record_response_inference_result") + + generator = response_generator( + stream=failing_stream(), + original_request=request, + updated_request=request, + api_params=ResponsesApiParams.model_validate(request.model_dump()), + user_id=MOCK_AUTH[0], + turn_summary=TurnSummary(), + inline_rag_context=RAGContext(), + endpoint_path="/v1/responses", + inference_start_time=10.0, + ) + + with pytest.raises(RuntimeError, match="stream failed"): + async for _ in generator: + pass + + mock_record.assert_called_once_with( + "provider/model1", + "/v1/responses", + "failure", + 3.0, + record_failure=True, + ) + @pytest.mark.asyncio async def test_handle_streaming_builds_tool_call_summary_from_output( self, diff --git a/tests/unit/authentication/test_utils.py b/tests/unit/authentication/test_utils.py index 36490d2bb..f9f337b5c 100644 --- a/tests/unit/authentication/test_utils.py +++ b/tests/unit/authentication/test_utils.py @@ -3,9 +3,10 @@ from typing import cast from fastapi import HTTPException +from pytest_mock import MockerFixture from starlette.datastructures import Headers -from authentication.utils import extract_user_token +from authentication.utils import extract_user_token, record_auth_metrics def test_extract_user_token() -> None: @@ -41,3 +42,18 @@ def test_extract_user_token_invalid_format() -> None: "Missing or invalid credentials provided by client" ) assert detail["cause"] == "No token found in Authorization header" + + +def test_record_auth_metrics_records_attempt_and_duration( + mocker: MockerFixture, +) -> None: + """Test recording auth attempt and duration through the shared helper.""" + mock_attempt = mocker.patch("authentication.utils.recording.record_auth_attempt") + mock_duration = mocker.patch("authentication.utils.recording.record_auth_duration") + mock_monotonic = mocker.patch("authentication.utils.time.monotonic") + mock_monotonic.return_value = 12.5 + + record_auth_metrics("jwk-token", "success", "authenticated", 10.0) + + mock_attempt.assert_called_once_with("jwk-token", "success", "authenticated") + mock_duration.assert_called_once_with("jwk-token", "success", 2.5) diff --git a/tests/unit/metrics/test_recording.py b/tests/unit/metrics/test_recording.py index 8012ac311..e4d3471ff 100644 --- a/tests/unit/metrics/test_recording.py +++ b/tests/unit/metrics/test_recording.py @@ -1,10 +1,38 @@ """Unit tests for Prometheus metric recording helpers.""" +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import pytest from pytest_mock import MockerFixture from metrics import recording +@dataclass(frozen=True) +class CounterRecorderCase: + """Expected behavior for a counter-style metric recorder.""" + + metric_path: str + recorder: Callable[..., None] + args: tuple[object, ...] + labels: tuple[object, ...] + warning_message: str + + +@dataclass(frozen=True) +class HistogramRecorderCase: + """Expected behavior for a histogram-style metric recorder.""" + + metric_path: str + recorder: Callable[..., None] + args: tuple[object, ...] + labels: tuple[object, ...] + duration: float + warning_message: str + + def test_measure_response_duration_records_timer(mocker: MockerFixture) -> None: """Test that response duration measurement uses the path label timer.""" mock_timer = mocker.MagicMock() @@ -159,3 +187,134 @@ def test_record_llm_token_usage_logs_metric_errors(mocker: MockerFixture) -> Non mock_logger.warning.assert_called_once_with( "Failed to update token metrics", exc_info=True ) + + +@pytest.fixture(name="recording_logger") +def recording_logger_fixture(mocker: MockerFixture) -> Any: + """Patch the metric recording logger for failure assertions.""" + return mocker.patch("metrics.recording.logger") + + +@pytest.mark.parametrize( + "case", + [ + CounterRecorderCase( + metric_path="metrics.recording.metrics.auth_attempts_total", + recorder=recording.record_auth_attempt, + args=("rh-identity", "success", "authenticated"), + labels=("rh-identity", "success", "authenticated"), + warning_message="Failed to update authentication metric", + ), + CounterRecorderCase( + metric_path="metrics.recording.metrics.authorization_checks_total", + recorder=recording.record_authorization_check, + args=("responses", "success"), + labels=("responses", "success"), + warning_message="Failed to update authorization metric", + ), + ], +) +def test_counter_recorders_update_metrics_and_log_errors( + mocker: MockerFixture, + recording_logger: Any, + case: CounterRecorderCase, +) -> None: + """Test new single-counter helpers with shared success and failure coverage.""" + mock_metric = mocker.patch(case.metric_path) + + case.recorder(*case.args) + + mock_metric.labels.assert_called_once_with(*case.labels) + mock_metric.labels.return_value.inc.assert_called_once() + + mock_metric.reset_mock() + mock_metric.labels.return_value.inc.side_effect = AttributeError("missing") + case.recorder(*case.args) + + recording_logger.warning.assert_called_once_with( + case.warning_message, exc_info=True + ) + + +@pytest.mark.parametrize( + "case", + [ + HistogramRecorderCase( + metric_path="metrics.recording.metrics.auth_duration_seconds", + recorder=recording.record_auth_duration, + args=("rh-identity", "success", 0.5), + labels=("rh-identity", "success"), + duration=0.5, + warning_message="Failed to update authentication duration metric", + ), + HistogramRecorderCase( + metric_path="metrics.recording.metrics.authorization_duration_seconds", + recorder=recording.record_authorization_duration, + args=("responses", "success", 0.25), + labels=("responses", "success"), + duration=0.25, + warning_message="Failed to update authorization duration metric", + ), + HistogramRecorderCase( + metric_path="metrics.recording.metrics.llm_inference_duration_seconds", + recorder=recording.record_llm_inference_duration, + args=("vertexai", "gemini", "/v1/responses", "success", 1.5), + labels=("vertexai", "gemini", "/v1/responses", "success"), + duration=1.5, + warning_message="Failed to update LLM inference duration metric", + ), + ], +) +def test_histogram_recorders_observe_metrics_and_log_errors( + mocker: MockerFixture, + recording_logger: Any, + case: HistogramRecorderCase, +) -> None: + """Test new histogram helpers with shared success and failure coverage.""" + mock_metric = mocker.patch(case.metric_path) + + case.recorder(*case.args) + + mock_metric.labels.assert_called_once_with(*case.labels) + mock_metric.labels.return_value.observe.assert_called_once_with(case.duration) + + mock_metric.reset_mock() + mock_metric.labels.return_value.observe.side_effect = TypeError("bad") + case.recorder(*case.args) + + recording_logger.warning.assert_called_once_with( + case.warning_message, exc_info=True + ) + + +@pytest.mark.parametrize("failing_metric", ["counter", "histogram"]) +def test_record_quota_check_updates_metrics_and_logs_errors( + mocker: MockerFixture, + recording_logger: Any, + failing_metric: str, +) -> None: + """Test quota helper counter and histogram updates plus both failure points.""" + mock_counter = mocker.patch("metrics.recording.metrics.quota_checks_total") + mock_histogram = mocker.patch( + "metrics.recording.metrics.quota_check_duration_seconds" + ) + + recording.record_quota_check("/v1/infer", "org_id", "success", 0.75) + + mock_counter.labels.assert_called_once_with("/v1/infer", "org_id", "success") + mock_counter.labels.return_value.inc.assert_called_once() + mock_histogram.labels.assert_called_once_with("/v1/infer", "org_id", "success") + mock_histogram.labels.return_value.observe.assert_called_once_with(0.75) + + mock_counter.reset_mock() + mock_histogram.reset_mock() + if failing_metric == "counter": + mock_counter.labels.return_value.inc.side_effect = TypeError("bad") + else: + mock_histogram.labels.return_value.observe.side_effect = TypeError("bad") + + recording.record_quota_check("/v1/infer", "org_id", "failure", 0.75) + + recording_logger.warning.assert_called_once_with( + "Failed to update quota check metrics", exc_info=True + )