Source code for langchain.retrievers.document_compressors.listwise_rerank

"""Filter that uses an LLM to rerank documents listwise and select top-k."""

from typing import Any, Dict, List, Optional, Sequence

from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from pydantic import BaseModel, ConfigDict, Field

_default_system_tmpl = """{context}

Sort the Documents by their relevance to the Query."""
_DEFAULT_PROMPT = ChatPromptTemplate.from_messages(
    [("system", _default_system_tmpl), ("human", "{query}")],
)


def _get_prompt_input(input_: dict) -> Dict[str, Any]:
    """Return the compression chain input."""
    documents = input_["documents"]
    context = ""
    for index, doc in enumerate(documents):
        context += f"Document ID: {index}\n```{doc.page_content}```\n\n"
    context += f"Documents = [Document ID: 0, ..., Document ID: {len(documents) - 1}]"
    return {"query": input_["query"], "context": context}


def _parse_ranking(results: dict) -> List[Document]:
    ranking = results["ranking"]
    docs = results["documents"]
    return [docs[i] for i in ranking.ranked_document_ids]


[docs] class LLMListwiseRerank(BaseDocumentCompressor): """Document compressor that uses `Zero-Shot Listwise Document Reranking`. Adapted from: https://arxiv.org/pdf/2305.02156.pdf ``LLMListwiseRerank`` uses a language model to rerank a list of documents based on their relevance to a query. **NOTE**: requires that underlying model implement ``with_structured_output``. Example usage: .. code-block:: python from langchain.retrievers.document_compressors.listwise_rerank import ( LLMListwiseRerank, ) from langchain_core.documents import Document from langchain_openai import ChatOpenAI documents = [ Document("Sally is my friend from school"), Document("Steve is my friend from home"), Document("I didn't always like yogurt"), Document("I wonder why it's called football"), Document("Where's waldo"), ] reranker = LLMListwiseRerank.from_llm( llm=ChatOpenAI(model="gpt-3.5-turbo"), top_n=3 ) compressed_docs = reranker.compress_documents(documents, "Who is steve") assert len(compressed_docs) == 3 assert "Steve" in compressed_docs[0].page_content """ reranker: Runnable[Dict, List[Document]] """LLM-based reranker to use for filtering documents. Expected to take in a dict with 'documents: Sequence[Document]' and 'query: str' keys and output a List[Document].""" top_n: int = 3 """Number of documents to return.""" model_config = ConfigDict( arbitrary_types_allowed=True, )
[docs] def compress_documents( self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """Filter down documents based on their relevance to the query.""" results = self.reranker.invoke( {"documents": documents, "query": query}, config={"callbacks": callbacks} ) return results[: self.top_n]
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, *, prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> "LLMListwiseRerank": """Create a LLMListwiseRerank document compressor from a language model. Args: llm: The language model to use for filtering. **Must implement BaseLanguageModel.with_structured_output().** prompt: The prompt to use for the filter. kwargs: Additional arguments to pass to the constructor. Returns: A LLMListwiseRerank document compressor that uses the given language model. """ if llm.with_structured_output == BaseLanguageModel.with_structured_output: raise ValueError( f"llm of type {type(llm)} does not implement `with_structured_output`." ) class RankDocuments(BaseModel): """Rank the documents by their relevance to the user question. Rank from most to least relevant.""" ranked_document_ids: List[int] = Field( ..., description=( "The integer IDs of the documents, sorted from most to least " "relevant to the user question." ), ) _prompt = prompt if prompt is not None else _DEFAULT_PROMPT reranker = RunnablePassthrough.assign( ranking=RunnableLambda(_get_prompt_input) | _prompt | llm.with_structured_output(RankDocuments) ) | RunnableLambda(_parse_ranking) return cls(reranker=reranker, **kwargs)