Source code for langchain_mongodb.retrievers.hybrid_search
from typing import Any, Dict, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pymongo.collection import Collection
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.pipelines import (
combine_pipelines,
final_hybrid_stage,
reciprocal_rank_stage,
text_search_stage,
vector_search_stage,
)
from langchain_mongodb.utils import make_serializable
[docs]
class MongoDBAtlasHybridSearchRetriever(BaseRetriever):
"""Hybrid Search Retriever combines vector and full-text searches
weighting them the via Reciprocal Rank Fusion (RRF) algorithm.
Increasing the vector_penalty will reduce the importance on the vector search.
Increasing the fulltext_penalty will correspondingly reduce the fulltext score.
For more on the algorithm,see
https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking
"""
vectorstore: MongoDBAtlasVectorSearch
"""MongoDBAtlas VectorStore"""
search_index_name: str
"""Atlas Search Index (full-text) name"""
top_k: int = 4
"""Number of documents to return."""
oversampling_factor: int = 10
"""This times top_k is the number of candidates chosen at each step"""
pre_filter: Optional[Dict[str, Any]] = None
"""(Optional) Any MQL match expression comparing an indexed field"""
post_filter: Optional[List[Dict[str, Any]]] = None
"""(Optional) Pipeline of MongoDB aggregation stages for postprocessing."""
vector_penalty: float = 60.0
"""Penalty applied to vector search results in RRF: scores=1/(rank + penalty)"""
fulltext_penalty: float = 60.0
"""Penalty applied to full-text search results in RRF: scores=1/(rank + penalty)"""
show_embeddings: float = False
"""If true, returned Document metadata will include vectors."""
@property
def collection(self) -> Collection:
return self.vectorstore._collection
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Retrieve documents that are highest scoring / most similar to query.
Note that the same query is used in both searches,
embedded for vector search, and as-is for full-text search.
Args:
query: String to find relevant documents for
run_manager: The callback handler to use
Returns:
List of relevant documents
"""
query_vector = self.vectorstore._embedding.embed_query(query)
scores_fields = ["vector_score", "fulltext_score"]
pipeline: List[Any] = []
# First we build up the aggregation pipeline,
# then it is passed to the server to execute
# Vector Search stage
vector_pipeline = [
vector_search_stage(
query_vector=query_vector,
search_field=self.vectorstore._embedding_key,
index_name=self.vectorstore._index_name,
top_k=self.top_k,
filter=self.pre_filter,
oversampling_factor=self.oversampling_factor,
)
]
vector_pipeline += reciprocal_rank_stage("vector_score", self.vector_penalty)
combine_pipelines(pipeline, vector_pipeline, self.collection.name)
# Full-Text Search stage
text_pipeline = text_search_stage(
query=query,
search_field=self.vectorstore._text_key,
index_name=self.search_index_name,
limit=self.top_k,
filter=self.pre_filter,
)
text_pipeline.extend(
reciprocal_rank_stage("fulltext_score", self.fulltext_penalty)
)
combine_pipelines(pipeline, text_pipeline, self.collection.name)
# Sum and sort stage
pipeline.extend(
final_hybrid_stage(scores_fields=scores_fields, limit=self.top_k)
)
# Removal of embeddings unless requested.
if not self.show_embeddings:
pipeline.append({"$project": {self.vectorstore._embedding_key: 0}})
# Post filtering
if self.post_filter is not None:
pipeline.extend(self.post_filter)
# Execution
cursor = self.collection.aggregate(pipeline) # type: ignore[arg-type]
# Formatting
docs = []
for res in cursor:
text = res.pop(self.vectorstore._text_key)
# score = res.pop("score") # The score remains buried!
make_serializable(res)
docs.append(Document(page_content=text, metadata=res))
return docs