Source code for langchain.retrievers.document_compressors.chain_filter
"""Filter that uses an LLM to drop documents that aren't relevant to the query."""fromcollections.abcimportSequencefromtypingimportAny,Callable,Optionalfromlangchain_core.callbacksimportCallbacksfromlangchain_core.documentsimportBaseDocumentCompressor,Documentfromlangchain_core.language_modelsimportBaseLanguageModelfromlangchain_core.output_parsersimportStrOutputParserfromlangchain_core.promptsimportBasePromptTemplate,PromptTemplatefromlangchain_core.runnablesimportRunnablefromlangchain_core.runnables.configimportRunnableConfigfrompydanticimportConfigDictfromlangchain.chainsimportLLMChainfromlangchain.output_parsers.booleanimportBooleanOutputParserfromlangchain.retrievers.document_compressors.chain_filter_promptimport(prompt_template,)def_get_default_chain_prompt()->PromptTemplate:returnPromptTemplate(template=prompt_template,input_variables=["question","context"],output_parser=BooleanOutputParser(),)
[docs]defdefault_get_input(query:str,doc:Document)->dict[str,Any]:"""Return the compression chain input."""return{"question":query,"context":doc.page_content}
[docs]classLLMChainFilter(BaseDocumentCompressor):"""Filter that drops documents that aren't relevant to the query."""llm_chain:Runnable"""LLM wrapper to use for filtering documents. The chain prompt is expected to have a BooleanOutputParser."""get_input:Callable[[str,Document],dict]=default_get_input"""Callable for constructing the chain input from the query and a Document."""model_config=ConfigDict(arbitrary_types_allowed=True,)
[docs]defcompress_documents(self,documents:Sequence[Document],query:str,callbacks:Optional[Callbacks]=None,)->Sequence[Document]:"""Filter down documents based on their relevance to the query."""filtered_docs=[]config=RunnableConfig(callbacks=callbacks)outputs=zip(self.llm_chain.batch([self.get_input(query,doc)fordocindocuments],config=config,),documents,)foroutput_,docinoutputs:include_doc=Noneifisinstance(self.llm_chain,LLMChain):output=output_[self.llm_chain.output_key]ifself.llm_chain.prompt.output_parserisnotNone:include_doc=self.llm_chain.prompt.output_parser.parse(output)elifisinstance(output_,bool):include_doc=output_ifinclude_doc:filtered_docs.append(doc)returnfiltered_docs
[docs]asyncdefacompress_documents(self,documents:Sequence[Document],query:str,callbacks:Optional[Callbacks]=None,)->Sequence[Document]:"""Filter down documents based on their relevance to the query."""filtered_docs=[]config=RunnableConfig(callbacks=callbacks)outputs=zip(awaitself.llm_chain.abatch([self.get_input(query,doc)fordocindocuments],config=config,),documents,)foroutput_,docinoutputs:include_doc=Noneifisinstance(self.llm_chain,LLMChain):output=output_[self.llm_chain.output_key]ifself.llm_chain.prompt.output_parserisnotNone:include_doc=self.llm_chain.prompt.output_parser.parse(output)elifisinstance(output_,bool):include_doc=output_ifinclude_doc:filtered_docs.append(doc)returnfiltered_docs
[docs]@classmethoddeffrom_llm(cls,llm:BaseLanguageModel,prompt:Optional[BasePromptTemplate]=None,**kwargs:Any,)->"LLMChainFilter":"""Create a LLMChainFilter from a language model. Args: llm: The language model to use for filtering. prompt: The prompt to use for the filter. kwargs: Additional arguments to pass to the constructor. Returns: A LLMChainFilter that uses the given language model. """_prompt=promptifpromptisnotNoneelse_get_default_chain_prompt()if_prompt.output_parserisnotNone:parser=_prompt.output_parserelse:parser=StrOutputParser()llm_chain=_prompt|llm|parserreturncls(llm_chain=llm_chain,**kwargs)