Source code for langchain_core.messages.ai

import json
from typing import Any, Literal, Optional, Union

from pydantic import model_validator
from typing_extensions import Self, TypedDict

from langchain_core.messages.base import (
    BaseMessage,
    BaseMessageChunk,
    merge_content,
)
from langchain_core.messages.tool import (
    InvalidToolCall,
    ToolCall,
    ToolCallChunk,
    default_tool_chunk_parser,
    default_tool_parser,
)
from langchain_core.messages.tool import (
    invalid_tool_call as create_invalid_tool_call,
)
from langchain_core.messages.tool import (
    tool_call as create_tool_call,
)
from langchain_core.messages.tool import (
    tool_call_chunk as create_tool_call_chunk,
)
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.json import parse_partial_json


[docs] class UsageMetadata(TypedDict): """Usage metadata for a message, such as token counts. This is a standard representation of token usage that is consistent across models. Example: .. code-block:: python { "input_tokens": 10, "output_tokens": 20, "total_tokens": 30 } """ input_tokens: int """Count of input (or prompt) tokens.""" output_tokens: int """Count of output (or completion) tokens.""" total_tokens: int """Total token count."""
[docs] class AIMessage(BaseMessage): """Message from an AI. AIMessage is returned from a chat model as a response to a prompt. This message represents the output of the model and consists of both the raw output as returned by the model together standardized fields (e.g., tool calls, usage metadata) added by the LangChain framework. """ example: bool = False """Use to denote that a message is part of an example conversation. At the moment, this is ignored by most models. Usage is discouraged. """ tool_calls: list[ToolCall] = [] """If provided, tool calls associated with the message.""" invalid_tool_calls: list[InvalidToolCall] = [] """If provided, tool calls with parsing errors associated with the message.""" usage_metadata: Optional[UsageMetadata] = None """If provided, usage metadata for a message, such as token counts. This is a standard representation of token usage that is consistent across models. """ type: Literal["ai"] = "ai" """The type of the message (used for deserialization). Defaults to "ai".""" def __init__( self, content: Union[str, list[Union[str, dict]]], **kwargs: Any ) -> None: """Pass in content as positional arg. Args: content: The content of the message. kwargs: Additional arguments to pass to the parent class. """ super().__init__(content=content, **kwargs) @classmethod def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Returns: The namespace of the langchain object. Defaults to ["langchain", "schema", "messages"]. """ return ["langchain", "schema", "messages"] @property def lc_attributes(self) -> dict: """Attrs to be serialized even if they are derived from other init args.""" return { "tool_calls": self.tool_calls, "invalid_tool_calls": self.invalid_tool_calls, } @model_validator(mode="before") @classmethod def _backwards_compat_tool_calls(cls, values: dict) -> Any: check_additional_kwargs = not any( values.get(k) for k in ("tool_calls", "invalid_tool_calls", "tool_call_chunks") ) if check_additional_kwargs and ( raw_tool_calls := values.get("additional_kwargs", {}).get("tool_calls") ): try: if issubclass(cls, AIMessageChunk): # type: ignore values["tool_call_chunks"] = default_tool_chunk_parser( raw_tool_calls ) else: parsed_tool_calls, parsed_invalid_tool_calls = default_tool_parser( raw_tool_calls ) values["tool_calls"] = parsed_tool_calls values["invalid_tool_calls"] = parsed_invalid_tool_calls except Exception: pass # Ensure "type" is properly set on all tool call-like dicts. if tool_calls := values.get("tool_calls"): updated: list = [] for tc in tool_calls: updated.append( create_tool_call(**{k: v for k, v in tc.items() if k != "type"}) ) values["tool_calls"] = updated if invalid_tool_calls := values.get("invalid_tool_calls"): updated = [] for tc in invalid_tool_calls: updated.append( create_invalid_tool_call( **{k: v for k, v in tc.items() if k != "type"} ) ) values["invalid_tool_calls"] = updated if tool_call_chunks := values.get("tool_call_chunks"): updated = [] for tc in tool_call_chunks: updated.append( create_tool_call_chunk( **{k: v for k, v in tc.items() if k != "type"} ) ) values["tool_call_chunks"] = updated return values
[docs] def pretty_repr(self, html: bool = False) -> str: """Return a pretty representation of the message. Args: html: Whether to return an HTML-formatted string. Defaults to False. Returns: A pretty representation of the message. """ base = super().pretty_repr(html=html) lines = [] def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> list[str]: lines = [ f" {tc.get('name', 'Tool')} ({tc.get('id')})", f" Call ID: {tc.get('id')}", ] if tc.get("error"): lines.append(f" Error: {tc.get('error')}") lines.append(" Args:") args = tc.get("args") if isinstance(args, str): lines.append(f" {args}") elif isinstance(args, dict): for arg, value in args.items(): lines.append(f" {arg}: {value}") return lines if self.tool_calls: lines.append("Tool Calls:") for tc in self.tool_calls: lines.extend(_format_tool_args(tc)) if self.invalid_tool_calls: lines.append("Invalid Tool Calls:") for itc in self.invalid_tool_calls: lines.extend(_format_tool_args(itc)) return (base.strip() + "\n" + "\n".join(lines)).strip()
AIMessage.model_rebuild()
[docs] class AIMessageChunk(AIMessage, BaseMessageChunk): """Message chunk from an AI.""" # Ignoring mypy re-assignment here since we're overriding the value # to make sure that the chunk variant can be discriminated from the # non-chunk variant. type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore """The type of the message (used for deserialization). Defaults to "AIMessageChunk".""" tool_call_chunks: list[ToolCallChunk] = [] """If provided, tool call chunks associated with the message.""" @classmethod def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Returns: The namespace of the langchain object. Defaults to ["langchain", "schema", "messages"]. """ return ["langchain", "schema", "messages"] @property def lc_attributes(self) -> dict: """Attrs to be serialized even if they are derived from other init args.""" return { "tool_calls": self.tool_calls, "invalid_tool_calls": self.invalid_tool_calls, } @model_validator(mode="after") def init_tool_calls(self) -> Self: """Initialize tool calls from tool call chunks. Args: values: The values to validate. Returns: The values with tool calls initialized. Raises: ValueError: If the tool call chunks are malformed. """ if not self.tool_call_chunks: if self.tool_calls: self.tool_call_chunks = [ create_tool_call_chunk( name=tc["name"], args=json.dumps(tc["args"]), id=tc["id"], index=None, ) for tc in self.tool_calls ] if self.invalid_tool_calls: tool_call_chunks = self.tool_call_chunks tool_call_chunks.extend( [ create_tool_call_chunk( name=tc["name"], args=tc["args"], id=tc["id"], index=None ) for tc in self.invalid_tool_calls ] ) self.tool_call_chunks = tool_call_chunks return self tool_calls = [] invalid_tool_calls = [] for chunk in self.tool_call_chunks: try: args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {} # type: ignore[arg-type] if isinstance(args_, dict): tool_calls.append( create_tool_call( name=chunk["name"] or "", args=args_, id=chunk["id"], ) ) else: raise ValueError("Malformed args.") except Exception: invalid_tool_calls.append( create_invalid_tool_call( name=chunk["name"], args=chunk["args"], id=chunk["id"], error=None, ) ) self.tool_calls = tool_calls self.invalid_tool_calls = invalid_tool_calls return self def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore if isinstance(other, AIMessageChunk): return add_ai_message_chunks(self, other) elif isinstance(other, (list, tuple)) and all( isinstance(o, AIMessageChunk) for o in other ): return add_ai_message_chunks(self, *other) return super().__add__(other)
[docs] def add_ai_message_chunks( left: AIMessageChunk, *others: AIMessageChunk ) -> AIMessageChunk: """Add multiple AIMessageChunks together.""" if any(left.example != o.example for o in others): raise ValueError( "Cannot concatenate AIMessageChunks with different example values." ) content = merge_content(left.content, *(o.content for o in others)) additional_kwargs = merge_dicts( left.additional_kwargs, *(o.additional_kwargs for o in others) ) response_metadata = merge_dicts( left.response_metadata, *(o.response_metadata for o in others) ) # Merge tool call chunks if raw_tool_calls := merge_lists( left.tool_call_chunks, *(o.tool_call_chunks for o in others) ): tool_call_chunks = [ create_tool_call_chunk( name=rtc.get("name"), args=rtc.get("args"), index=rtc.get("index"), id=rtc.get("id"), ) for rtc in raw_tool_calls ] else: tool_call_chunks = [] # Token usage if left.usage_metadata or any(o.usage_metadata is not None for o in others): usage_metadata_: UsageMetadata = left.usage_metadata or UsageMetadata( input_tokens=0, output_tokens=0, total_tokens=0 ) for other in others: if other.usage_metadata is not None: usage_metadata_["input_tokens"] += other.usage_metadata["input_tokens"] usage_metadata_["output_tokens"] += other.usage_metadata[ "output_tokens" ] usage_metadata_["total_tokens"] += other.usage_metadata["total_tokens"] usage_metadata: Optional[UsageMetadata] = usage_metadata_ else: usage_metadata = None return left.__class__( example=left.example, content=content, additional_kwargs=additional_kwargs, tool_call_chunks=tool_call_chunks, response_metadata=response_metadata, usage_metadata=usage_metadata, id=left.id, )