Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion src/app/endpoints/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,37 @@ def _get_user_agent(request: Request) -> Optional[str]:
return sanitized or None


def _check_response_quota(user_id: str, endpoint_path: str) -> None:
"""Check response quota availability and record bounded quota metrics."""
quota_start_time = time.monotonic()
try:
check_tokens_available(configuration.quota_limiters, user_id)
except HTTPException:
recording.record_quota_check(
endpoint_path,
recording.QUOTA_TYPE_USER_ID,
recording.QUOTA_RESULT_FAILURE,
time.monotonic() - quota_start_time,
)
raise
except Exception: # pylint: disable=broad-exception-caught
# Unexpected quota backend failures still need bounded metrics before
# propagating to the endpoint error handling layer.
recording.record_quota_check(
endpoint_path,
recording.QUOTA_TYPE_USER_ID,
recording.QUOTA_RESULT_ERROR,
time.monotonic() - quota_start_time,
)
raise
recording.record_quota_check(
endpoint_path,
recording.QUOTA_TYPE_USER_ID,
recording.QUOTA_RESULT_SUCCESS,
time.monotonic() - quota_start_time,
)


responses_response: dict[int | str, dict[str, Any]] = {
200: ResponsesResponse.openapi_response(),
401: UnauthorizedResponse.openapi_response(
Expand Down Expand Up @@ -358,11 +389,12 @@ async def responses_endpoint_handler(
started_at = datetime.now(UTC)
rh_identity_context = get_rh_identity_context(request)
user_id, _, _, token = auth
endpoint_path = ENDPOINT_PATH_RESPONSES

await check_mcp_auth(configuration, mcp_headers, token, request.headers)

# Check token availability
check_tokens_available(configuration.quota_limiters, user_id)
_check_response_quota(user_id, endpoint_path)

# Enforce RBAC: optionally disallow overriding model in requests
validate_model_provider_override(
Expand Down
70 changes: 64 additions & 6 deletions src/app/endpoints/rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,63 @@ def _resolve_quota_subject(request: Request, auth: AuthTuple) -> Optional[str]:
return system_id


def _check_infer_quota(
request: Request, auth: AuthTuple, endpoint_path: str
) -> Optional[str]:
"""Check infer quota availability and record bounded quota metrics.

Resolves the quota subject from the request and auth context, then
verifies that the subject has tokens available. All outcomes (success,
failure, error, skipped) are recorded as Prometheus metrics.

Args:
request: The incoming FastAPI request used to resolve the quota subject.
auth: Authentication tuple ``(user_id, username, skip_userid_check, token)``.
endpoint_path: API endpoint path for metric labeling.

Returns:
The resolved quota subject identifier, or ``None`` when quota is disabled.

Raises:
HTTPException: Re-raised from the quota limiter when the subject has
exhausted its token allowance (HTTP 429).
"""
quota_id = _resolve_quota_subject(request, auth)
quota_type = configuration.rlsapi_v1.quota_subject or "disabled"
if quota_id is None:
recording.record_quota_check(
endpoint_path, quota_type, recording.QUOTA_RESULT_SKIPPED, 0.0
)
return None

quota_start_time = time.monotonic()
try:
check_tokens_available(configuration.quota_limiters, quota_id)
except HTTPException:
recording.record_quota_check(
endpoint_path,
quota_type,
recording.QUOTA_RESULT_FAILURE,
time.monotonic() - quota_start_time,
)
raise
except Exception: # pylint: disable=broad-exception-caught
recording.record_quota_check(
endpoint_path,
quota_type,
recording.QUOTA_RESULT_ERROR,
time.monotonic() - quota_start_time,
)
raise
recording.record_quota_check(
endpoint_path,
quota_type,
recording.QUOTA_RESULT_SUCCESS,
time.monotonic() - quota_start_time,
)
return quota_id
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def _build_infer_response(
response_text: str,
request_id: str,
Expand Down Expand Up @@ -733,16 +790,17 @@ async def infer_endpoint( # pylint: disable=R0914,R0915

logger.info("Processing rlsapi v1 /infer request %s", request_id)

# Quota enforcement: resolve subject and check availability before any work.
# No-op when quota_subject is not configured or no quota limiters exist.
quota_id = _resolve_quota_subject(request, auth)
if quota_id is not None:
# Quota enforcement: check availability before any work and record metrics for
# both enforced and disabled quota paths.
quota_subject = configuration.rlsapi_v1.quota_subject
if quota_subject is not None:
logger.info(
"Checking quota availability for rlsapi v1 request %s using subject type %s",
request_id,
configuration.rlsapi_v1.quota_subject,
quota_subject,
)
check_tokens_available(configuration.quota_limiters, quota_id)
quota_id = _check_infer_quota(request, auth, endpoint_path)
if quota_id is not None:
logger.info(
"Quota availability check passed for rlsapi v1 request %s", request_id
)
Expand Down
33 changes: 31 additions & 2 deletions src/authentication/api_key_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@
"""

import secrets
import time

from fastapi import HTTPException, Request, status

from authentication.interface import AuthInterface, AuthTuple
from authentication.utils import extract_user_token
from authentication.utils import (
AUTH_CAUSE_NO_HEADER,
extract_user_token,
record_auth_metrics,
)
from constants import (
AUTH_MOD_APIKEY_TOKEN,
DEFAULT_USER_NAME,
DEFAULT_USER_UID,
DEFAULT_VIRTUAL_PATH,
Expand Down Expand Up @@ -59,16 +65,39 @@ async def __call__(self, request: Request) -> AuthTuple:
HTTPException: If the bearer token is missing or
doesn't match the configured API key (HTTP 401).
"""
start_time = time.monotonic()

# try to extract user token from request
user_token = extract_user_token(request.headers)
try:
user_token = extract_user_token(request.headers)
except HTTPException as exc:
# Distinguish missing header from malformed token
reason = "missing_token"
if (
isinstance(exc.detail, dict)
and exc.detail.get("cause") == AUTH_CAUSE_NO_HEADER
):
reason = "missing_header"
record_auth_metrics(AUTH_MOD_APIKEY_TOKEN, "failure", reason, start_time)
raise
Comment thread
coderabbitai[bot] marked this conversation as resolved.
except Exception: # pylint: disable=broad-exception-caught
logger.exception("Unexpected error while extracting API key bearer token")
record_auth_metrics(
AUTH_MOD_APIKEY_TOKEN, "failure", "unexpected_error", start_time
)
raise

# API Key validation. Use secrets.compare_digest for constant-time comparison
if not secrets.compare_digest(
user_token, self.config.api_key.get_secret_value()
):
record_auth_metrics(
AUTH_MOD_APIKEY_TOKEN, "failure", "invalid_key", start_time
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API Key",
)

record_auth_metrics(AUTH_MOD_APIKEY_TOKEN, "success", "valid_key", start_time)
return DEFAULT_USER_UID, DEFAULT_USER_NAME, self.skip_userid_check, user_token
Loading
Loading