Source code for langchain.retrievers.document_compressors.chain_extract

"""DocumentFilter that uses an LLM chain to extract the relevant parts of documents."""

from __future__ import annotations

from typing import Any, Callable, Dict, Optional, Sequence, cast

from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import Runnable
from pydantic import ConfigDict

from langchain.chains.llm import LLMChain
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from langchain.retrievers.document_compressors.chain_extract_prompt import (
    prompt_template,
)


[docs] def default_get_input(query: str, doc: Document) -> Dict[str, Any]: """Return the compression chain input.""" return {"question": query, "context": doc.page_content}
[docs] class NoOutputParser(BaseOutputParser[str]): """Parse outputs that could return a null string of some sort.""" no_output_str: str = "NO_OUTPUT"
[docs] def parse(self, text: str) -> str: cleaned_text = text.strip() if cleaned_text == self.no_output_str: return "" return cleaned_text
def _get_default_chain_prompt() -> PromptTemplate: output_parser = NoOutputParser() template = prompt_template.format(no_output_str=output_parser.no_output_str) return PromptTemplate( template=template, input_variables=["question", "context"], output_parser=output_parser, )
[docs] class LLMChainExtractor(BaseDocumentCompressor): """Document compressor that uses an LLM chain to extract the relevant parts of documents.""" llm_chain: Runnable """LLM wrapper to use for compressing documents.""" get_input: Callable[[str, Document], dict] = default_get_input """Callable for constructing the chain input from the query and a Document.""" model_config = ConfigDict( arbitrary_types_allowed=True, )
[docs] def compress_documents( self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """Compress page content of raw documents.""" compressed_docs = [] for doc in documents: _input = self.get_input(query, doc) output_ = self.llm_chain.invoke(_input, config={"callbacks": callbacks}) if isinstance(self.llm_chain, LLMChain): output = output_[self.llm_chain.output_key] if self.llm_chain.prompt.output_parser is not None: output = self.llm_chain.prompt.output_parser.parse(output) else: output = output_ if len(output) == 0: continue compressed_docs.append( Document(page_content=cast(str, output), metadata=doc.metadata) ) return compressed_docs
[docs] async def acompress_documents( self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """Compress page content of raw documents asynchronously.""" inputs = [self.get_input(query, doc) for doc in documents] outputs = await self.llm_chain.abatch(inputs, {"callbacks": callbacks}) compressed_docs = [] for i, doc in enumerate(documents): if len(outputs[i]) == 0: continue compressed_docs.append( Document(page_content=outputs[i], metadata=doc.metadata) # type: ignore[arg-type] ) return compressed_docs
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, prompt: Optional[PromptTemplate] = None, get_input: Optional[Callable[[str, Document], str]] = None, llm_chain_kwargs: Optional[dict] = None, ) -> LLMChainExtractor: """Initialize from LLM.""" _prompt = prompt if prompt is not None else _get_default_chain_prompt() _get_input = get_input if get_input is not None else default_get_input if _prompt.output_parser is not None: parser = _prompt.output_parser else: parser = StrOutputParser() llm_chain = _prompt | llm | parser return cls(llm_chain=llm_chain, get_input=_get_input) # type: ignore[arg-type]