Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
329 changes: 260 additions & 69 deletions src/app/endpoints/responses.py

Large diffs are not rendered by default.

43 changes: 38 additions & 5 deletions src/app/endpoints/rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
"""
inference_time = time.monotonic() - start_time
recording.record_llm_failure(provider, model, endpoint_path)
recording.record_llm_inference_duration(
provider, model, endpoint_path, "failure", inference_time
)
_queue_splunk_event(
background_tasks,
infer_request,
Expand Down Expand Up @@ -532,6 +535,36 @@ def _resolve_quota_subject(request: Request, auth: AuthTuple) -> Optional[str]:
return system_id


def _check_infer_quota(
request: Request, auth: AuthTuple, endpoint_path: str
) -> Optional[str]:
"""Check infer quota availability and record bounded quota metrics."""
Comment on lines +538 to +541
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial | 💤 Low value

Add docstring detail for return value.

The function docstring lacks documentation for the return value. Per coding guidelines, docstrings should include a Returns section.

📝 Suggested docstring improvement
 def _check_infer_quota(
     request: Request, auth: AuthTuple, endpoint_path: str
 ) -> Optional[str]:
-    """Check infer quota availability and record bounded quota metrics."""
+    """Check infer quota availability and record bounded quota metrics.
+
+    Args:
+        request: The FastAPI request object for resolving identity context.
+        auth: Authentication tuple from the configured auth provider.
+        endpoint_path: API endpoint path for metric labeling.
+
+    Returns:
+        The resolved quota subject identifier, or None if quota is disabled.
+
+    Raises:
+        HTTPException: 429 if quota is exhausted.
+    """
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/app/endpoints/rlsapi_v1.py` around lines 538 - 541, Update the docstring
for _check_infer_quota(request: Request, auth: AuthTuple, endpoint_path: str) to
include a Returns section that documents the return type Optional[str],
specifying that it returns a string error message when quota is exceeded or
another blocking condition is detected and returns None when the check passes;
follow the project's docstring style (brief summary, Args, Returns) and place
the Returns section directly under the summary as with other endpoint helpers.

quota_id = _resolve_quota_subject(request, auth)
quota_type = configuration.rlsapi_v1.quota_subject or "disabled"
if quota_id is None:
recording.record_quota_check(endpoint_path, quota_type, "skipped", 0.0)
return None

quota_start_time = time.monotonic()
try:
check_tokens_available(configuration.quota_limiters, quota_id)
except HTTPException:
recording.record_quota_check(
endpoint_path,
quota_type,
"failure",
time.monotonic() - quota_start_time,
)
raise
recording.record_quota_check(
endpoint_path,
quota_type,
"success",
time.monotonic() - quota_start_time,
)
return quota_id


def _build_infer_response(
response_text: str,
request_id: str,
Expand Down Expand Up @@ -669,12 +702,11 @@ async def infer_endpoint( # pylint: disable=R0914
"""
# Authentication enforced by get_auth_dependency(), authorization by @authorize decorator.
check_configuration_loaded(configuration)
endpoint_path = "/v1/infer"

# Quota enforcement: resolve subject and check availability before any work.
# No-op when quota_subject is not configured or no quota limiters exist.
quota_id = _resolve_quota_subject(request, auth)
if quota_id is not None:
check_tokens_available(configuration.quota_limiters, quota_id)
quota_id = _check_infer_quota(request, auth, endpoint_path)

request_id = get_suid()

Expand All @@ -685,8 +717,6 @@ async def infer_endpoint( # pylint: disable=R0914
"Request %s: Combined input source length: %d", request_id, len(input_source)
)

endpoint_path = "/v1/infer"

# Run shield moderation on user input before inference.
# Uses all configured shields; no-op when no shields are registered.
# Runs before model/tool discovery so blocked requests short-circuit
Expand Down Expand Up @@ -721,6 +751,9 @@ async def infer_endpoint( # pylint: disable=R0914
response_text = extract_text_from_response_items(response.output)
token_usage = extract_token_usage(response.usage, model_id, endpoint_path)
inference_time = time.monotonic() - start_time
recording.record_llm_inference_duration(
provider, model, endpoint_path, "success", inference_time
)
except _INFER_HANDLED_EXCEPTIONS as error:
if response is not None:
extract_token_usage(response.usage, model_id, endpoint_path) # type: ignore[arg-type]
Expand Down
18 changes: 16 additions & 2 deletions src/authentication/api_key_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
"""

import secrets
import time

from fastapi import HTTPException, Request, status

from authentication.interface import AuthInterface
from authentication.utils import extract_user_token
from authentication.utils import extract_user_token, record_auth_metrics
from constants import (
AUTH_MOD_APIKEY_TOKEN,
DEFAULT_USER_NAME,
DEFAULT_USER_UID,
DEFAULT_VIRTUAL_PATH,
Expand Down Expand Up @@ -59,16 +61,28 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]:
HTTPException: If the bearer token is missing or
doesn't match the configured API key (HTTP 401).
"""
start_time = time.monotonic()

# try to extract user token from request
user_token = extract_user_token(request.headers)
try:
user_token = extract_user_token(request.headers)
except HTTPException:
record_auth_metrics(
AUTH_MOD_APIKEY_TOKEN, "failure", "missing_token", start_time
)
raise

# API Key validation. Use secrets.compare_digest for constant-time comparison
if not secrets.compare_digest(
user_token, self.config.api_key.get_secret_value()
):
record_auth_metrics(
AUTH_MOD_APIKEY_TOKEN, "failure", "invalid_key", start_time
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API Key",
)

record_auth_metrics(AUTH_MOD_APIKEY_TOKEN, "success", "valid_key", start_time)
return DEFAULT_USER_UID, DEFAULT_USER_NAME, self.skip_userid_check, user_token
164 changes: 103 additions & 61 deletions src/authentication/jwk_token.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Manage authentication flow for FastAPI endpoints with JWK based JWT auth."""

import json
import time
from asyncio import Lock
from collections.abc import Callable
from typing import Any
Expand All @@ -17,8 +18,9 @@
from fastapi import HTTPException, Request

from authentication.interface import AuthInterface, AuthTuple
from authentication.utils import extract_user_token
from authentication.utils import extract_user_token, record_auth_metrics
from constants import (
AUTH_MOD_JWK_TOKEN,
DEFAULT_VIRTUAL_PATH,
)
from log import get_logger
Expand Down Expand Up @@ -139,6 +141,85 @@ def _internal(header: dict[str, Any], _payload: dict[str, Any]) -> Key:
return _internal


async def _get_jwk_set_for_auth(config: JwkConfiguration, start_time: float) -> KeySet:
"""Load the configured JWK set and record bounded auth failures."""
try:
return await get_jwk_set(str(config.url))
except aiohttp.ClientError as exc:
logger.error("Failed to fetch JWK set: %s", exc)
record_auth_metrics(
AUTH_MOD_JWK_TOKEN, "failure", "jwk_fetch_error", start_time
)
response = UnauthorizedResponse(
cause="Unable to reach authentication key server"
)
raise HTTPException(**response.model_dump()) from exc
except json.JSONDecodeError as exc:
logger.error("Invalid JSON in JWK set response: %s", exc)
record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_json", start_time)
response = UnauthorizedResponse(
cause="Authentication key server returned invalid data"
)
raise HTTPException(**response.model_dump()) from exc
except JoseError as exc:
logger.error("Invalid JWK set format: %s", exc)
record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_jwk", start_time)
response = UnauthorizedResponse(cause="Authentication keys are malformed")
raise HTTPException(**response.model_dump()) from exc


def _decode_jwk_claims(user_token: str, jwk_set: KeySet, start_time: float) -> Any:
"""Decode a JWT and record bounded auth failures."""
try:
return jwt.decode(user_token, key=key_resolver_func(jwk_set))
except (KeyNotFoundError, BadSignatureError, DecodeError, JoseError) as exc:
logger.warning("Token decode error: %s", exc)
record_auth_metrics(
AUTH_MOD_JWK_TOKEN, "failure", "token_decode_error", start_time
)
cause_map = {
KeyNotFoundError: "Token signed by unknown key",
BadSignatureError: "Invalid token signature",
DecodeError: "Token could not be decoded",
JoseError: "Token format error",
}
response = UnauthorizedResponse(
cause=cause_map.get(type(exc), "Unknown token error")
)
raise HTTPException(**response.model_dump()) from exc


def _validate_jwk_claims(claims: Any, start_time: float) -> None:
"""Validate decoded JWT claims and record bounded auth failures."""
try:
claims.validate()
except ExpiredTokenError as exc:
record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "token_expired", start_time)
response = UnauthorizedResponse(cause="Token has expired")
raise HTTPException(**response.model_dump()) from exc
except JoseError as exc:
record_auth_metrics(
AUTH_MOD_JWK_TOKEN, "failure", "token_validation_error", start_time
)
response = UnauthorizedResponse(cause="Token validation failed")
raise HTTPException(**response.model_dump()) from exc


def _get_required_claim(claims: Any, claim_name: str, start_time: float) -> str:
"""Return a required JWT claim and record bounded auth failures when missing."""
try:
value = claims[claim_name]
except KeyError as exc:
record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "missing_claim", start_time)
response = UnauthorizedResponse(cause=f"Token missing claim: {claim_name}")
raise HTTPException(**response.model_dump()) from exc
if not isinstance(value, str) or not value:
record_auth_metrics(AUTH_MOD_JWK_TOKEN, "failure", "invalid_claim", start_time)
response = UnauthorizedResponse(cause=f"Token has invalid claim: {claim_name}")
raise HTTPException(**response.model_dump())
return value


class JwkTokenAuthDependency(AuthInterface): # pylint: disable=too-few-public-methods
"""JWK AuthDependency class for JWK-based JWT authentication."""

Expand Down Expand Up @@ -187,73 +268,34 @@ async def __call__(self, request: Request) -> AuthTuple:
extracted from the validated JWT. Only returned on successful
authentication; all error paths raise HTTPException.
"""
start_time = time.monotonic()

if not request.headers.get("Authorization"):
record_auth_metrics(
AUTH_MOD_JWK_TOKEN, "failure", "missing_header", start_time
)
response = UnauthorizedResponse(cause="No Authorization header found")
raise HTTPException(**response.model_dump())

user_token = extract_user_token(request.headers)

try:
jwk_set = await get_jwk_set(str(self.config.url))
except aiohttp.ClientError as exc:
logger.error("Failed to fetch JWK set: %s", exc)
response = UnauthorizedResponse(
cause="Unable to reach authentication key server"
)
raise HTTPException(**response.model_dump()) from exc
except json.JSONDecodeError as exc:
logger.error("Invalid JSON in JWK set response: %s", exc)
response = UnauthorizedResponse(
cause="Authentication key server returned invalid data"
)
raise HTTPException(**response.model_dump()) from exc
except JoseError as exc:
logger.error("Invalid JWK set format: %s", exc)
response = UnauthorizedResponse(cause="Authentication keys are malformed")
raise HTTPException(**response.model_dump()) from exc

try:
claims = jwt.decode(user_token, key=key_resolver_func(jwk_set))
except (KeyNotFoundError, BadSignatureError, DecodeError, JoseError) as exc:
logger.warning("Token decode error: %s", exc)
cause_map = {
KeyNotFoundError: "Token signed by unknown key",
BadSignatureError: "Invalid token signature",
DecodeError: "Token could not be decoded",
JoseError: "Token format error",
}
response = UnauthorizedResponse(
cause=cause_map.get(type(exc), "Unknown token error")
)
raise HTTPException(**response.model_dump()) from exc

try:
claims.validate()
except ExpiredTokenError as exc:
response = UnauthorizedResponse(cause="Token has expired")
raise HTTPException(**response.model_dump()) from exc
except JoseError as exc:
response = UnauthorizedResponse(cause="Token validation failed")
raise HTTPException(**response.model_dump()) from exc

try:
user_id: str = claims[self.config.jwt_configuration.user_id_claim]
except KeyError as exc:
missing_claim = self.config.jwt_configuration.user_id_claim
response = UnauthorizedResponse(
cause=f"Token missing claim: {missing_claim}"
)
raise HTTPException(**response.model_dump()) from exc

try:
username: str = claims[self.config.jwt_configuration.username_claim]
except KeyError as exc:
missing_claim = self.config.jwt_configuration.username_claim
response = UnauthorizedResponse(
cause=f"Token missing claim: {missing_claim}"
user_token = extract_user_token(request.headers)
except HTTPException:
record_auth_metrics(
AUTH_MOD_JWK_TOKEN, "failure", "missing_token", start_time
)
raise HTTPException(**response.model_dump()) from exc
raise

jwk_set = await _get_jwk_set_for_auth(self.config, start_time)
claims = _decode_jwk_claims(user_token, jwk_set, start_time)
_validate_jwk_claims(claims, start_time)
user_id = _get_required_claim(
claims, self.config.jwt_configuration.user_id_claim, start_time
)
username = _get_required_claim(
claims, self.config.jwt_configuration.username_claim, start_time
)

logger.info("Successfully authenticated user %s (ID: %s)", username, user_id)

record_auth_metrics(AUTH_MOD_JWK_TOKEN, "success", "authenticated", start_time)
return user_id, username, self.skip_userid_check, user_token
Loading
Loading