Source code for langchain_cohere.chains.summarize.summarize_chain

"""Load summarizing chains."""
from typing import Any, Callable, Dict, List, Optional, Union

from langchain_core._api import beta
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.prompts import (
    BasePromptTemplate,
    ChatPromptTemplate,
)
from langchain_core.prompts.chat import (
    BaseMessagePromptTemplate,
)
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables.base import RunnableLambda, RunnableSerializable

from langchain_cohere.chains.summarize.prompt import RAG_SUMMARIZATION_PREAMBLE
from langchain_cohere.chat_models import ChatCohere


[docs] def create_summarize_prompt( prompt_message: BaseMessage = HumanMessage( content="Please summarize the documents in a concise manner." ), extra_prompt_messages: List[BaseMessagePromptTemplate] = [], ) -> ChatPromptTemplate: """Create prompt for this agent. Args: system_message: Message to use as the system message that will be the first in the prompt. extra_prompt_messages: Prompt messages that will be placed between the system message and the new human input. Returns: A prompt template to pass into this agent. """ extra_prompt_messages = extra_prompt_messages or [] messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] if prompt_message: messages = [prompt_message] else: messages = [prompt_message] + extra_prompt_messages return ChatPromptTemplate(messages=messages)
def _load_stuff_chain( llm: BaseLanguageModel, prompt: Optional[BasePromptTemplate] = None, ) -> RunnableSerializable: if "preamble" in llm.__dict__ and not llm.__dict__.get("preamble"): llm = ChatCohere(**llm.__dict__) llm.preamble = RAG_SUMMARIZATION_PREAMBLE if not prompt: prompt = create_summarize_prompt() def llm_with_docs(input_: dict) -> RunnableSerializable[Any, Any]: docs = input_["documents"] return RunnableLambda(lambda x: x["input"]) | llm.bind(documents=docs) runnable = ( RunnablePassthrough.assign( documents=lambda x: x["documents"], input=lambda x: prompt.format_prompt(**x), # type: ignore[union-attr] ) | llm_with_docs ) return runnable
[docs] @beta( message="""Makes use of Cohere's grounded RAG summarization, which may change in a later langchain-cohere version""" ) def load_summarize_chain( llm: BaseLanguageModel, chain_type: str = "stuff", **kwargs: Any, ) -> RunnableSerializable: """Load summarizing chain. Args: llm: Language Model to use in the chain. chain_type: Type of document combining chain to use. Currently, only "stuff" is supported in this implementation. verbose: Whether chains should be run in verbose mode or not. Note that this applies to all chains that make up the final chain. Returns: A chain to use for summarizing. """ loader_mapping: Dict[ str, Callable[ [BaseLanguageModel[Any], BasePromptTemplate[Any]], RunnableSerializable[Any, Any], ], ] = { "stuff": _load_stuff_chain, } if chain_type not in loader_mapping: raise ValueError( f"Got unsupported chain type: {chain_type}. " f"Should be one of {loader_mapping.keys()}" ) return loader_mapping[chain_type](llm, **kwargs) # type: ignore[call-arg]