"""
Cohere multi-hop agent enables multiple tools to be used in sequence to complete a
task.
This agent uses a multi hop prompt by Cohere, which is experimental and subject
to change. The latest prompt can be used by upgrading the langchain-cohere package.
"""
from typing import Any, Dict, List, MutableMapping, Optional, Sequence, Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableParallel,
RunnablePassthrough,
)
from langchain_core.tools import BaseTool
from langchain_cohere.react_multi_hop.parsing import (
GROUNDED_ANSWER_KEY,
OUTPUT_KEY,
CohereToolsReactAgentOutputParser,
parse_citations,
)
from langchain_cohere.react_multi_hop.prompt import (
convert_to_documents,
multi_hop_prompt,
)
[docs]
def create_cohere_react_agent(
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: ChatPromptTemplate,
) -> Runnable:
"""
Create an agent that enables multiple tools to be used in sequence to complete
a task.
Args:
llm: The ChatCohere LLM instance to use.
tools: Tools this agent has access to.
prompt: The prompt to use.
Returns:
A Runnable sequence representing an agent. It takes as input all the same input
variables as the prompt passed in does and returns a List[AgentAction] or a
single AgentFinish.
The AgentFinish will have two fields:
* output: str - The output string generated by the model
* citations: List[CohereCitation] - A list of citations that refer to the
output and observations made by the agent. If there are no citations this
list will be empty.
Example:
. code-block:: python
from langchain.agents import AgentExecutor
from langchain.prompts import ChatPromptTemplate
from langchain_cohere import ChatCohere, create_cohere_react_agent
prompt = ChatPromptTemplate.from_template("{input}")
tools = [] # Populate this with a list of tools you would like to use.
llm = ChatCohere()
agent = create_cohere_react_agent(
llm,
tools,
prompt
)
agent_executor = AgentExecutor(agent=agent, tools=tools)
agent_executor.invoke({
"input": "In what year was the company that was founded as Sound of Music added to the S&P 500?",
})
""" # noqa: E501
# Creates a prompt, invokes the model, and produces a
# "Union[List[AgentAction], AgentFinish]"
generate_agent_steps = (
multi_hop_prompt(tools=tools, prompt=prompt)
| llm.bind(stop=["\nObservation:"], raw_prompting=True)
| CohereToolsReactAgentOutputParser()
)
agent = (
RunnablePassthrough.assign(
# agent_scratchpad isn't used in this chain, but added here for
# interoperability with other chains that may require it.
agent_scratchpad=lambda _: [],
)
| RunnableParallel(
chain_input=RunnablePassthrough(), agent_steps=generate_agent_steps
)
| _AddCitations()
)
return agent
class _AddCitations(Runnable):
"""
Adds a list of citations to the output of the Cohere multi hop chain when the
last step is an AgentFinish. Citations are generated from the observations (made
in previous agent steps) and the grounded answer (made in the last step).
"""
def invoke(
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
) -> Union[List[AgentAction], AgentFinish]:
agent_steps = input.get("agent_steps", [])
if not agent_steps:
# The input wasn't as expected.
return []
if not isinstance(agent_steps, AgentFinish):
# We're not on the AgentFinish step.
return agent_steps
agent_finish = agent_steps
# Build a list of documents from the intermediate_steps used in this chain.
intermediate_steps = input.get("chain_input", {}).get("intermediate_steps", [])
documents: List[MutableMapping] = []
for _, observation in intermediate_steps:
documents.extend(convert_to_documents(observation))
# Build a list of citations, if any, from the documents + grounded answer.
grounded_answer = agent_finish.return_values.pop(GROUNDED_ANSWER_KEY, "")
output, citations = parse_citations(
grounded_answer=grounded_answer, documents=documents
)
agent_finish.return_values[OUTPUT_KEY] = output
agent_finish.return_values["citations"] = citations
return agent_finish