Source code for langchain.retrievers.document_compressors.embeddings_filter
fromcollections.abcimportSequencefromtypingimportCallable,Optionalfromlangchain_core.callbacksimportCallbacksfromlangchain_core.documentsimportBaseDocumentCompressor,Documentfromlangchain_core.embeddingsimportEmbeddingsfromlangchain_core.utilsimportpre_initfrompydanticimportConfigDict,Fielddef_get_similarity_function()->Callable:try:fromlangchain_community.utils.mathimportcosine_similarityexceptImportError:raiseImportError("To use please install langchain-community ""with `pip install langchain-community`.")returncosine_similarity
[docs]classEmbeddingsFilter(BaseDocumentCompressor):"""Document compressor that uses embeddings to drop documents unrelated to the query."""embeddings:Embeddings"""Embeddings to use for embedding document contents and queries."""similarity_fn:Callable=Field(default_factory=_get_similarity_function)"""Similarity function for comparing documents. Function expected to take as input two matrices (List[List[float]]) and return a matrix of scores where higher values indicate greater similarity."""k:Optional[int]=20"""The number of relevant documents to return. Can be set to None, in which case `similarity_threshold` must be specified. Defaults to 20."""similarity_threshold:Optional[float]=None"""Threshold for determining when two documents are similar enough to be considered redundant. Defaults to None, must be specified if `k` is set to None."""model_config=ConfigDict(arbitrary_types_allowed=True,)
[docs]@pre_initdefvalidate_params(cls,values:dict)->dict:"""Validate similarity parameters."""ifvalues["k"]isNoneandvalues["similarity_threshold"]isNone:raiseValueError("Must specify one of `k` or `similarity_threshold`.")returnvalues
[docs]defcompress_documents(self,documents:Sequence[Document],query:str,callbacks:Optional[Callbacks]=None,)->Sequence[Document]:"""Filter documents based on similarity of their embeddings to the query."""try:fromlangchain_community.document_transformers.embeddings_redundant_filterimport(# noqa: E501_get_embeddings_from_stateful_docs,get_stateful_documents,)exceptImportError:raiseImportError("To use please install langchain-community ""with `pip install langchain-community`.")try:importnumpyasnpexceptImportErrorase:raiseImportError("Could not import numpy, please install with `pip install numpy`.")fromestateful_documents=get_stateful_documents(documents)embedded_documents=_get_embeddings_from_stateful_docs(self.embeddings,stateful_documents)embedded_query=self.embeddings.embed_query(query)similarity=self.similarity_fn([embedded_query],embedded_documents)[0]included_idxs:np.ndarray=np.arange(len(embedded_documents))ifself.kisnotNone:included_idxs=np.argsort(similarity)[::-1][:self.k]ifself.similarity_thresholdisnotNone:similar_enough=np.where(similarity[included_idxs]>self.similarity_threshold)included_idxs=included_idxs[similar_enough]foriinincluded_idxs:stateful_documents[i].state["query_similarity_score"]=similarity[i]return[stateful_documents[i]foriinincluded_idxs]
[docs]asyncdefacompress_documents(self,documents:Sequence[Document],query:str,callbacks:Optional[Callbacks]=None,)->Sequence[Document]:"""Filter documents based on similarity of their embeddings to the query."""try:fromlangchain_community.document_transformers.embeddings_redundant_filterimport(# noqa: E501_aget_embeddings_from_stateful_docs,get_stateful_documents,)exceptImportError:raiseImportError("To use please install langchain-community ""with `pip install langchain-community`.")try:importnumpyasnpexceptImportErrorase:raiseImportError("Could not import numpy, please install with `pip install numpy`.")fromestateful_documents=get_stateful_documents(documents)embedded_documents=await_aget_embeddings_from_stateful_docs(self.embeddings,stateful_documents)embedded_query=awaitself.embeddings.aembed_query(query)similarity=self.similarity_fn([embedded_query],embedded_documents)[0]included_idxs:np.ndarray=np.arange(len(embedded_documents))ifself.kisnotNone:included_idxs=np.argsort(similarity)[::-1][:self.k]ifself.similarity_thresholdisnotNone:similar_enough=np.where(similarity[included_idxs]>self.similarity_threshold)included_idxs=included_idxs[similar_enough]foriinincluded_idxs:stateful_documents[i].state["query_similarity_score"]=similarity[i]return[stateful_documents[i]foriinincluded_idxs]