"""Declarative chat agent node for Azure AI Foundry agents."""
import base64
import json
import logging
import tempfile
import time
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Union,
)
from azure.ai.agents.models import (
Agent,
FunctionDefinition,
FunctionTool,
FunctionToolDefinition,
ListSortOrder,
MessageInputTextBlock,
RequiredFunctionToolCall,
StructuredToolOutput,
SubmitToolOutputsAction,
ThreadMessage,
Tool,
ToolDefinition,
ToolOutput,
ToolResources,
ToolSet,
)
from azure.ai.projects import AIProjectClient
from azure.core.exceptions import HttpResponseError
from langchain_core.messages import (
AIMessage,
HumanMessage,
ToolCall,
ToolMessage,
)
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import (
convert_to_openai_function,
)
from langgraph._internal._runnable import RunnableCallable
from langgraph.graph import MessagesState
from langgraph.prebuilt.tool_node import ToolNode
from langgraph.store.base import BaseStore
from langchain_azure_ai.agents.prebuilt.tools import (
AgentServiceBaseTool,
_OpenAIFunctionTool,
)
logger = logging.getLogger(__package__)
def _required_tool_calls_to_message(
required_tool_call: RequiredFunctionToolCall,
) -> AIMessage:
"""Convert a RequiredFunctionToolCall to an AIMessage with tool calls.
Args:
required_tool_call: The RequiredFunctionToolCall to convert.
Returns:
An AIMessage containing the tool calls.
"""
tool_calls: List[ToolCall] = []
tool_calls.append(
ToolCall(
id=required_tool_call.id,
name=required_tool_call.function.name,
args=json.loads(required_tool_call.function.arguments),
)
)
return AIMessage(content="", tool_calls=tool_calls)
def _tool_message_to_output(tool_message: ToolMessage) -> StructuredToolOutput:
"""Convert a ToolMessage to a ToolOutput."""
# TODO: Add support to artifacts
return ToolOutput(
tool_call_id=tool_message.tool_call_id,
output=tool_message.content, # type: ignore[arg-type]
)
def _get_tool_resources(
tools: Union[
Sequence[Union[AgentServiceBaseTool, BaseTool, Callable]],
ToolNode,
],
) -> Union[ToolResources, None]:
"""Get the tool resources for a list of tools.
Args:
tools: A list of tools to get resources for.
Returns:
The tool resources.
"""
if isinstance(tools, list):
for tool in tools:
if isinstance(tool, AgentServiceBaseTool):
if tool.tool.resources is not None:
return tool.tool.resources
else:
continue
return None
def _get_tool_definitions(
tools: Union[
Sequence[Union[AgentServiceBaseTool, BaseTool, Callable]],
ToolNode,
],
) -> List[ToolDefinition]:
"""Convert a list of tools to a ToolSet for the agent.
Args:
tools: A list of tools, which can be BaseTool instances, callables, or
tool definitions.
Returns:
A ToolSet containing the converted tools.
"""
toolset = ToolSet()
function_tools: set[Callable] = set()
openai_tools: list[FunctionToolDefinition] = []
if isinstance(tools, list):
for tool in tools:
if isinstance(tool, AgentServiceBaseTool):
logger.debug(f"Adding AgentService tool: {tool.tool}")
toolset.add(tool.tool)
elif isinstance(tool, BaseTool):
function_def = convert_to_openai_function(tool)
logger.debug(f"Adding OpenAI function tool: {function_def['name']}")
openai_tools.append(
FunctionToolDefinition(
function=FunctionDefinition(
name=function_def["name"],
description=function_def["description"],
parameters=function_def["parameters"],
)
)
)
elif callable(tool):
logger.debug(f"Adding callable function tool: {tool.__name__}")
function_tools.add(tool)
else:
if isinstance(tool, Tool):
raise ValueError(
"Passing raw Tool definitions from package azure-ai-agents "
"is not supported. Wrap the tool in "
"langchain_azure_ai.agents.prebuilt.tools.AgentServiceBaseTool"
" and pass `tool=<your_tool>`."
)
else:
raise ValueError(
"Each tool must be an AgentServiceBaseTool, BaseTool, or a "
f"callable. Got {type(tool)}"
)
elif isinstance(tools, ToolNode):
raise ValueError(
"ToolNode is not supported as a tool input. Use a list of " "tools instead."
)
else:
raise ValueError("tools must be a list or a ToolNode.")
if len(function_tools) > 0:
toolset.add(FunctionTool(function_tools))
if len(openai_tools) > 0:
toolset.add(_OpenAIFunctionTool(openai_tools))
return toolset.definitions
[docs]
class DeclarativeChatAgentNode(RunnableCallable):
"""A LangGraph node that represents a declarative chat agent in Azure AI Foundry.
You can use this node to create complex graphs that involve interactions with
an Azure AI Foundry agent.
You can also use `langchain_azure_ai.agents.AgentServiceFactory` to create
instances of this node.
Example:
.. code-block:: python
from azure.identity import DefaultAzureCredential
from langchain_azure_ai.agents import AgentServiceFactory
factory = AgentServiceFactory(
project_endpoint=(
"https://resource.services.ai.azure.com/api/projects/demo-project",
),
credential=DefaultAzureCredential()
)
coder = factory.create_declarative_chat_node(
name="code-interpreter-agent",
model="gpt-4.1",
instructions="You are a helpful assistant that can run Python code.",
tools=[func1, func2],
)
"""
name: str = "DeclarativeChatAgent"
_client: AIProjectClient
"""The AIProjectClient instance to use."""
_agent: Optional[Agent] = None
"""The agent instance to use."""
_agent_name: Optional[str] = None
"""The name of the agent to create or use."""
_agent_id: Optional[str] = None
"""The ID of the agent to use. If not provided, a new agent will be created."""
_thread_id: Optional[str] = None
"""The ID of the conversation thread to use. If not provided, a new thread will be
created."""
_pending_run_id: Optional[str] = None
"""The ID of the pending run, if any."""
_polling_interval: int = 1
"""The interval (in seconds) to poll for updates on the agent's status."""
def __init__(
self,
client: AIProjectClient,
model: str,
instructions: str,
name: str,
description: Optional[str] = None,
agent_id: Optional[str] = None,
response_format: Optional[Dict[str, Any]] = None,
tools: Optional[
Union[
Sequence[Union[AgentServiceBaseTool, BaseTool, Callable]],
ToolNode,
]
] = None,
tool_resources: Optional[Any] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
polling_interval: int = 1,
tags: Optional[Sequence[str]] = None,
trace: bool = True,
) -> None:
"""Initialize the DeclarativeChatAgentNode.
Args:
client: The AIProjectClient instance to use.
model: The model to use for the agent.
instructions: The prompt instructions to use for the agent.
name: The name of the agent.
agent_id: The ID of an existing agent to use. If not provided, a new
agent will be created.
response_format: The response format to use for the agent.
description: An optional description for the agent.
tools: A list of tools to use with the agent. Each tool can be a
dictionary defining the tool.
tool_resources: Optional tool resources to use with the agent.
temperature: The temperature to use for the agent.
top_p: The top_p value to use for the agent.
tags: Optional tags to associate with the agent.
polling_interval: The interval (in seconds) to poll for updates on the
agent's status. Defaults to 1 second.
trace: Whether to enable tracing for the node. Defaults to True.
"""
super().__init__(self._func, self._afunc, name=name, tags=tags, trace=trace)
self._client = client
self._polling_interval = polling_interval
if agent_id is not None:
try:
self._agent = self._client.agents.get_agent(agent_id=agent_id)
self._agent_id = self._agent.id
self._agent_name = self._agent.name
except HttpResponseError as e:
raise ValueError(
f"Could not find agent with ID {agent_id} in the "
"connected project. Do not pass agent_id when "
"creating a new agent."
) from e
agent_params: Dict[str, Any] = {
"model": model,
"name": name,
"instructions": instructions,
}
# Add optional parameters
if description:
agent_params["description"] = description
if tool_resources:
agent_params["tool_resources"] = tool_resources
if tags:
agent_params["metadata"] = tags
if temperature is not None:
agent_params["temperature"] = temperature
if top_p is not None:
agent_params["top_p"] = top_p
if response_format is not None:
agent_params["response_format"] = response_format
if tools is not None:
agent_params["tools"] = _get_tool_definitions(tools)
tool_resources = _get_tool_resources(tools)
if tool_resources is not None:
agent_params["tool_resources"] = tool_resources
self._agent = client.agents.create_agent(**agent_params)
self._agent_id = self._agent.id
self._agent_name = name
logger.info(f"Created agent with name: {self._agent.name} ({self._agent.id})")
def _to_langchain_message(self, msg: ThreadMessage) -> AIMessage:
"""Convert an Azure AI Foundry message to a LangChain message.
Args:
msg: The message from Azure AI Foundry.
Returns:
The corresponding LangChain message, or None if the message type is
unsupported.
"""
contents: List[Union[str, Dict[Any, Any]]] = []
file_paths: Dict[str, str] = {}
if msg.text_messages:
for text in msg.text_messages:
contents.append(text.text.value)
if msg.file_path_annotations:
for ann in msg.file_path_annotations:
logger.info(
f"Found file path annotation: {ann.type} with text {ann.text}"
)
if ann.type == "file_path":
file_paths[ann.file_path.file_id] = ann.text.split("/")[-1]
if msg.image_contents:
for img in msg.image_contents:
file_id = img.image_file.file_id
file_name = file_paths.get(file_id, f"{file_id}.png")
with tempfile.TemporaryDirectory() as target_dir:
logger.info(f"Downloading image file {file_id} as {file_name}")
self._client.agents.files.save(
file_id=file_id,
file_name=file_name,
target_dir=target_dir,
)
with open(f"{target_dir}/{file_name}", "rb") as f:
content = f.read()
contents.append(
{
"type": "image",
"mime_type": "image/png",
"base64": base64.b64encode(content).decode("utf-8"),
}
)
if len(contents) == 1:
return AIMessage(content=contents[0]) # type: ignore[arg-type]
return AIMessage(content=contents) # type: ignore[arg-type]
def delete_agent_from_node(self) -> None:
"""Delete an agent associated with a DeclarativeChatAgentNode node."""
if self._agent_id is not None:
self._client.agents.delete_agent(self._agent_id)
logger.info(f"Deleted agent with ID: {self._agent_id}")
self._agent_id = None
self._agent = None
else:
raise ValueError(
"The node does not have an associated agent ID to eliminate"
)
def _func(
self,
input: MessagesState,
config: RunnableConfig,
*,
store: Optional[BaseStore],
) -> Any:
if self._agent is None or self._agent_id is None:
raise RuntimeError(
"The agent has not been initialized properly "
"its associated agent in Azure AI Foundry "
"has been deleted."
)
if self._thread_id is None:
thread = self._client.agents.threads.create()
self._thread_id = thread.id
logger.info(f"Created new thread with ID: {self._thread_id}")
assert self._thread_id is not None
state = input
if len(state["messages"]) > 0:
message = state["messages"][-1]
else:
raise ValueError("Input state must contain at least one message.")
if isinstance(message, ToolMessage):
logger.info(f"Submitting tool message with ID {message.id}")
if self._pending_run_id:
run = self._client.agents.runs.get(
thread_id=self._thread_id, run_id=self._pending_run_id
)
if run.status == "requires_action" and isinstance(
run.required_action, SubmitToolOutputsAction
):
tool_outputs = [_tool_message_to_output(message)]
self._client.agents.runs.submit_tool_outputs(
thread_id=self._thread_id,
run_id=self._pending_run_id,
tool_outputs=tool_outputs,
)
else:
raise RuntimeError(
f"Run {self._pending_run_id} is not in a state to accept "
"tool outputs."
)
else:
raise RuntimeError(
"No pending run to submit tool outputs to. Got ToolMessage "
"without a pending run."
)
elif isinstance(message, HumanMessage):
logger.info(f"Submitting human message {message.content}")
if isinstance(message.content, str):
self._client.agents.messages.create(
thread_id=self._thread_id, role="user", content=message.content
)
elif isinstance(message.content, dict):
raise RuntimeError(
"Message content as dict is not supported yet. "
"Please submit as string."
)
elif isinstance(message.content, list):
self._client.agents.messages.create(
thread_id=self._thread_id,
role="user",
content=[MessageInputTextBlock(block) for block in message.content], # type: ignore[arg-type]
)
else:
raise RuntimeError(f"Unsupported message type: {type(message)}")
if self._pending_run_id is None:
logger.info("Creating and processing new run...")
run = self._client.agents.runs.create(
thread_id=self._thread_id,
agent_id=self._agent_id,
)
else:
logger.info(f"Getting existing run {self._pending_run_id}...")
run = self._client.agents.runs.get(
thread_id=self._thread_id, run_id=self._pending_run_id
)
while run.status in ["queued", "in_progress"]:
time.sleep(self._polling_interval)
run = self._client.agents.runs.get(thread_id=self._thread_id, run_id=run.id)
if run.status == "requires_action" and isinstance(
run.required_action, SubmitToolOutputsAction
):
tool_calls = run.required_action.submit_tool_outputs.tool_calls
for tool_call in tool_calls:
if isinstance(tool_call, RequiredFunctionToolCall):
state["messages"].append(_required_tool_calls_to_message(tool_call))
else:
raise ValueError(
f"Unsupported tool call type: {type(tool_call)} in run "
f"{run.id}."
)
self._pending_run_id = run.id
elif run.status == "failed":
raise RuntimeError(f"Run {run.id} failed with error: {run.last_error}")
elif run.status == "completed":
response = self._client.agents.messages.list(
thread_id=self._thread_id,
run_id=run.id,
order=ListSortOrder.ASCENDING,
)
for msg in response:
new_message = self._to_langchain_message(msg)
if new_message:
state["messages"].append(new_message)
self._pending_run_id = None
async def _afunc(
self,
input: MessagesState,
config: RunnableConfig,
*,
store: Optional[BaseStore],
) -> Any:
import asyncio
def _sync_func() -> Any:
return self._func(input, config, store=store)
return await asyncio.to_thread(_sync_func)