"""Aggregation pipeline components used in Atlas Full-Text, Vector, and Hybrid Search
See the following for more:
- `Full-Text Search <https://www.mongodb.com/docs/atlas/atlas-search/aggregation-stages/search/#mongodb-pipeline-pipe.-search>`_
- `MongoDB Operators <https://www.mongodb.com/docs/atlas/atlas-search/operators-and-collectors/#std-label-operators-ref>`_
- `Vector Search <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/>`_
- `Filter Example <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter>`_
"""
from typing import Any, Dict, List, Optional
[docs]def text_search_stage(
query: str,
search_field: str,
index_name: str,
limit: Optional[int] = None,
filter: Optional[Dict[str, Any]] = None,
include_scores: Optional[bool] = True,
**kwargs: Any,
) -> List[Dict[str, Any]]: # noqa: E501
"""Full-Text search using Lucene's standard (BM25) analyzer
Args:
query: Input text to search for
search_field: Field in Collection that will be searched
index_name: Atlas Search Index name
limit: Maximum number of documents to return. Default of no limit
filter: Any MQL match expression comparing an indexed field
include_scores: Scores provide measure of relative relevance
Returns:
Dictionary defining the $search stage
"""
pipeline = [
{
"$search": {
"index": index_name,
"text": {"query": query, "path": search_field},
}
}
]
if filter:
pipeline.append({"$match": filter}) # type: ignore
if include_scores:
pipeline.append({"$set": {"score": {"$meta": "searchScore"}}})
if limit:
pipeline.append({"$limit": limit}) # type: ignore
return pipeline # type: ignore
[docs]def vector_search_stage(
query_vector: List[float],
search_field: str,
index_name: str,
top_k: int = 4,
filter: Optional[Dict[str, Any]] = None,
oversampling_factor: int = 10,
**kwargs: Any,
) -> Dict[str, Any]: # noqa: E501
"""Vector Search Stage without Scores.
Scoring is applied later depending on strategy.
vector search includes a vectorSearchScore that is typically used.
hybrid uses Reciprocal Rank Fusion.
Args:
query_vector: List of embedding vector
search_field: Field in Collection containing embedding vectors
index_name: Name of Atlas Vector Search Index tied to Collection
top_k: Number of documents to return
oversampling_factor: this times limit is the number of candidates
filter: MQL match expression comparing an indexed field.
Some operators are not supported.
See `vectorSearch filter docs <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter>`_
Returns:
Dictionary defining the $vectorSearch
"""
stage = {
"index": index_name,
"path": search_field,
"queryVector": query_vector,
"numCandidates": top_k * oversampling_factor,
"limit": top_k,
}
if filter:
stage["filter"] = filter
return {"$vectorSearch": stage}
[docs]def combine_pipelines(
pipeline: List[Any], stage: List[Dict[str, Any]], collection_name: str
) -> None:
"""Combines two aggregations into a single result set in-place."""
if pipeline:
pipeline.append({"$unionWith": {"coll": collection_name, "pipeline": stage}})
else:
pipeline.extend(stage)
[docs]def reciprocal_rank_stage(
score_field: str, penalty: float = 0, **kwargs: Any
) -> List[Dict[str, Any]]:
"""Stage adds Reciprocal Rank Fusion weighting.
First, it pushes documents retrieved from previous stage
into a temporary sub-document. It then unwinds to establish
the rank to each and applies the penalty.
Args:
score_field: A unique string to identify the search being ranked
penalty: A non-negative float.
extra_fields: Any fields other than text_field that one wishes to keep.
Returns:
RRF score := \frac{1}{rank + penalty} with rank in [1,2,..,n]
"""
rrf_pipeline = [
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
{
"$addFields": {
f"docs.{score_field}": {
"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]
},
"docs.rank": "$rank",
"_id": "$docs._id",
}
},
{"$replaceRoot": {"newRoot": "$docs"}},
]
return rrf_pipeline # type: ignore
[docs]def final_hybrid_stage(
scores_fields: List[str], limit: int, **kwargs: Any
) -> List[Dict[str, Any]]:
"""Sum weighted scores, sort, and apply limit.
Args:
scores_fields: List of fields given to scores of vector and text searches
limit: Number of documents to return
Returns:
Final aggregation stages
"""
return [
{"$group": {"_id": "$_id", "docs": {"$mergeObjects": "$$ROOT"}}},
{"$replaceRoot": {"newRoot": "$docs"}},
{"$set": {score: {"$ifNull": [f"${score}", 0]} for score in scores_fields}},
{"$addFields": {"score": {"$add": [f"${score}" for score in scores_fields]}}},
{"$sort": {"score": -1}},
{"$limit": limit},
]