How to use few-shot prompting with tool calling

For more complex tool use it's very useful to add few-shot examples to the prompt. We can do this by adding AIMessages with ToolCalls and corresponding ToolMessages to our prompt.

First let's define our tools and model.

from import tool

def add(a: int, b: int) -> int:
"""Adds a and b."""
return a + b

def multiply(a: int, b: int) -> int:
"""Multiplies a and b."""
return a * b

tools = [add, multiply]
API Reference:tool
import os
from getpass import getpass

from langchain_openai import ChatOpenAI

os.environ["OPENAI_API_KEY"] = getpass()

llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm_with_tools = llm.bind_tools(tools)
API Reference:ChatOpenAI

Let's run our model where we can notice that even with some special instructions our model can get tripped up by order of operations.

"Whats 119 times 8 minus 20. Don't do any math yourself, only use tools for math. Respect order of operations"
[{'name': 'Multiply',
'args': {'a': 119, 'b': 8},
'id': 'call_T88XN6ECucTgbXXkyDeC2CQj'},
{'name': 'Add',
'args': {'a': 952, 'b': -20},
'id': 'call_licdlmGsRqzup8rhqJSb1yZ4'}]

The model shouldn't be trying to add anything yet, since it technically can't know the results of 119 * 8 yet.

By adding a prompt with some examples we can correct this behavior:

from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough

examples = [
"What's the product of 317253 and 128472 plus four", name="example_user"
{"name": "Multiply", "args": {"x": 317253, "y": 128472}, "id": "1"}
ToolMessage("16505054784", tool_call_id="1"),
tool_calls=[{"name": "Add", "args": {"x": 16505054784, "y": 4}, "id": "2"}],
ToolMessage("16505054788", tool_call_id="2"),
"The product of 317253 and 128472 plus four is 16505054788",

system = """You are bad at math but are an expert at using a calculator.

Use past tool usage as an example of how to correctly use the tools."""
few_shot_prompt = ChatPromptTemplate.from_messages(
("system", system),
("human", "{query}"),

chain = {"query": RunnablePassthrough()} | few_shot_prompt | llm_with_tools
chain.invoke("Whats 119 times 8 minus 20").tool_calls
[{'name': 'Multiply',
'args': {'a': 119, 'b': 8},
'id': 'call_9MvuwQqg7dlJupJcoTWiEsDo'}]

And we get the correct output this time.

Here's what the LangSmith trace looks like.

