"""Load summarizing chains."""
from typing import Any , Mapping , Optional , Protocol
from langchain_core.callbacks import Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.summarize import map_reduce_prompt , refine_prompts , stuff_prompt
[docs]
class LoadingCallable ( Protocol ):
"""Interface for loading the combine documents chain."""
def __call__ (
self , llm : BaseLanguageModel , ** kwargs : Any
) -> BaseCombineDocumentsChain :
"""Callable to load the combine documents chain."""
def _load_stuff_chain (
llm : BaseLanguageModel ,
prompt : BasePromptTemplate = stuff_prompt . PROMPT ,
document_variable_name : str = "text" ,
verbose : Optional [ bool ] = None ,
** kwargs : Any ,
) -> StuffDocumentsChain :
llm_chain = LLMChain ( llm = llm , prompt = prompt , verbose = verbose ) # type: ignore[arg-type]
# TODO: document prompt
return StuffDocumentsChain (
llm_chain = llm_chain ,
document_variable_name = document_variable_name ,
verbose = verbose , # type: ignore[arg-type]
** kwargs ,
)
def _load_map_reduce_chain (
llm : BaseLanguageModel ,
map_prompt : BasePromptTemplate = map_reduce_prompt . PROMPT ,
combine_prompt : BasePromptTemplate = map_reduce_prompt . PROMPT ,
combine_document_variable_name : str = "text" ,
map_reduce_document_variable_name : str = "text" ,
collapse_prompt : Optional [ BasePromptTemplate ] = None ,
reduce_llm : Optional [ BaseLanguageModel ] = None ,
collapse_llm : Optional [ BaseLanguageModel ] = None ,
verbose : Optional [ bool ] = None ,
token_max : int = 3000 ,
callbacks : Callbacks = None ,
* ,
collapse_max_retries : Optional [ int ] = None ,
** kwargs : Any ,
) -> MapReduceDocumentsChain :
map_chain = LLMChain (
llm = llm ,
prompt = map_prompt ,
verbose = verbose , # type: ignore[arg-type]
callbacks = callbacks , # type: ignore[arg-type]
)
_reduce_llm = reduce_llm or llm
reduce_chain = LLMChain (
llm = _reduce_llm ,
prompt = combine_prompt ,
verbose = verbose , # type: ignore[arg-type]
callbacks = callbacks , # type: ignore[arg-type]
)
# TODO: document prompt
combine_documents_chain = StuffDocumentsChain (
llm_chain = reduce_chain ,
document_variable_name = combine_document_variable_name ,
verbose = verbose , # type: ignore[arg-type]
callbacks = callbacks ,
)
if collapse_prompt is None :
collapse_chain = None
if collapse_llm is not None :
raise ValueError (
"collapse_llm provided, but collapse_prompt was not: please "
"provide one or stop providing collapse_llm."
)
else :
_collapse_llm = collapse_llm or llm
collapse_chain = StuffDocumentsChain (
llm_chain = LLMChain (
llm = _collapse_llm ,
prompt = collapse_prompt ,
verbose = verbose , # type: ignore[arg-type]
callbacks = callbacks ,
),
document_variable_name = combine_document_variable_name ,
)
reduce_documents_chain = ReduceDocumentsChain (
combine_documents_chain = combine_documents_chain ,
collapse_documents_chain = collapse_chain ,
token_max = token_max ,
verbose = verbose , # type: ignore[arg-type]
callbacks = callbacks ,
collapse_max_retries = collapse_max_retries ,
)
return MapReduceDocumentsChain (
llm_chain = map_chain ,
reduce_documents_chain = reduce_documents_chain ,
document_variable_name = map_reduce_document_variable_name ,
verbose = verbose , # type: ignore[arg-type]
callbacks = callbacks ,
** kwargs ,
)
def _load_refine_chain (
llm : BaseLanguageModel ,
question_prompt : BasePromptTemplate = refine_prompts . PROMPT ,
refine_prompt : BasePromptTemplate = refine_prompts . REFINE_PROMPT ,
document_variable_name : str = "text" ,
initial_response_name : str = "existing_answer" ,
refine_llm : Optional [ BaseLanguageModel ] = None ,
verbose : Optional [ bool ] = None ,
** kwargs : Any ,
) -> RefineDocumentsChain :
initial_chain = LLMChain ( llm = llm , prompt = question_prompt , verbose = verbose ) # type: ignore[arg-type]
_refine_llm = refine_llm or llm
refine_chain = LLMChain ( llm = _refine_llm , prompt = refine_prompt , verbose = verbose ) # type: ignore[arg-type]
return RefineDocumentsChain (
initial_llm_chain = initial_chain ,
refine_llm_chain = refine_chain ,
document_variable_name = document_variable_name ,
initial_response_name = initial_response_name ,
verbose = verbose , # type: ignore[arg-type]
** kwargs ,
)
[docs]
def load_summarize_chain (
llm : BaseLanguageModel ,
chain_type : str = "stuff" ,
verbose : Optional [ bool ] = None ,
** kwargs : Any ,
) -> BaseCombineDocumentsChain :
"""Load summarizing chain.
Args:
llm: Language Model to use in the chain.
chain_type: Type of document combining chain to use. Should be one of "stuff",
"map_reduce", and "refine".
verbose: Whether chains should be run in verbose mode or not. Note that this
applies to all chains that make up the final chain.
Returns:
A chain to use for summarizing.
"""
loader_mapping : Mapping [ str , LoadingCallable ] = {
"stuff" : _load_stuff_chain ,
"map_reduce" : _load_map_reduce_chain ,
"refine" : _load_refine_chain ,
}
if chain_type not in loader_mapping :
raise ValueError (
f "Got unsupported chain type: { chain_type } . "
f "Should be one of { loader_mapping . keys () } "
)
return loader_mapping [ chain_type ]( llm , verbose = verbose , ** kwargs )
Copy to clipboard