diff --git a/docs/openapi.json b/docs/openapi.json index 8ff6e171e..94cc320d0 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -12023,6 +12023,11 @@ "$ref": "#/components/schemas/OkpConfiguration", "title": "OKP configuration", "description": "OKP provider settings. Only used when 'okp' is listed in rag.inline or rag.tool." + }, + "reranker": { + "$ref": "#/components/schemas/RerankerConfiguration", + "title": "Reranker configuration", + "description": "Configuration for neural reranking of RAG chunks using cross-encoder." } }, "additionalProperties": false, @@ -17820,6 +17825,26 @@ "title": "ReferencedDocument", "description": "Model representing a document referenced in generating a response.\n\nAttributes:\n doc_url: Url to the referenced doc.\n doc_title: Title of the referenced doc." }, + "RerankerConfiguration": { + "properties": { + "enabled": { + "type": "boolean", + "title": "Reranker enabled", + "description": "When True, reranking applied to RAG chunks. When False, reranking is disabled and original scoring used.", + "default": false + }, + "model": { + "type": "string", + "title": "Reranker model", + "description": "Cross-encoder model name for reranking RAG chunks. Defaults to 'cross-encoder/ms-marco-MiniLM-L6-v2' from sentence-transformers.", + "default": "cross-encoder/ms-marco-MiniLM-L6-v2" + } + }, + "additionalProperties": false, + "type": "object", + "title": "RerankerConfiguration", + "description": "Reranker configuration for RAG chunk reranking." + }, "ResponseInput": { "anyOf": [ { diff --git a/pyproject.toml b/pyproject.toml index 7b7cc72b6..4bf38553b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ dependencies = [ "jinja2>=3.1.0", # To be able to fix multiple CVEs, also LCORE-1117 "requests>=2.33.0", + # Used for RAG chunk reranking (cross-encoder) "datasets>=4.7.0", # Used for error tracking and monitoring "sentry-sdk[fastapi]>=2.58.0", diff --git a/src/configuration.py b/src/configuration.py index 4eeca0460..c6a376327 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -27,6 +27,7 @@ OkpConfiguration, QuotaHandlersConfiguration, RagConfiguration, + RerankerConfiguration, RlsapiV1Configuration, ServiceConfiguration, SplunkConfiguration, @@ -465,6 +466,13 @@ def okp(self) -> "OkpConfiguration": raise LogicError("logic error: configuration is not loaded") return self._configuration.okp + @property + def reranker(self) -> "RerankerConfiguration": + """Return reranker configuration.""" + if self._configuration is None: + raise LogicError("logic error: configuration is not loaded") + return self._configuration.reranker + @property def rag_id_mapping(self) -> dict[str, str]: """Return mapping from vector_db_id to rag_id from BYOK and OKP RAG config. diff --git a/src/constants.py b/src/constants.py index cbc3cdd74..88dd2aee5 100644 --- a/src/constants.py +++ b/src/constants.py @@ -181,6 +181,9 @@ # Default embedding vector dimension for the sentence transformer model DEFAULT_EMBEDDING_DIMENSION: Final[int] = 768 +# Default sentence transformer cross encoder model for reranking RAG chunk scores +DEFAULT_CROSS_ENCODER_MODEL: Final[str] = "cross-encoder/ms-marco-MiniLM-L6-v2" + # quota limiters constants USER_QUOTA_LIMITER: Final[str] = "user_limiter" CLUSTER_QUOTA_LIMITER: Final[str] = "cluster_limiter" @@ -192,6 +195,8 @@ # Inline RAG constants BYOK_RAG_MAX_CHUNKS: Final[int] = 10 # retrieved from BYOK RAG OKP_RAG_MAX_CHUNKS: Final[int] = 5 # retrieved from OKP RAG +# Score multiplier applied to BYOK chunks after cross-encoder reranking (Solr chunks unchanged) +BYOK_RAG_RERANK_BOOST: Final[float] = 1.2 # Solr OKP constants SOLR_VECTOR_SEARCH_DEFAULT_K: Final[int] = 5 diff --git a/src/models/config.py b/src/models/config.py index 646f94c89..38de45180 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -1817,6 +1817,34 @@ class OkpConfiguration(ConfigurationBase): ) +class RerankerConfiguration(ConfigurationBase): + """Reranker configuration for RAG chunk reranking.""" + + enabled: bool = Field( + default=False, + title="Reranker enabled", + description="When True, reranking applied to RAG chunks. " + "When False, reranking is disabled and original scoring used.", + ) + model: str = Field( + default="cross-encoder/ms-marco-MiniLM-L6-v2", + title="Reranker model", + description="Cross-encoder model name for reranking RAG chunks. " + "Defaults to 'cross-encoder/ms-marco-MiniLM-L6-v2' from sentence-transformers.", + ) + + # Private attribute to track if this was explicitly configured + _explicitly_configured: bool = PrivateAttr(default=False) + + @model_validator(mode="after") + def mark_as_explicitly_configured(self) -> Self: + """Mark this configuration as explicitly set when instantiated from user input.""" + if self.model_fields_set: + self._explicitly_configured = True + + return self + + class AzureEntraIdConfiguration(ConfigurationBase): """Microsoft Entra ID authentication attributes for Azure.""" @@ -1976,6 +2004,12 @@ class Configuration(ConfigurationBase): "in rag.inline or rag.tool.", ) + reranker: RerankerConfiguration = Field( + default_factory=RerankerConfiguration, + title="Reranker configuration", + description="Configuration for neural reranking of RAG chunks using cross-encoder.", + ) + @model_validator(mode="after") def validate_mcp_auth_headers(self) -> Self: """ @@ -2078,6 +2112,43 @@ def validate_rlsapi_v1_quota_configuration(self) -> Self: return self + @model_validator(mode="after") + def validate_reranker_auto_enable(self) -> Self: + """Automatically enable reranker when both BYOK and OKP RAG are configured. + + When users have both BYOK entries in byok_rag and OKP + configured in the RAG strategies, automatically + enable the reranker if it's not explicitly disabled. This improves result + quality when multiple knowledge sources are available. + + Returns: + Self: The validated configuration instance with reranker potentially enabled. + """ + # Check if BYOK RAG entries are configured + has_byok = len(self.byok_rag) > 0 + + # Check if OKP is configured in either inline or tool RAG strategies + # pylint: disable=no-member + has_okp = constants.OKP_RAG_ID in self.rag.inline + + # If both BYOK and OKP are present and reranker is using default settings, + # ensure it's enabled for optimal results + if ( + has_byok + and has_okp + and not self.reranker._explicitly_configured # pylint: disable=protected-access + and not self.reranker.enabled + ): + logger.info( + "Automatically enabling reranker: Both BYOK RAG (%d entries) or " + "other inline RAG and OKP are configured. Reranking improves result " + "quality when multiple knowledge sources are available.", + len(self.byok_rag), + ) + self.reranker.enabled = True + + return self + def dump(self, filename: str | Path = "configuration.json") -> None: """ Write the current Configuration model to a JSON file. diff --git a/src/utils/reranker.py b/src/utils/reranker.py new file mode 100644 index 000000000..72cd21b88 --- /dev/null +++ b/src/utils/reranker.py @@ -0,0 +1,216 @@ +"""Reranker utilities for RAG chunk reranking. + +This module contains functionality for reranking RAG chunks using cross-encoder models +to improve the relevance of retrieved documents in RAG applications. +""" + +import asyncio +from typing import Any + +import constants +from configuration import configuration +from log import get_logger +from models.common.turn_summary import RAGChunk + +logger = get_logger(__name__) + +# Lazy-loaded cross-encoder models for reranking RAG chunks (CPU-bound, use in thread). +# Cache models by name to avoid reloading the same model multiple times. +# Not a constant; pylint invalid-name is disabled for this module-level singleton. +_cross_encoder_models: dict[str, Any] = {} # pylint: disable=invalid-name +_cross_encoder_load_lock = asyncio.Lock() + + +async def _get_cross_encoder(model_name: str) -> Any: + """Return the lazy-loaded cross-encoder model for reranking. + + Args: + model_name: Name of the cross-encoder model to load. + + Returns: + Loaded CrossEncoder model instance, or None if loading fails. + """ + # Check if reranking is enabled before attempting to load the model + if not configuration.reranker.enabled: + logger.debug("Reranker is disabled, not loading cross-encoder model") + return None + + if model_name in _cross_encoder_models: + return _cross_encoder_models[model_name] + async with _cross_encoder_load_lock: + if model_name in _cross_encoder_models: + return _cross_encoder_models[model_name] + try: + from sentence_transformers import ( + CrossEncoder, + ) # pylint: disable=import-outside-toplevel + + model = await asyncio.to_thread(CrossEncoder, model_name) + _cross_encoder_models[model_name] = model + logger.info("Loaded cross-encoder for RAG reranking: %s", model_name) + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning( + "Could not load cross-encoder for reranking (%s): %s", model_name, e + ) + _cross_encoder_models[model_name] = None + return _cross_encoder_models[model_name] + + +# pylint: disable=too-many-locals,too-many-branches +async def rerank_chunks_with_cross_encoder( + query: str, + chunks: list[RAGChunk], + top_k: int, +) -> list[RAGChunk]: + """Rerank chunks using configurable cross-encoder model. + + Args: + query: The search query + chunks: RAG chunks to rerank (should contain original weighted scores) + top_k: Number of top chunks to return + + Returns: + Top top_k chunks sorted by combined cross-encoder and weighted score (descending) + """ + if not chunks: + return [] + + try: + # Get the cached cross-encoder model + model_name = constants.DEFAULT_CROSS_ENCODER_MODEL + model = await _get_cross_encoder(model_name) + if model is None: + raise RuntimeError(f"Failed to load cross-encoder model: {model_name}") + + logger.debug("Using cross-encoder model: %s", model_name) + + # Create query-chunk pairs for scoring + pairs = [(query, chunk.content) for chunk in chunks] + scores = await asyncio.to_thread(model.predict, pairs) + + if hasattr(scores, "tolist"): + scores = scores.tolist() + + # Normalize cross-encoder scores to [0,1] range using min-max normalization + if len(scores) > 1: + min_score = min(scores) + max_score = max(scores) + score_range = max_score - min_score + if score_range > 0: + normalized_ce_scores = [ + (score - min_score) / score_range for score in scores + ] + else: + # All scores are identical, assign 0.5 to all + normalized_ce_scores = [0.5] * len(scores) + else: + # Single score, assign 1.0 + normalized_ce_scores = [1.0] * len(scores) + + # Extract original weighted scores and normalize them + original_scores = [ + chunk.score if chunk.score is not None else 0.0 for chunk in chunks + ] + + if len(original_scores) > 1: + min_orig = min(original_scores) + max_orig = max(original_scores) + orig_range = max_orig - min_orig + if orig_range > 0: + normalized_orig_scores = [ + (score - min_orig) / orig_range for score in original_scores + ] + else: + # All original scores identical, assign 0.5 to all + normalized_orig_scores = [0.5] * len(original_scores) + else: + # Single score, assign 1.0 + normalized_orig_scores = [1.0] * len(original_scores) + + # Combine cross-encoder scores with original weighted scores + # (favor original weighted scores) + # This ensures score multipliers are still influential in the final ranking + # Weight: 30% cross-encoder, 70% original weighted scores + combined_scores = [ + (0.3 * ce_score + 0.7 * orig_score) + for ce_score, orig_score in zip( + normalized_ce_scores, normalized_orig_scores, strict=True + ) + ] + + # Combine scores with chunks and sort by combined score (descending) + indexed = list(zip(combined_scores, chunks, strict=True)) + indexed.sort(key=lambda x: x[0], reverse=True) + top_indexed = indexed[:top_k] + + # Log the score combination results + logger.info( + "Cross-encoder scoring completed: combined %d cross-encoder + " + "original scores (30%%/70%% mix), returning top %d chunks", + len(chunks), + len(top_indexed), + ) + if logger.isEnabledFor(10): # DEBUG level + for i, (score, chunk) in enumerate(top_indexed[:3]): # Show top 3 + logger.debug( + "Reranked chunk %d: source=%s, combined_score=%.3f, content_preview='%.50s...'", + i + 1, + chunk.source, + score, + chunk.content, + ) + + # Return RAGChunk list with combined scores + return [ + RAGChunk( + content=chunk.content, + source=chunk.source, + score=float(score), + attributes=chunk.attributes, + ) + for score, chunk in top_indexed + ] + + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning( + "Cross-encoder reranking failed, falling back to original scoring: %s", e + ) + # Fallback: sort by original score and take top_k + sorted_chunks = sorted( + chunks, + key=lambda c: c.score if c.score is not None else float("-inf"), + reverse=True, + ) + return sorted_chunks[:top_k] + + +def apply_byok_rerank_boost( + chunks: list[RAGChunk], boost: float = constants.BYOK_RAG_RERANK_BOOST +) -> list[RAGChunk]: + """Apply a score multiplier to BYOK chunks (source != OKP) and re-sort by score. + + Args: + chunks: RAG chunks after reranking (may be from BYOK or Solr). + boost: Multiplier applied to BYOK chunk scores. Solr chunks unchanged. + + Returns: + Same chunks with BYOK scores boosted, sorted by score descending. + """ + boosted = [] + for chunk in chunks: + score = chunk.score if chunk.score is not None else float("-inf") + if chunk.source != constants.OKP_RAG_ID: + score = score * boost + boosted.append( + RAGChunk( + content=chunk.content, + source=chunk.source, + score=score, + attributes=chunk.attributes, + ) + ) + boosted.sort( + key=lambda c: c.score if c.score is not None else float("-inf"), + reverse=True, + ) + return boosted diff --git a/src/utils/vector_search.py b/src/utils/vector_search.py index 9d901271e..31e6f1b73 100644 --- a/src/utils/vector_search.py +++ b/src/utils/vector_search.py @@ -21,11 +21,53 @@ from models.common.query import SolrVectorSearchRequest from models.common.responses.types import ResponseInput from models.common.turn_summary import RAGChunk, RAGContext, ReferencedDocument +from utils.reranker import apply_byok_rerank_boost, rerank_chunks_with_cross_encoder from utils.responses import resolve_vector_store_ids logger = get_logger(__name__) +def _filter_documents_for_chunks( + all_documents: list[ReferencedDocument], + final_chunks: list[RAGChunk], +) -> list[ReferencedDocument]: + """Filter documents to match the final set of chunks after reranking. + + Args: + all_documents: All documents extracted from both BYOK and Solr sources. + final_chunks: Final chunks after merging and reranking. + + Returns: + Filtered list of documents that correspond to the final chunks. + """ + # Create a set of unique identifiers from final chunks + final_chunk_identifiers = set() + for chunk in final_chunks: + attrs = chunk.attributes or {} + # Use same logic as original extraction to identify documents + doc_url = ( + attrs.get("reference_url") or attrs.get("doc_url") or attrs.get("docs_url") + ) + doc_id = attrs.get("document_id") or attrs.get("doc_id") + dedup_key = doc_url or doc_id or chunk.source or "" + if dedup_key: + final_chunk_identifiers.add(dedup_key) + + # Filter documents that match final chunk identifiers + filtered_documents = [] + seen = set() + for doc in all_documents: + # Build same dedup key for document + doc_url_str = str(doc.doc_url) if doc.doc_url else None + dedup_key = doc_url_str or doc.source or "" + + if dedup_key in final_chunk_identifiers and dedup_key not in seen: + seen.add(dedup_key) + filtered_documents.append(doc) + + return filtered_documents + + def _get_okp_base_url() -> AnyUrl: """Return OKP document base URL from configuration (rhokp_url), or default if unset. @@ -56,14 +98,16 @@ def _get_solr_vector_store_ids() -> list[str]: def _build_query_params( solr: Optional[SolrVectorSearchRequest] = None, + k: Optional[int] = None, ) -> dict[str, Any]: """Build query parameters for Solr vector_io search. Args: solr: Optional structured Solr request (mode and filters from the API). + k: Optional number of results to return. If not provided, uses default. Returns: - Parameter dictionary for ``vector_io.query``. + Query parameters dict for vector_io.query. """ resolved_mode = ( solr.mode @@ -71,7 +115,7 @@ def _build_query_params( else constants.SOLR_VECTOR_SEARCH_DEFAULT_MODE ) params: dict[str, Any] = { - "k": constants.SOLR_VECTOR_SEARCH_DEFAULT_K, + "k": k if k is not None else constants.SOLR_VECTOR_SEARCH_DEFAULT_K, "score_threshold": constants.SOLR_VECTOR_SEARCH_DEFAULT_SCORE_THRESHOLD, "mode": resolved_mode, } @@ -107,7 +151,9 @@ def _extract_byok_rag_chunks( ): weighted_score = score * weight doc_id = ( - chunk.metadata.get("document_id", chunk.chunk_id) + chunk.metadata.get("document_id") + or chunk.metadata.get("doc_id") + or chunk.chunk_id if chunk.metadata else chunk.chunk_id ) @@ -180,6 +226,7 @@ async def _query_store_for_byok_rag( vector_store_id: str, query: str, weight: float, + max_chunks: int = constants.BYOK_RAG_MAX_CHUNKS, ) -> list[dict[str, Any]]: """Query a single vector store for BYOK RAG. @@ -188,6 +235,7 @@ async def _query_store_for_byok_rag( vector_store_id: ID of the vector store to query query: Search query string weight: Score multiplier to apply + max_chunks: Maximum number of chunks to request from this store. Returns: List of weighted result dictionaries, or empty list on error @@ -197,7 +245,7 @@ async def _query_store_for_byok_rag( vector_store_id=vector_store_id, query=query, params={ - "max_chunks": constants.BYOK_RAG_MAX_CHUNKS, + "max_chunks": max_chunks, "mode": "vector", }, ) @@ -252,7 +300,11 @@ def _process_byok_rag_chunks_for_documents( for result in result_chunks: metadata = result.get("metadata", {}) - doc_id = result.get("doc_id") or metadata.get("document_id") + doc_id = ( + result.get("doc_id") + or metadata.get("document_id") + or metadata.get("doc_id") + ) title = metadata.get("title") reference_url = ( metadata.get("reference_url") @@ -344,26 +396,29 @@ def _process_solr_chunks_for_documents( return doc_ids_from_chunks -async def _fetch_byok_rag( +async def _fetch_byok_rag( # pylint: disable=too-many-locals client: AsyncLlamaStackClient, query: str, - vector_store_ids: Optional[list[str]] = None, # User-facing + vector_store_ids: Optional[list[str]] = None, + max_chunks: Optional[int] = None, ) -> tuple[list[RAGChunk], list[ReferencedDocument]]: """Fetch chunks and documents from BYOK RAG sources. Args: client: The AsyncLlamaStackClient to use for the request query: The search query - configuration: Application configuration vector_store_ids: Optional list of vector store IDs to query. If provided, only these stores will be queried. If None, all stores (excluding Solr) will be queried. + max_chunks: Maximum number of chunks to return. If None, uses + constants.BYOK_RAG_MAX_CHUNKS. Returns: Tuple containing: - rag_chunks: RAG chunks from BYOK RAG - referenced_documents: Documents referenced in BYOK RAG results """ + limit = max_chunks if max_chunks is not None else constants.BYOK_RAG_MAX_CHUNKS rag_chunks: list[RAGChunk] = [] referenced_documents: list[ReferencedDocument] = [] @@ -410,6 +465,7 @@ async def _fetch_byok_rag( vector_store_id, query, score_multiplier_mapping.get(vector_store_id, 1.0), + max_chunks=limit, ) for vector_store_id in vector_store_ids_to_query ] @@ -420,7 +476,7 @@ async def _fetch_byok_rag( for store_results in results_per_store: all_results.extend(store_results) all_results.sort(key=lambda x: x["weighted_score"], reverse=True) - top_results = all_results[: constants.BYOK_RAG_MAX_CHUNKS] + top_results = all_results[:limit] # Resolve source, log, and convert to RAGChunk in a single pass logger.info("Filtered top %d chunks from BYOK RAG", len(top_results)) @@ -451,7 +507,7 @@ async def _fetch_byok_rag( return rag_chunks, referenced_documents -async def _fetch_solr_rag( +async def _fetch_solr_rag( # pylint: disable=too-many-locals client: AsyncLlamaStackClient, query: str, solr: Optional[SolrVectorSearchRequest] = None, @@ -462,6 +518,8 @@ async def _fetch_solr_rag( client: The AsyncLlamaStackClient to use for the request query: The user's query solr: Structured Solr inline RAG request from the API (optional). + max_chunks: Maximum number of chunks to return. If None, uses + constants.OKP_RAG_MAX_CHUNKS. Returns: Tuple containing: @@ -470,6 +528,7 @@ async def _fetch_solr_rag( """ rag_chunks: list[RAGChunk] = [] referenced_documents: list[ReferencedDocument] = [] + limit = constants.OKP_RAG_MAX_CHUNKS if not _is_solr_enabled(): logger.info("OKP vector IO is disabled, skipping OKP search") @@ -502,8 +561,8 @@ async def _fetch_solr_rag( ) # Limit to top N chunks - top_chunks = query_response.chunks[: constants.OKP_RAG_MAX_CHUNKS] - top_scores = retrieved_scores[: constants.OKP_RAG_MAX_CHUNKS] + top_chunks = query_response.chunks[:limit] + top_scores = retrieved_scores[:limit] # Extract referenced documents from Solr chunks referenced_documents = _process_solr_chunks_for_documents( @@ -516,7 +575,7 @@ async def _fetch_solr_rag( ) logger.debug( "Filtered top %d chunks from OKP RAG (%d were retrieved)", - constants.OKP_RAG_MAX_CHUNKS, + limit, len(rag_chunks), ) @@ -527,20 +586,22 @@ async def _fetch_solr_rag( return rag_chunks, referenced_documents -async def build_rag_context( +async def build_rag_context( # pylint: disable=too-many-locals,too-many-branches client: AsyncLlamaStackClient, - moderation_decision: str, + moderation_decision: str, # pylint: disable=unused-argument query: str, vector_store_ids: Optional[list[str]], solr: Optional[SolrVectorSearchRequest] = None, ) -> RAGContext: """Build RAG context by fetching and merging chunks from all enabled sources. - Enabled sources can be BYOK and/or Solr OKP. + Fetches 2 * BYOK_RAG_MAX_CHUNKS from each of BYOK and Solr, merges and keeps + top 2 * BYOK_RAG_MAX_CHUNKS by score, reranks with a cross-encoder, then + keeps the top BYOK_RAG_MAX_CHUNKS for context. Enabled sources can be BYOK + and/or Solr OKP. Args: client: The AsyncLlamaStackClient to use for the request - moderation_decision: The moderation decision query: The user's query vector_store_ids: The vector store IDs to query solr: Structured Solr inline RAG request from the API (optional). @@ -551,27 +612,54 @@ async def build_rag_context( if moderation_decision == "blocked": return RAGContext() - # Fetch from all enabled RAG sources in parallel - byok_chunks_task = _fetch_byok_rag(client, query, vector_store_ids) + pool_size = 2 * constants.BYOK_RAG_MAX_CHUNKS + top_k = constants.BYOK_RAG_MAX_CHUNKS + + # Fetch 2*BYOK_RAG_MAX_CHUNKS from each source in parallel + byok_chunks_task = _fetch_byok_rag( + client, query, vector_store_ids, max_chunks=pool_size + ) solr_chunks_task = _fetch_solr_rag(client, query, solr) - (byok_chunks, byok_docs), (solr_chunks, solr_docs) = await asyncio.gather( + (byok_chunks, byok_documents), (solr_chunks, solr_documents) = await asyncio.gather( byok_chunks_task, solr_chunks_task ) - # Merge chunks from all sources (BYOK + Solr) - context_chunks = byok_chunks + solr_chunks + # Merge: combine and sort by score, keep top 2*BYOK_RAG_MAX_CHUNKS + merged = byok_chunks + solr_chunks + merged.sort( + key=lambda c: c.score if c.score is not None else float("-inf"), reverse=True + ) + merged = merged[:pool_size] + + # Rerank full pool with cross-encoder if enabled; boost BYOK then take top_k + if configuration.reranker.enabled: + logger.info( + "Reranker enabled: processing %d chunks with model '%s'", + len(merged), + configuration.reranker.model, + ) + reranked = await rerank_chunks_with_cross_encoder(query, merged, pool_size) + context_chunks = apply_byok_rerank_boost(reranked)[:top_k] + logger.info( + "Reranker completed: returned %d top chunks after BYOK boost", + len(context_chunks), + ) + else: + logger.info("Reranker disabled: using original vector similarity scores") + context_chunks = merged[:top_k] context_text = _format_rag_context(context_chunks, query) logger.debug( - "Inline RAG context built: %d chunks, %d characters", + "Inline RAG context built: %d chunks (after rerank), %d characters", len(context_chunks), len(context_text), ) - # Merge referenced documents from all sources (BYOK + Solr) - top_documents = byok_docs + solr_docs + # Filter documents to match final chunks (after reranking) + all_documents = byok_documents + solr_documents + top_documents = _filter_documents_for_chunks(all_documents, context_chunks) return RAGContext( context_text=context_text, @@ -602,7 +690,8 @@ def _build_document_url( Build document URL based on offline flag and available metadata. Args: - offline: Whether to use offline mode (parent_id) or online mode (reference_url) + offline: Whether to use offline + (parent_id) or online mode (reference_url) doc_id: Document ID from chunk metadata reference_url: Reference URL from chunk metadata diff --git a/tests/integration/endpoints/test_query_byok_integration.py b/tests/integration/endpoints/test_query_byok_integration.py index 7a4c65863..bdf080489 100644 --- a/tests/integration/endpoints/test_query_byok_integration.py +++ b/tests/integration/endpoints/test_query_byok_integration.py @@ -1095,6 +1095,9 @@ async def test_query_byok_max_chunks_caps_retrieved_results( # pylint: disable= test_config.configuration.byok_rag = [entry] test_config.configuration.rag.inline = ["big-source"] + # Disable reranker for this test since it's testing chunk capping, not reranking + test_config.configuration.reranker.enabled = False + mock_holder_class = mocker.patch("app.endpoints.query.AsyncLlamaStackClientHolder") mock_client = _build_base_mock_client(mocker) @@ -1153,6 +1156,7 @@ async def test_query_byok_max_chunks_caps_retrieved_results( # pylint: disable= # The lowest-scored chunks from the original set should be excluded # The highest score in the original set is at the last index highest_original_score = chunks_data[-1][2] # score of the last chunk + # When reranker is disabled, BYOK boost is NOT applied, so we expect original score assert response.rag_chunks[0].score == highest_original_score diff --git a/tests/unit/models/config/test_dump_configuration.py b/tests/unit/models/config/test_dump_configuration.py index 9eb94ba6b..4e4b71069 100644 --- a/tests/unit/models/config/test_dump_configuration.py +++ b/tests/unit/models/config/test_dump_configuration.py @@ -102,6 +102,7 @@ def test_dump_configuration(tmp_path: Path) -> None: assert "byok_rag" in content assert "quota_handlers" in content assert "azure_entra_id" in content + assert "reranker" in content # check the whole deserialized JSON file content assert content == { @@ -222,6 +223,10 @@ def test_dump_configuration(tmp_path: Path) -> None: }, "splunk": None, "deployment_environment": "development", + "reranker": { + "enabled": False, + "model": "cross-encoder/ms-marco-MiniLM-L6-v2", + }, } @@ -445,6 +450,7 @@ def test_dump_configuration_with_quota_limiters(tmp_path: Path) -> None: assert "byok_rag" in content assert "quota_handlers" in content assert "azure_entra_id" in content + assert "reranker" in content # check the whole deserialized JSON file content assert content == { @@ -580,6 +586,10 @@ def test_dump_configuration_with_quota_limiters(tmp_path: Path) -> None: }, "splunk": None, "deployment_environment": "development", + "reranker": { + "enabled": False, + "model": "cross-encoder/ms-marco-MiniLM-L6-v2", + }, } @@ -815,6 +825,10 @@ def test_dump_configuration_with_quota_limiters_different_values( }, "splunk": None, "deployment_environment": "development", + "reranker": { + "enabled": False, + "model": "cross-encoder/ms-marco-MiniLM-L6-v2", + }, } @@ -1025,6 +1039,10 @@ def test_dump_configuration_byok(tmp_path: Path) -> None: }, "splunk": None, "deployment_environment": "development", + "reranker": { + "enabled": False, + "model": "cross-encoder/ms-marco-MiniLM-L6-v2", + }, } @@ -1220,4 +1238,8 @@ def test_dump_configuration_pg_namespace(tmp_path: Path) -> None: }, "splunk": None, "deployment_environment": "development", + "reranker": { + "enabled": False, + "model": "cross-encoder/ms-marco-MiniLM-L6-v2", + }, } diff --git a/tests/unit/models/config/test_reranker_configuration.py b/tests/unit/models/config/test_reranker_configuration.py new file mode 100644 index 000000000..06a819071 --- /dev/null +++ b/tests/unit/models/config/test_reranker_configuration.py @@ -0,0 +1,53 @@ +"""Unit tests for RerankerConfiguration model.""" + +import pytest +from pydantic import ValidationError + +import constants +from models.config import RerankerConfiguration + + +class TestRerankerConfiguration: + """Tests for RerankerConfiguration model.""" + + def test_default_values(self) -> None: + """Test that RerankerConfiguration has correct default values.""" + config = RerankerConfiguration() + assert config.enabled is False + assert config.model == constants.DEFAULT_CROSS_ENCODER_MODEL + + def test_custom_model(self) -> None: + """Test configuration with custom cross-encoder model.""" + config = RerankerConfiguration(model="cross-encoder/ms-marco-TinyBERT-L2-v2") + assert config.model == "cross-encoder/ms-marco-TinyBERT-L2-v2" + assert config.enabled is False + + def test_disabled_reranker(self) -> None: + """Test configuration with reranker disabled.""" + config = RerankerConfiguration(enabled=False) + assert config.enabled is False + assert config.model == constants.DEFAULT_CROSS_ENCODER_MODEL + + def test_model_fields_set_detection(self) -> None: + """Test that model_fields_set is properly detected.""" + config = RerankerConfiguration(model="custom-model") + assert config.model == "custom-model" + + def test_all_custom_values(self) -> None: + """Test configuration with all custom values.""" + config = RerankerConfiguration(enabled=False, model="custom-cross-encoder") + assert config.enabled is False + assert config.model == "custom-cross-encoder" + + def test_explicit_configuration_detection(self) -> None: + """Test that explicitly configured values are detected.""" + # Non-default values should mark as explicitly configured + config = RerankerConfiguration(enabled=False) + assert hasattr(config, "_explicitly_configured") + # Note: The actual _explicitly_configured logic is private + # and tested through integration tests + + def test_invalid_field_rejected(self) -> None: + """Test that invalid fields are rejected due to extra='forbid'.""" + with pytest.raises(ValidationError): + RerankerConfiguration(invalid_field="value") diff --git a/tests/unit/utils/test_vector_search.py b/tests/unit/utils/test_vector_search.py index 0945bb236..f301e6657 100644 --- a/tests/unit/utils/test_vector_search.py +++ b/tests/unit/utils/test_vector_search.py @@ -1,5 +1,7 @@ """Unit tests for vector search utilities.""" +# pylint: disable=too-many-lines + import pytest from pydantic import AnyUrl from pytest_mock import MockerFixture @@ -8,6 +10,11 @@ from configuration import AppConfig from models.common.query import SolrVectorSearchRequest from models.common.turn_summary import RAGChunk +from utils.reranker import ( + _get_cross_encoder, + apply_byok_rerank_boost, + rerank_chunks_with_cross_encoder, +) from utils.vector_search import ( _build_document_url, _build_query_params, @@ -719,3 +726,554 @@ async def test_byok_enabled_only(self, mocker: MockerFixture) -> None: assert len(context.rag_chunks) > 0 assert "BYOK content" in context.context_text assert "file_search found" in context.context_text + + @pytest.mark.asyncio + async def test_reranker_enabled_calls_cross_encoder( + self, mocker: MockerFixture + ) -> None: + """Test that cross-encoder is called when reranker is enabled.""" + # Mock configuration with reranker enabled + config_mock = mocker.Mock(spec=AppConfig) + byok_rag_mock = mocker.Mock() + byok_rag_mock.rag_id = "rag_1" + byok_rag_mock.vector_db_id = "vs_1" + config_mock.configuration.rag.inline = ["rag_1"] + config_mock.configuration.byok_rag = [byok_rag_mock] + config_mock.inline_solr_enabled = False + config_mock.score_multiplier_mapping = {"vs_1": 1.0} + config_mock.rag_id_mapping = {"vs_1": "rag_1"} + config_mock.reranker.enabled = True + config_mock.reranker.model = "test-model" + mocker.patch("utils.vector_search.configuration", config_mock) + mocker.patch("utils.reranker.configuration", config_mock) + + # Mock BYOK search response + chunk_mock = mocker.Mock() + chunk_mock.content = "BYOK content" + chunk_mock.chunk_id = "chunk_1" + chunk_mock.metadata = {"document_id": "doc_1"} + + search_response = mocker.Mock() + search_response.chunks = [chunk_mock] + search_response.scores = [0.9] + + client_mock = mocker.AsyncMock() + client_mock.vector_io.query.return_value = search_response + + # Mock cross-encoder reranking function + mock_rerank = mocker.patch("utils.vector_search.rerank_chunks_with_cross_encoder") + mock_rerank.return_value = [ + RAGChunk(content="BYOK content", source="rag_1", score=0.95) + ] + + context = await build_rag_context(client_mock, "passed", "test query", None) + + # Verify cross-encoder was called + mock_rerank.assert_called_once() + assert mock_rerank.call_args[0][0] == "test query" # query parameter + # Check that chunks were passed as second argument + assert len(mock_rerank.call_args[0][1]) == 1 # chunks parameter + + assert len(context.rag_chunks) > 0 + + @pytest.mark.asyncio + async def test_reranker_disabled_skips_cross_encoder( + self, mocker: MockerFixture + ) -> None: + """Test that cross-encoder is skipped when reranker is disabled.""" + # Mock configuration with reranker disabled + config_mock = mocker.Mock(spec=AppConfig) + byok_rag_mock = mocker.Mock() + byok_rag_mock.rag_id = "rag_1" + byok_rag_mock.vector_db_id = "vs_1" + config_mock.configuration.rag.inline = ["rag_1"] + config_mock.configuration.byok_rag = [byok_rag_mock] + config_mock.inline_solr_enabled = False + config_mock.score_multiplier_mapping = {"vs_1": 1.0} + config_mock.rag_id_mapping = {"vs_1": "rag_1"} + config_mock.reranker.enabled = False + mocker.patch("utils.vector_search.configuration", config_mock) + + # Mock BYOK search response + chunk_mock = mocker.Mock() + chunk_mock.content = "BYOK content" + chunk_mock.chunk_id = "chunk_1" + chunk_mock.metadata = {"document_id": "doc_1"} + + search_response = mocker.Mock() + search_response.chunks = [chunk_mock] + search_response.scores = [0.9] + + client_mock = mocker.AsyncMock() + client_mock.vector_io.query.return_value = search_response + + # Mock cross-encoder reranking function + mock_rerank = mocker.patch("utils.reranker.rerank_chunks_with_cross_encoder") + + context = await build_rag_context(client_mock, "passed", "test query", None) + + # Verify cross-encoder was NOT called + mock_rerank.assert_not_called() + + assert len(context.rag_chunks) > 0 + + +class TestGetCrossEncoder: + """Tests for _get_cross_encoder function.""" + + @pytest.mark.asyncio + async def test_loads_model_successfully(self, mocker: MockerFixture) -> None: + """Test successful model loading and caching when reranker is enabled.""" + # Clear the cache for testing + # pylint: disable=import-outside-toplevel + from utils.reranker import _cross_encoder_models + + _cross_encoder_models.clear() + + # Mock reranker configuration to be enabled + mock_config = mocker.Mock() + mock_config.reranker.enabled = True + mocker.patch("utils.vector_search.configuration", mock_config) + mocker.patch("utils.reranker.configuration", mock_config) + + # Mock the CrossEncoder class by patching the import + mock_model_instance = mocker.Mock() + mock_cross_encoder = mocker.Mock(return_value=mock_model_instance) + + # Patch the import at the module level where it happens + mocker.patch.dict( + "sys.modules", + {"sentence_transformers": mocker.Mock(CrossEncoder=mock_cross_encoder)}, + ) + + # Mock asyncio.to_thread + mocker.patch("asyncio.to_thread", return_value=mock_model_instance) + + model = await _get_cross_encoder("test-model") + + assert model == mock_model_instance + + @pytest.mark.asyncio + async def test_caches_loaded_model(self, mocker: MockerFixture) -> None: + """Test that models are cached and not reloaded when reranker is enabled.""" + # Clear the cache for testing + # pylint: disable=import-outside-toplevel + from utils.reranker import _cross_encoder_models + + _cross_encoder_models.clear() + + # Mock reranker configuration to be enabled + mock_config = mocker.Mock() + mock_config.reranker.enabled = True + mocker.patch("utils.vector_search.configuration", mock_config) + mocker.patch("utils.reranker.configuration", mock_config) + + mock_model_instance = mocker.Mock() + mock_cross_encoder = mocker.Mock(return_value=mock_model_instance) + + # Patch the import at the module level where it happens + mocker.patch.dict( + "sys.modules", + {"sentence_transformers": mocker.Mock(CrossEncoder=mock_cross_encoder)}, + ) + + # Mock asyncio.to_thread + mocker.patch("asyncio.to_thread", return_value=mock_model_instance) + + # First call should load the model + model1 = await _get_cross_encoder("test-model") + # Second call should return cached model + model2 = await _get_cross_encoder("test-model") + + assert model1 == model2 == mock_model_instance + + @pytest.mark.asyncio + async def test_handles_import_error(self, mocker: MockerFixture) -> None: + """Test graceful handling of sentence_transformers import error when reranker is enabled.""" + # Clear the cache for testing + # pylint: disable=import-outside-toplevel + from utils.reranker import _cross_encoder_models + + _cross_encoder_models.clear() + + # Mock reranker configuration to be enabled + mock_config = mocker.Mock() + mock_config.reranker.enabled = True + mocker.patch("utils.vector_search.configuration", mock_config) + mocker.patch("utils.reranker.configuration", mock_config) + + # Mock asyncio.to_thread to raise an exception + mocker.patch("asyncio.to_thread", side_effect=Exception("Model loading failed")) + + model = await _get_cross_encoder("test-model") + + assert model is None + + @pytest.mark.asyncio + async def test_handles_model_loading_error(self, mocker: MockerFixture) -> None: + """Test graceful handling of model instantiation error when reranker is enabled.""" + # Clear the cache for testing + # pylint: disable=import-outside-toplevel + from utils.reranker import _cross_encoder_models + + _cross_encoder_models.clear() + + # Mock reranker configuration to be enabled + mock_config = mocker.Mock() + mock_config.reranker.enabled = True + mocker.patch("utils.vector_search.configuration", mock_config) + mocker.patch("utils.reranker.configuration", mock_config) + + # Mock asyncio.to_thread to raise an exception + mocker.patch("asyncio.to_thread", side_effect=Exception("Model loading failed")) + + model = await _get_cross_encoder("test-model") + + assert model is None + + @pytest.mark.asyncio + async def test_returns_none_when_reranker_disabled( + self, mocker: MockerFixture + ) -> None: + """Test that _get_cross_encoder returns None when reranker is disabled.""" + # Clear the cache for testing + # pylint: disable=import-outside-toplevel + from utils.reranker import _cross_encoder_models + + _cross_encoder_models.clear() + + # Mock reranker configuration to be disabled + mock_config = mocker.Mock() + mock_config.reranker.enabled = False + mocker.patch("utils.vector_search.configuration", mock_config) + + # Mock the CrossEncoder class - should not be called since reranker is disabled + mock_cross_encoder = mocker.Mock() + mocker.patch.dict( + "sys.modules", + {"sentence_transformers": mocker.Mock(CrossEncoder=mock_cross_encoder)}, + ) + + model = await _get_cross_encoder("test-model") + + assert model is None + # Verify CrossEncoder was not instantiated since reranker is disabled + mock_cross_encoder.assert_not_called() + + @pytest.mark.asyncio + async def test_does_not_cache_when_reranker_disabled( + self, mocker: MockerFixture + ) -> None: + """Test that no caching occurs when reranker is disabled.""" + # Clear the cache for testing + # pylint: disable=import-outside-toplevel + from utils.reranker import _cross_encoder_models + + _cross_encoder_models.clear() + + # Mock reranker configuration to be disabled + mock_config = mocker.Mock() + mock_config.reranker.enabled = False + mocker.patch("utils.vector_search.configuration", mock_config) + + # Call multiple times + model1 = await _get_cross_encoder("test-model") + model2 = await _get_cross_encoder("test-model") + + assert model1 is None + assert model2 is None + # Verify cache remains empty + assert "test-model" not in _cross_encoder_models + + +class TestRerankChunksWithCrossEncoder: + """Tests for rerank_chunks_with_cross_encoder function.""" + + @pytest.mark.asyncio + async def test_empty_chunks(self) -> None: + """Test reranking with empty chunks list.""" + result = await rerank_chunks_with_cross_encoder("test query", [], 5) + assert result == [] + + @pytest.mark.asyncio + async def test_successful_reranking(self, mocker: MockerFixture) -> None: + """Test successful reranking with combined cross-encoder and original scores.""" + # Create test chunks + chunks = [ + RAGChunk(content="Content 1", source="source_1", score=0.5), + RAGChunk(content="Content 2", source="source_2", score=0.3), + RAGChunk(content="Content 3", source="source_3", score=0.8), + ] + + # Mock cross-encoder model and prediction + mock_model = mocker.Mock() + mock_model.predict.return_value = [2.5, 1.0, 3.0] # Raw scores + + # Mock _get_cross_encoder to return our mock model + mocker.patch( + "utils.reranker._get_cross_encoder", + new_callable=mocker.AsyncMock, + return_value=mock_model, + ) + + result = await rerank_chunks_with_cross_encoder("test query", chunks, 3) + + # Verify model was called with correct pairs + expected_pairs = [ + ("test query", "Content 1"), + ("test query", "Content 2"), + ("test query", "Content 3"), + ] + mock_model.predict.assert_called_once_with(expected_pairs) + + # Verify results are sorted by combined scores (highest first) + assert len(result) == 3 + assert result[0].content == "Content 3" # Highest combined score + assert result[1].content == "Content 1" # Middle combined score + assert result[2].content == "Content 2" # Lowest combined score + + # Verify scores are combined (30% cross-encoder + 70% original weighted scores) + # Content 3: 0.3 * 1.0 + 0.7 * 1.0 = 1.0 + # Content 1: 0.3 * 0.75 + 0.7 * 0.4 = 0.505 (approximately) + # Content 2: 0.3 * 0.0 + 0.7 * 0.0 = 0.0 + assert result[0].score == 1.0 + assert abs(result[1].score - 0.505) < 0.01 # Allow small floating point errors + assert result[2].score == 0.0 + + @pytest.mark.asyncio + async def test_top_k_limiting(self, mocker: MockerFixture) -> None: + """Test that top_k limits the number of returned chunks.""" + chunks = [ + RAGChunk(content="Content 1", source="source_1", score=0.5), + RAGChunk(content="Content 2", source="source_2", score=0.3), + RAGChunk(content="Content 3", source="source_3", score=0.8), + ] + + mock_model = mocker.Mock() + mock_model.predict.return_value = [2.5, 1.0, 3.0] + mocker.patch( + "utils.reranker._get_cross_encoder", + new_callable=mocker.AsyncMock, + return_value=mock_model, + ) + + result = await rerank_chunks_with_cross_encoder("test query", chunks, 2) + + assert len(result) == 2 # Limited to top_k=2 + assert result[0].content == "Content 3" + assert result[1].content == "Content 1" + + @pytest.mark.asyncio + async def test_identical_scores_normalization(self, mocker: MockerFixture) -> None: + """Test normalization when all cross-encoder scores are identical.""" + chunks = [ + RAGChunk(content="Content 1", source="source_1", score=0.5), + RAGChunk(content="Content 2", source="source_2", score=0.3), + ] + + mock_model = mocker.Mock() + mock_model.predict.return_value = [1.5, 1.5] # Identical cross-encoder scores + mocker.patch( + "utils.reranker._get_cross_encoder", + new_callable=mocker.AsyncMock, + return_value=mock_model, + ) + + result = await rerank_chunks_with_cross_encoder("test query", chunks, 2) + + # When cross-encoder scores are identical (both normalized to 0.5), + # combined scores should favor original scores + # Content 1: 0.3 * 0.5 + 0.7 * 1.0 = 0.85 (orig score 0.5 normalized to 1.0) + # Content 2: 0.3 * 0.5 + 0.7 * 0.0 = 0.15 (orig score 0.3 normalized to 0.0) + assert len(result) == 2 + assert result[0].content == "Content 1" # Higher original score + assert result[1].content == "Content 2" # Lower original score + assert result[0].score == 0.85 + assert result[1].score == 0.15 + + @pytest.mark.asyncio + async def test_single_chunk_normalization(self, mocker: MockerFixture) -> None: + """Test normalization with single chunk.""" + chunks = [RAGChunk(content="Content 1", source="source_1", score=0.5)] + + mock_model = mocker.Mock() + mock_model.predict.return_value = [2.5] + mocker.patch( + "utils.reranker._get_cross_encoder", + new_callable=mocker.AsyncMock, + return_value=mock_model, + ) + + result = await rerank_chunks_with_cross_encoder("test query", chunks, 1) + + # Single chunk should get score 1.0 + assert len(result) == 1 + assert result[0].score == 1.0 + + @pytest.mark.asyncio + async def test_model_loading_failure_fallback(self, mocker: MockerFixture) -> None: + """Test fallback to original scores when model loading fails.""" + chunks = [ + RAGChunk(content="Content 1", source="source_1", score=0.8), + RAGChunk(content="Content 2", source="source_2", score=0.6), + ] + + # Mock _get_cross_encoder to return None (loading failed) + mocker.patch( + "utils.reranker._get_cross_encoder", + new_callable=mocker.AsyncMock, + return_value=None, + ) + + result = await rerank_chunks_with_cross_encoder("test query", chunks, 2) + + # Should return chunks sorted by original scores + assert len(result) == 2 + assert result[0].content == "Content 1" # Higher original score + assert result[1].content == "Content 2" + assert result[0].score == 0.8 # Original scores preserved + assert result[1].score == 0.6 + + @pytest.mark.asyncio + async def test_prediction_failure_fallback(self, mocker: MockerFixture) -> None: + """Test fallback when model.predict() raises exception.""" + chunks = [ + RAGChunk(content="Content 1", source="source_1", score=0.9), + RAGChunk(content="Content 2", source="source_2", score=0.7), + ] + + mock_model = mocker.Mock() + mock_model.predict.side_effect = Exception("Prediction failed") + mocker.patch( + "utils.reranker._get_cross_encoder", + new_callable=mocker.AsyncMock, + return_value=mock_model, + ) + + result = await rerank_chunks_with_cross_encoder("test query", chunks, 2) + + # Should fallback to original scores + assert len(result) == 2 + assert result[0].content == "Content 1" + assert result[0].score == 0.9 + + @pytest.mark.asyncio + async def test_numpy_array_scores(self, mocker: MockerFixture) -> None: + """Test handling of numpy array scores from model prediction.""" + chunks = [RAGChunk(content="Content 1", source="source_1", score=0.5)] + + # Mock numpy array with tolist() method + mock_scores = mocker.Mock() + mock_scores.tolist.return_value = [2.5] + + mock_model = mocker.Mock() + mock_model.predict.return_value = mock_scores + mocker.patch( + "utils.reranker._get_cross_encoder", + new_callable=mocker.AsyncMock, + return_value=mock_model, + ) + + result = await rerank_chunks_with_cross_encoder("test query", chunks, 1) + + # Should successfully handle numpy array conversion + assert len(result) == 1 + assert result[0].score == 1.0 + mock_scores.tolist.assert_called_once() + + +class TestApplyByokRerankBoost: + """Tests for apply_byok_rerank_boost function.""" + + def test_empty_chunks(self) -> None: + """Test boost application with empty chunks list.""" + result = apply_byok_rerank_boost([]) + assert not result + + def test_boost_byok_chunks_only(self) -> None: + """Test that only BYOK chunks (non-OKP) get boosted.""" + chunks = [ + RAGChunk(content="BYOK content", source="byok_store", score=0.8), + RAGChunk(content="OKP content", source=constants.OKP_RAG_ID, score=0.6), + RAGChunk(content="Another BYOK", source="another_store", score=0.7), + ] + + result = apply_byok_rerank_boost(chunks, boost=2.0) + + assert len(result) == 3 + + # Find chunks by content for assertion + byok_chunk = next(c for c in result if c.content == "BYOK content") + okp_chunk = next(c for c in result if c.content == "OKP content") + another_byok = next(c for c in result if c.content == "Another BYOK") + + # BYOK chunks should be boosted + assert byok_chunk.score == 1.6 # 0.8 * 2.0 + assert another_byok.score == 1.4 # 0.7 * 2.0 + + # OKP chunk should remain unchanged + assert okp_chunk.score == 0.6 + + def test_sorting_by_boosted_scores(self) -> None: + """Test that chunks are sorted by boosted scores in descending order.""" + chunks = [ + RAGChunk(content="Low BYOK", source="byok_store", score=0.5), + RAGChunk(content="High OKP", source=constants.OKP_RAG_ID, score=0.9), + RAGChunk(content="Mid BYOK", source="another_store", score=0.7), + ] + + result = apply_byok_rerank_boost(chunks, boost=2.0) + + # After boosting: Low BYOK=1.0, High OKP=0.9, Mid BYOK=1.4 + # Sorted order should be: Mid BYOK (1.4), Low BYOK (1.0), High OKP (0.9) + assert result[0].content == "Mid BYOK" + assert result[1].content == "Low BYOK" + assert result[2].content == "High OKP" + + def test_default_boost_factor(self) -> None: + """Test that default boost factor is applied correctly.""" + chunks = [RAGChunk(content="BYOK content", source="byok_store", score=0.8)] + + result = apply_byok_rerank_boost(chunks) # Using default boost + + # Default boost should be constants.BYOK_RAG_RERANK_BOOST (1.2) + assert result[0].score == 0.8 * constants.BYOK_RAG_RERANK_BOOST + + def test_none_scores_handled(self) -> None: + """Test handling of chunks with None scores.""" + chunks = [ + RAGChunk(content="BYOK with score", source="byok_store", score=0.8), + RAGChunk(content="BYOK no score", source="byok_store", score=None), + RAGChunk(content="OKP no score", source=constants.OKP_RAG_ID, score=None), + ] + + result = apply_byok_rerank_boost(chunks, boost=2.0) + + assert len(result) == 3 + + # Chunks with None scores should be treated as negative infinity for sorting + # but actual score calculation should handle None -> float("-inf") conversion + byok_with_score = next(c for c in result if c.content == "BYOK with score") + assert byok_with_score.score == 1.6 # 0.8 * 2.0 + + def test_preserves_chunk_attributes(self) -> None: + """Test that chunk attributes are preserved during boosting.""" + chunks = [ + RAGChunk( + content="Test content", + source="byok_store", + score=0.8, + attributes={"title": "Test Doc", "url": "http://example.com"}, + ) + ] + + result = apply_byok_rerank_boost(chunks, boost=1.5) + + assert len(result) == 1 + assert result[0].content == "Test content" + assert result[0].source == "byok_store" + assert abs(result[0].score - 1.2) < 1e-10 # 0.8 * 1.5 + assert result[0].attributes == { + "title": "Test Doc", + "url": "http://example.com", + }