Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -12023,6 +12023,11 @@
"$ref": "#/components/schemas/OkpConfiguration",
"title": "OKP configuration",
"description": "OKP provider settings. Only used when 'okp' is listed in rag.inline or rag.tool."
},
"reranker": {
"$ref": "#/components/schemas/RerankerConfiguration",
"title": "Reranker configuration",
"description": "Configuration for neural reranking of RAG chunks using cross-encoder."
}
},
"additionalProperties": false,
Expand Down Expand Up @@ -17820,6 +17825,26 @@
"title": "ReferencedDocument",
"description": "Model representing a document referenced in generating a response.\n\nAttributes:\n doc_url: Url to the referenced doc.\n doc_title: Title of the referenced doc."
},
"RerankerConfiguration": {
"properties": {
"enabled": {
"type": "boolean",
"title": "Reranker enabled",
"description": "When True, reranking applied to RAG chunks. When False, reranking is disabled and original scoring used.",
"default": false
},
"model": {
"type": "string",
"title": "Reranker model",
"description": "Cross-encoder model name for reranking RAG chunks. Defaults to 'cross-encoder/ms-marco-MiniLM-L6-v2' from sentence-transformers.",
"default": "cross-encoder/ms-marco-MiniLM-L6-v2"
}
},
"additionalProperties": false,
"type": "object",
"title": "RerankerConfiguration",
"description": "Reranker configuration for RAG chunk reranking."
},
"ResponseInput": {
"anyOf": [
{
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ dependencies = [
"jinja2>=3.1.0",
# To be able to fix multiple CVEs, also LCORE-1117
"requests>=2.33.0",
# Used for RAG chunk reranking (cross-encoder)
"datasets>=4.7.0",
# Used for error tracking and monitoring
"sentry-sdk[fastapi]>=2.58.0",
Expand Down
8 changes: 8 additions & 0 deletions src/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
OkpConfiguration,
QuotaHandlersConfiguration,
RagConfiguration,
RerankerConfiguration,
RlsapiV1Configuration,
ServiceConfiguration,
SplunkConfiguration,
Expand Down Expand Up @@ -465,6 +466,13 @@ def okp(self) -> "OkpConfiguration":
raise LogicError("logic error: configuration is not loaded")
return self._configuration.okp

@property
def reranker(self) -> "RerankerConfiguration":
"""Return reranker configuration."""
if self._configuration is None:
raise LogicError("logic error: configuration is not loaded")
return self._configuration.reranker

@property
def rag_id_mapping(self) -> dict[str, str]:
"""Return mapping from vector_db_id to rag_id from BYOK and OKP RAG config.
Expand Down
5 changes: 5 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@
# Default embedding vector dimension for the sentence transformer model
DEFAULT_EMBEDDING_DIMENSION: Final[int] = 768

# Default sentence transformer cross encoder model for reranking RAG chunk scores
DEFAULT_CROSS_ENCODER_MODEL: Final[str] = "cross-encoder/ms-marco-MiniLM-L6-v2"

# quota limiters constants
USER_QUOTA_LIMITER: Final[str] = "user_limiter"
CLUSTER_QUOTA_LIMITER: Final[str] = "cluster_limiter"
Expand All @@ -192,6 +195,8 @@
# Inline RAG constants
BYOK_RAG_MAX_CHUNKS: Final[int] = 10 # retrieved from BYOK RAG
OKP_RAG_MAX_CHUNKS: Final[int] = 5 # retrieved from OKP RAG
# Score multiplier applied to BYOK chunks after cross-encoder reranking (Solr chunks unchanged)
BYOK_RAG_RERANK_BOOST: Final[float] = 1.2

# Solr OKP constants
SOLR_VECTOR_SEARCH_DEFAULT_K: Final[int] = 5
Expand Down
71 changes: 71 additions & 0 deletions src/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,34 @@ class OkpConfiguration(ConfigurationBase):
)


class RerankerConfiguration(ConfigurationBase):
"""Reranker configuration for RAG chunk reranking."""

enabled: bool = Field(
default=False,
title="Reranker enabled",
description="When True, reranking applied to RAG chunks. "
"When False, reranking is disabled and original scoring used.",
)
model: str = Field(
default="cross-encoder/ms-marco-MiniLM-L6-v2",
title="Reranker model",
description="Cross-encoder model name for reranking RAG chunks. "
"Defaults to 'cross-encoder/ms-marco-MiniLM-L6-v2' from sentence-transformers.",
)

# Private attribute to track if this was explicitly configured
_explicitly_configured: bool = PrivateAttr(default=False)

@model_validator(mode="after")
def mark_as_explicitly_configured(self) -> Self:
"""Mark this configuration as explicitly set when instantiated from user input."""
if self.model_fields_set:
self._explicitly_configured = True

return self
Comment thread
Anxhela21 marked this conversation as resolved.


class AzureEntraIdConfiguration(ConfigurationBase):
"""Microsoft Entra ID authentication attributes for Azure."""

Expand Down Expand Up @@ -1976,6 +2004,12 @@ class Configuration(ConfigurationBase):
"in rag.inline or rag.tool.",
)

reranker: RerankerConfiguration = Field(
default_factory=RerankerConfiguration,
title="Reranker configuration",
description="Configuration for neural reranking of RAG chunks using cross-encoder.",
)

@model_validator(mode="after")
def validate_mcp_auth_headers(self) -> Self:
"""
Expand Down Expand Up @@ -2078,6 +2112,43 @@ def validate_rlsapi_v1_quota_configuration(self) -> Self:

return self

@model_validator(mode="after")
def validate_reranker_auto_enable(self) -> Self:
"""Automatically enable reranker when both BYOK and OKP RAG are configured.

When users have both BYOK entries in byok_rag and OKP
configured in the RAG strategies, automatically
enable the reranker if it's not explicitly disabled. This improves result
quality when multiple knowledge sources are available.

Returns:
Self: The validated configuration instance with reranker potentially enabled.
"""
# Check if BYOK RAG entries are configured
has_byok = len(self.byok_rag) > 0
Comment thread
Anxhela21 marked this conversation as resolved.

# Check if OKP is configured in either inline or tool RAG strategies
# pylint: disable=no-member
has_okp = constants.OKP_RAG_ID in self.rag.inline

# If both BYOK and OKP are present and reranker is using default settings,
# ensure it's enabled for optimal results
if (
has_byok
and has_okp
and not self.reranker._explicitly_configured # pylint: disable=protected-access
and not self.reranker.enabled
):
logger.info(
"Automatically enabling reranker: Both BYOK RAG (%d entries) or "
"other inline RAG and OKP are configured. Reranking improves result "
"quality when multiple knowledge sources are available.",
len(self.byok_rag),
)
self.reranker.enabled = True
Comment thread
Anxhela21 marked this conversation as resolved.

return self

def dump(self, filename: str | Path = "configuration.json") -> None:
"""
Write the current Configuration model to a JSON file.
Expand Down
216 changes: 216 additions & 0 deletions src/utils/reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
"""Reranker utilities for RAG chunk reranking.

This module contains functionality for reranking RAG chunks using cross-encoder models
to improve the relevance of retrieved documents in RAG applications.
"""

import asyncio
from typing import Any

import constants
from configuration import configuration
from log import get_logger
from models.common.turn_summary import RAGChunk

logger = get_logger(__name__)

# Lazy-loaded cross-encoder models for reranking RAG chunks (CPU-bound, use in thread).
# Cache models by name to avoid reloading the same model multiple times.
# Not a constant; pylint invalid-name is disabled for this module-level singleton.
_cross_encoder_models: dict[str, Any] = {} # pylint: disable=invalid-name
_cross_encoder_load_lock = asyncio.Lock()


async def _get_cross_encoder(model_name: str) -> Any:
"""Return the lazy-loaded cross-encoder model for reranking.

Args:
model_name: Name of the cross-encoder model to load.

Returns:
Loaded CrossEncoder model instance, or None if loading fails.
"""
# Check if reranking is enabled before attempting to load the model
if not configuration.reranker.enabled:
logger.debug("Reranker is disabled, not loading cross-encoder model")
return None

if model_name in _cross_encoder_models:
return _cross_encoder_models[model_name]
async with _cross_encoder_load_lock:
if model_name in _cross_encoder_models:
return _cross_encoder_models[model_name]
try:
from sentence_transformers import (
CrossEncoder,
) # pylint: disable=import-outside-toplevel

model = await asyncio.to_thread(CrossEncoder, model_name)
_cross_encoder_models[model_name] = model
logger.info("Loaded cross-encoder for RAG reranking: %s", model_name)
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning(
"Could not load cross-encoder for reranking (%s): %s", model_name, e
)
_cross_encoder_models[model_name] = None
return _cross_encoder_models[model_name]


# pylint: disable=too-many-locals,too-many-branches
async def rerank_chunks_with_cross_encoder(
query: str,
chunks: list[RAGChunk],
top_k: int,
) -> list[RAGChunk]:
"""Rerank chunks using configurable cross-encoder model.

Args:
query: The search query
chunks: RAG chunks to rerank (should contain original weighted scores)
top_k: Number of top chunks to return

Returns:
Top top_k chunks sorted by combined cross-encoder and weighted score (descending)
"""
if not chunks:
return []

try:
# Get the cached cross-encoder model
model_name = constants.DEFAULT_CROSS_ENCODER_MODEL
model = await _get_cross_encoder(model_name)
if model is None:
raise RuntimeError(f"Failed to load cross-encoder model: {model_name}")

logger.debug("Using cross-encoder model: %s", model_name)

# Create query-chunk pairs for scoring
pairs = [(query, chunk.content) for chunk in chunks]
scores = await asyncio.to_thread(model.predict, pairs)

if hasattr(scores, "tolist"):
scores = scores.tolist()

# Normalize cross-encoder scores to [0,1] range using min-max normalization
if len(scores) > 1:
min_score = min(scores)
max_score = max(scores)
score_range = max_score - min_score
if score_range > 0:
normalized_ce_scores = [
(score - min_score) / score_range for score in scores
]
else:
# All scores are identical, assign 0.5 to all
normalized_ce_scores = [0.5] * len(scores)
else:
# Single score, assign 1.0
normalized_ce_scores = [1.0] * len(scores)

# Extract original weighted scores and normalize them
original_scores = [
chunk.score if chunk.score is not None else 0.0 for chunk in chunks
]

if len(original_scores) > 1:
min_orig = min(original_scores)
max_orig = max(original_scores)
orig_range = max_orig - min_orig
if orig_range > 0:
normalized_orig_scores = [
(score - min_orig) / orig_range for score in original_scores
]
else:
# All original scores identical, assign 0.5 to all
normalized_orig_scores = [0.5] * len(original_scores)
else:
# Single score, assign 1.0
normalized_orig_scores = [1.0] * len(original_scores)

# Combine cross-encoder scores with original weighted scores
# (favor original weighted scores)
# This ensures score multipliers are still influential in the final ranking
# Weight: 30% cross-encoder, 70% original weighted scores
combined_scores = [
(0.3 * ce_score + 0.7 * orig_score)
for ce_score, orig_score in zip(
normalized_ce_scores, normalized_orig_scores, strict=True
)
]

# Combine scores with chunks and sort by combined score (descending)
indexed = list(zip(combined_scores, chunks, strict=True))
indexed.sort(key=lambda x: x[0], reverse=True)
top_indexed = indexed[:top_k]

# Log the score combination results
logger.info(
"Cross-encoder scoring completed: combined %d cross-encoder + "
"original scores (30%%/70%% mix), returning top %d chunks",
len(chunks),
len(top_indexed),
)
if logger.isEnabledFor(10): # DEBUG level
for i, (score, chunk) in enumerate(top_indexed[:3]): # Show top 3
logger.debug(
"Reranked chunk %d: source=%s, combined_score=%.3f, content_preview='%.50s...'",
i + 1,
chunk.source,
score,
chunk.content,
)

# Return RAGChunk list with combined scores
return [
RAGChunk(
content=chunk.content,
source=chunk.source,
score=float(score),
attributes=chunk.attributes,
)
for score, chunk in top_indexed
]

except Exception as e: # pylint: disable=broad-exception-caught
logger.warning(
"Cross-encoder reranking failed, falling back to original scoring: %s", e
)
# Fallback: sort by original score and take top_k
sorted_chunks = sorted(
chunks,
key=lambda c: c.score if c.score is not None else float("-inf"),
reverse=True,
)
return sorted_chunks[:top_k]


def apply_byok_rerank_boost(
chunks: list[RAGChunk], boost: float = constants.BYOK_RAG_RERANK_BOOST
) -> list[RAGChunk]:
"""Apply a score multiplier to BYOK chunks (source != OKP) and re-sort by score.

Args:
chunks: RAG chunks after reranking (may be from BYOK or Solr).
boost: Multiplier applied to BYOK chunk scores. Solr chunks unchanged.

Returns:
Same chunks with BYOK scores boosted, sorted by score descending.
"""
boosted = []
for chunk in chunks:
score = chunk.score if chunk.score is not None else float("-inf")
if chunk.source != constants.OKP_RAG_ID:
score = score * boost
boosted.append(
RAGChunk(
content=chunk.content,
source=chunk.source,
score=score,
attributes=chunk.attributes,
)
)
boosted.sort(
key=lambda c: c.score if c.score is not None else float("-inf"),
reverse=True,
)
return boosted
Loading
Loading