Source code for langchain_cohere.chains.summarize.summarize_chain
"""Load summarizing chains."""
from typing import Any, Callable, Dict, List, Optional, Union
from langchain_core._api import beta
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.prompts import (
BasePromptTemplate,
ChatPromptTemplate,
)
from langchain_core.prompts.chat import (
BaseMessagePromptTemplate,
)
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables.base import RunnableLambda, RunnableSerializable
from langchain_cohere.chains.summarize.prompt import RAG_SUMMARIZATION_PREAMBLE
from langchain_cohere.chat_models import ChatCohere
[docs]
def create_summarize_prompt(
prompt_message: BaseMessage = HumanMessage(
content="Please summarize the documents in a concise manner."
),
extra_prompt_messages: List[BaseMessagePromptTemplate] = [],
) -> ChatPromptTemplate:
"""Create prompt for this agent.
Args:
system_message: Message to use as the system message that will be the
first in the prompt.
extra_prompt_messages: Prompt messages that will be placed between the
system message and the new human input.
Returns:
A prompt template to pass into this agent.
"""
extra_prompt_messages = extra_prompt_messages or []
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
if prompt_message:
messages = [prompt_message]
else:
messages = [prompt_message] + extra_prompt_messages
return ChatPromptTemplate(messages=messages)
def _load_stuff_chain(
llm: BaseLanguageModel,
prompt: Optional[BasePromptTemplate] = None,
) -> RunnableSerializable:
if "preamble" in llm.__dict__ and not llm.__dict__.get("preamble"):
llm = ChatCohere(**llm.__dict__)
llm.preamble = RAG_SUMMARIZATION_PREAMBLE
if not prompt:
prompt = create_summarize_prompt()
def llm_with_docs(input_: dict) -> RunnableSerializable[Any, Any]:
docs = input_["documents"]
return RunnableLambda(lambda x: x["input"]) | llm.bind(documents=docs)
runnable = (
RunnablePassthrough.assign(
documents=lambda x: x["documents"],
input=lambda x: prompt.format_prompt(**x), # type: ignore[union-attr]
)
| llm_with_docs
)
return runnable
[docs]
@beta(
message="""Makes use of Cohere's grounded RAG summarization,
which may change in a later langchain-cohere version"""
)
def load_summarize_chain(
llm: BaseLanguageModel,
chain_type: str = "stuff",
**kwargs: Any,
) -> RunnableSerializable:
"""Load summarizing chain.
Args:
llm: Language Model to use in the chain.
chain_type: Type of document combining chain to use. Currently, only "stuff"
is supported in this implementation.
verbose: Whether chains should be run in verbose mode or not. Note that this
applies to all chains that make up the final chain.
Returns:
A chain to use for summarizing.
"""
loader_mapping: Dict[
str,
Callable[
[BaseLanguageModel[Any], BasePromptTemplate[Any]],
RunnableSerializable[Any, Any],
],
] = {
"stuff": _load_stuff_chain,
}
if chain_type not in loader_mapping:
raise ValueError(
f"Got unsupported chain type: {chain_type}. "
f"Should be one of {loader_mapping.keys()}"
)
return loader_mapping[chain_type](llm, **kwargs) # type: ignore[call-arg]