Source code for langchain.retrievers.document_compressors.listwise_rerank
"""Filter that uses an LLM to rerank documents listwise and select top-k."""fromtypingimportAny,Dict,List,Optional,Sequencefromlangchain_core.callbacksimportCallbacksfromlangchain_core.documentsimportBaseDocumentCompressor,Documentfromlangchain_core.language_modelsimportBaseLanguageModelfromlangchain_core.promptsimportBasePromptTemplate,ChatPromptTemplatefromlangchain_core.runnablesimportRunnable,RunnableLambda,RunnablePassthroughfrompydanticimportBaseModel,ConfigDict,Field_default_system_tmpl="""{context}Sort the Documents by their relevance to the Query."""_DEFAULT_PROMPT=ChatPromptTemplate.from_messages([("system",_default_system_tmpl),("human","{query}")],)def_get_prompt_input(input_:dict)->Dict[str,Any]:"""Return the compression chain input."""documents=input_["documents"]context=""forindex,docinenumerate(documents):context+=f"Document ID: {index}\n```{doc.page_content}```\n\n"context+=f"Documents = [Document ID: 0, ..., Document ID: {len(documents)-1}]"return{"query":input_["query"],"context":context}def_parse_ranking(results:dict)->List[Document]:ranking=results["ranking"]docs=results["documents"]return[docs[i]foriinranking.ranked_document_ids]
[docs]classLLMListwiseRerank(BaseDocumentCompressor):"""Document compressor that uses `Zero-Shot Listwise Document Reranking`. Adapted from: https://arxiv.org/pdf/2305.02156.pdf ``LLMListwiseRerank`` uses a language model to rerank a list of documents based on their relevance to a query. **NOTE**: requires that underlying model implement ``with_structured_output``. Example usage: .. code-block:: python from langchain.retrievers.document_compressors.listwise_rerank import ( LLMListwiseRerank, ) from langchain_core.documents import Document from langchain_openai import ChatOpenAI documents = [ Document("Sally is my friend from school"), Document("Steve is my friend from home"), Document("I didn't always like yogurt"), Document("I wonder why it's called football"), Document("Where's waldo"), ] reranker = LLMListwiseRerank.from_llm( llm=ChatOpenAI(model="gpt-3.5-turbo"), top_n=3 ) compressed_docs = reranker.compress_documents(documents, "Who is steve") assert len(compressed_docs) == 3 assert "Steve" in compressed_docs[0].page_content """reranker:Runnable[Dict,List[Document]]"""LLM-based reranker to use for filtering documents. Expected to take in a dict with 'documents: Sequence[Document]' and 'query: str' keys and output a List[Document]."""top_n:int=3"""Number of documents to return."""model_config=ConfigDict(arbitrary_types_allowed=True,)
[docs]defcompress_documents(self,documents:Sequence[Document],query:str,callbacks:Optional[Callbacks]=None,)->Sequence[Document]:"""Filter down documents based on their relevance to the query."""results=self.reranker.invoke({"documents":documents,"query":query},config={"callbacks":callbacks})returnresults[:self.top_n]
[docs]@classmethoddeffrom_llm(cls,llm:BaseLanguageModel,*,prompt:Optional[BasePromptTemplate]=None,**kwargs:Any,)->"LLMListwiseRerank":"""Create a LLMListwiseRerank document compressor from a language model. Args: llm: The language model to use for filtering. **Must implement BaseLanguageModel.with_structured_output().** prompt: The prompt to use for the filter. kwargs: Additional arguments to pass to the constructor. Returns: A LLMListwiseRerank document compressor that uses the given language model. """ifllm.with_structured_output==BaseLanguageModel.with_structured_output:raiseValueError(f"llm of type {type(llm)} does not implement `with_structured_output`.")classRankDocuments(BaseModel):"""Rank the documents by their relevance to the user question. Rank from most to least relevant."""ranked_document_ids:List[int]=Field(...,description=("The integer IDs of the documents, sorted from most to least ""relevant to the user question."),)_prompt=promptifpromptisnotNoneelse_DEFAULT_PROMPTreranker=RunnablePassthrough.assign(ranking=RunnableLambda(_get_prompt_input)|_prompt|llm.with_structured_output(RankDocuments))|RunnableLambda(_parse_ranking)returncls(reranker=reranker,**kwargs)