Source code for langchain_core.tools.structured
from __future__ import annotations
import textwrap
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Type, Union
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.messages import ToolCall
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnableConfig, run_in_executor
from langchain_core.tools.base import (
FILTERED_ARGS,
BaseTool,
_get_runnable_config_param,
create_schema_from_function,
)
from langchain_core.utils.pydantic import TypeBaseModel
[docs]class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs."""
description: str = ""
args_schema: TypeBaseModel = Field(..., description="The tool schema.")
"""The input arguments' schema."""
func: Optional[Callable[..., Any]]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
"""The asynchronous version of the function."""
# --- Runnable ---
# TODO: Is this needed?
[docs] async def ainvoke(
self,
input: Union[str, Dict, ToolCall],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
if not self.coroutine:
# If the tool does not implement async, fall back to default implementation
return await run_in_executor(config, self.invoke, input, config, **kwargs)
return await super().ainvoke(input, config, **kwargs)
# --- Tool ---
@property
def args(self) -> dict:
"""The tool's input arguments."""
return self.args_schema.schema()["properties"]
def _run(
self,
*args: Any,
config: RunnableConfig,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs: Any,
) -> Any:
"""Use the tool."""
if self.func:
if run_manager and signature(self.func).parameters.get("callbacks"):
kwargs["callbacks"] = run_manager.get_child()
if config_param := _get_runnable_config_param(self.func):
kwargs[config_param] = config
return self.func(*args, **kwargs)
raise NotImplementedError("StructuredTool does not support sync invocation.")
async def _arun(
self,
*args: Any,
config: RunnableConfig,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any,
) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
kwargs["callbacks"] = run_manager.get_child()
if config_param := _get_runnable_config_param(self.coroutine):
kwargs[config_param] = config
return await self.coroutine(*args, **kwargs)
# NOTE: this code is unreachable since _arun is only called if coroutine is not
# None.
return await super()._arun(
*args, config=config, run_manager=run_manager, **kwargs
)
[docs] @classmethod
def from_function(
cls,
func: Optional[Callable] = None,
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
*,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
**kwargs: Any,
) -> StructuredTool:
"""Create tool from a given function.
A classmethod that helps to create a tool from a function.
Args:
func: The function from which to create a tool.
coroutine: The async function from which to create a tool.
name: The name of the tool. Defaults to the function name.
description: The description of the tool.
Defaults to the function docstring.
return_direct: Whether to return the result directly or as a callback.
Defaults to False.
args_schema: The schema of the tool's input arguments. Defaults to None.
infer_schema: Whether to infer the schema from the function's signature.
Defaults to True.
response_format: The tool response format. If "content" then the output of
the tool is interpreted as the contents of a ToolMessage. If
"content_and_artifact" then the output is expected to be a two-tuple
corresponding to the (content, artifact) of a ToolMessage.
Defaults to "content".
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt
to parse parameter descriptions from Google Style function docstrings.
Defaults to False.
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings.
Defaults to False.
kwargs: Additional arguments to pass to the tool
Returns:
The tool.
Raises:
ValueError: If the function is not provided.
Examples:
.. code-block:: python
def add(a: int, b: int) -> int:
\"\"\"Add two numbers\"\"\"
return a + b
tool = StructuredTool.from_function(add)
tool.run(1, 2) # 3
"""
if func is not None:
source_function = func
elif coroutine is not None:
source_function = coroutine
else:
raise ValueError("Function and/or coroutine must be provided")
name = name or source_function.__name__
if args_schema is None and infer_schema:
# schema name is appended within function
args_schema = create_schema_from_function(
name,
source_function,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
filter_args=_filter_schema_args(source_function),
)
description_ = description
if description is None and not parse_docstring:
description_ = source_function.__doc__ or None
if description_ is None and args_schema:
description_ = args_schema.__doc__ or None
if description_ is None:
raise ValueError(
"Function must have a docstring if description not provided."
)
if description is None:
# Only apply if using the function's docstring
description_ = textwrap.dedent(description_).strip()
# Description example:
# search_api(query: str) - Searches the API for the query.
description_ = f"{description_.strip()}"
return cls(
name=name,
func=func,
coroutine=coroutine,
args_schema=args_schema, # type: ignore[arg-type]
description=description_,
return_direct=return_direct,
response_format=response_format,
**kwargs,
)
def _filter_schema_args(func: Callable) -> List[str]:
filter_args = list(FILTERED_ARGS)
if config_param := _get_runnable_config_param(func):
filter_args.append(config_param)
# filter_args.extend(_get_non_model_params(type_hints))
return filter_args