from typing import Any, Dict, List, Optional, Union
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from pymilvus import AnnSearchRequest, Collection
from pymilvus.client.abstract import BaseRanker, SearchResult # type: ignore
from langchain_milvus.utils.sparse import BaseSparseEmbedding
[docs]class MilvusCollectionHybridSearchRetriever(BaseRetriever):
"""Hybrid search retriever
that uses Milvus Collection to retrieve documents based on multiple fields.
For more information, please refer to:
https://milvus.io/docs/release_notes.md#Multi-Embedding---Hybrid-Search
"""
collection: Collection
"""Milvus Collection object."""
rerank: BaseRanker
"""Milvus ranker object. Such as WeightedRanker or RRFRanker."""
anns_fields: List[str]
"""The names of vector fields that are used for ANNS search."""
field_embeddings: List[Union[Embeddings, BaseSparseEmbedding]]
"""The embedding functions of each vector fields,
which can be either Embeddings or BaseSparseEmbedding."""
field_search_params: Optional[List[Dict]] = None
"""The search parameters of each vector fields.
If not specified, the default search parameters will be used."""
field_limits: Optional[List[int]] = None
"""Limit number of results for each ANNS field.
If not specified, the default top_k will be used."""
field_exprs: Optional[List[Optional[str]]] = None
"""The boolean expression for filtering the search results."""
top_k: int = 4
"""Final top-K number of documents to retrieve."""
text_field: str = "text"
"""The text field name,
which will be used as the `page_content` of a `Document` object."""
output_fields: Optional[List[str]] = None
"""Final output fields of the documents.
If not specified, all fields except the vector fields will be used as output fields,
which will be the `metadata` of a `Document` object."""
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
# If some parameters are not specified, set default values
if self.field_search_params is None:
default_search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}
self.field_search_params = [default_search_params] * len(self.anns_fields)
if self.field_limits is None:
self.field_limits = [self.top_k] * len(self.anns_fields)
if self.field_exprs is None:
self.field_exprs = [None] * len(self.anns_fields)
# Check the fields
self._validate_fields_num()
self.output_fields = self._get_output_fields()
self._validate_fields_name()
# Load collection
self.collection.load()
def _validate_fields_num(self) -> None:
assert (
len(self.anns_fields) >= 2
), "At least two fields are required for hybrid search."
lengths = [len(self.anns_fields)]
if self.field_limits is not None:
lengths.append(len(self.field_limits))
if self.field_exprs is not None:
lengths.append(len(self.field_exprs))
if not all(length == lengths[0] for length in lengths):
raise ValueError("All field-related lists must have the same length.")
if len(self.field_search_params) != len(self.anns_fields): # type: ignore[arg-type]
raise ValueError(
"field_search_params must have the same length as anns_fields."
)
def _validate_fields_name(self) -> None:
collection_fields = [x.name for x in self.collection.schema.fields]
for field in self.anns_fields:
assert (
field in collection_fields
), f"{field} is not a valid field in the collection."
assert (
self.text_field in collection_fields
), f"{self.text_field} is not a valid field in the collection."
for field in self.output_fields: # type: ignore[union-attr]
assert (
field in collection_fields
), f"{field} is not a valid field in the collection."
def _get_output_fields(self) -> List[str]:
if self.output_fields:
return self.output_fields
output_fields = [x.name for x in self.collection.schema.fields]
for field in self.anns_fields:
if field in output_fields:
output_fields.remove(field)
if self.text_field not in output_fields:
output_fields.append(self.text_field)
return output_fields
def _build_ann_search_requests(self, query: str) -> List[AnnSearchRequest]:
search_requests = []
for ann_field, embedding, param, limit, expr in zip(
self.anns_fields,
self.field_embeddings,
self.field_search_params, # type: ignore[arg-type]
self.field_limits, # type: ignore[arg-type]
self.field_exprs, # type: ignore[arg-type]
):
request = AnnSearchRequest(
data=[embedding.embed_query(query)],
anns_field=ann_field,
param=param,
limit=limit,
expr=expr,
)
search_requests.append(request)
return search_requests
def _parse_document(self, data: dict) -> Document:
return Document(
page_content=data.pop(self.text_field),
metadata=data,
)
def _process_search_result(
self, search_results: List[SearchResult]
) -> List[Document]:
documents = []
for result in search_results[0]:
data = {x: result.entity.get(x) for x in self.output_fields} # type: ignore[union-attr]
doc = self._parse_document(data)
documents.append(doc)
return documents
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
requests = self._build_ann_search_requests(query)
search_result = self.collection.hybrid_search(
requests, self.rerank, limit=self.top_k, output_fields=self.output_fields
)
documents = self._process_search_result(search_result)
return documents