Source code for langchain.retrievers.multi_vector

from enum import Enum
from typing import Dict, List, Optional

from langchain_core.callbacks import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.stores import BaseStore, ByteStore
from langchain_core.vectorstores import VectorStore

from langchain.storage._lc_store import create_kv_docstore


[docs]class SearchType(str, Enum): """Enumerator of the types of search to perform.""" similarity = "similarity" """Similarity search.""" similarity_score_threshold = "similarity_score_threshold" """Similarity search with a score threshold.""" mmr = "mmr" """Maximal Marginal Relevance reranking of similarity search."""
[docs]class MultiVectorRetriever(BaseRetriever): """Retrieve from a set of multiple embeddings for the same document.""" vectorstore: VectorStore """The underlying vectorstore to use to store small chunks and their embedding vectors""" byte_store: Optional[ByteStore] = None """The lower-level backing storage layer for the parent documents""" docstore: BaseStore[str, Document] """The storage interface for the parent documents""" id_key: str = "doc_id" search_kwargs: dict = Field(default_factory=dict) """Keyword arguments to pass to the search function.""" search_type: SearchType = SearchType.similarity """Type of search to perform (similarity / mmr)""" @root_validator(pre=True) def shim_docstore(cls, values: Dict) -> Dict: byte_store = values.get("byte_store") docstore = values.get("docstore") if byte_store is not None: docstore = create_kv_docstore(byte_store) elif docstore is None: raise Exception("You must pass a `byte_store` parameter.") values["docstore"] = docstore return values def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """Get documents relevant to a query. Args: query: String to find relevant documents for run_manager: The callbacks handler to use Returns: List of relevant documents """ if self.search_type == SearchType.mmr: sub_docs = self.vectorstore.max_marginal_relevance_search( query, **self.search_kwargs ) elif self.search_type == SearchType.similarity_score_threshold: sub_docs_and_similarities = ( self.vectorstore.similarity_search_with_relevance_scores( query, **self.search_kwargs ) ) sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities] else: sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs) # We do this to maintain the order of the ids that are returned ids = [] for d in sub_docs: if self.id_key in d.metadata and d.metadata[self.id_key] not in ids: ids.append(d.metadata[self.id_key]) docs = self.docstore.mget(ids) return [d for d in docs if d is not None] async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: """Asynchronously get documents relevant to a query. Args: query: String to find relevant documents for run_manager: The callbacks handler to use Returns: List of relevant documents """ if self.search_type == SearchType.mmr: sub_docs = await self.vectorstore.amax_marginal_relevance_search( query, **self.search_kwargs ) elif self.search_type == SearchType.similarity_score_threshold: sub_docs_and_similarities = ( await self.vectorstore.asimilarity_search_with_relevance_scores( query, **self.search_kwargs ) ) sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities] else: sub_docs = await self.vectorstore.asimilarity_search( query, **self.search_kwargs ) # We do this to maintain the order of the ids that are returned ids = [] for d in sub_docs: if self.id_key in d.metadata and d.metadata[self.id_key] not in ids: ids.append(d.metadata[self.id_key]) docs = await self.docstore.amget(ids) return [d for d in docs if d is not None]