Source code for langchain_aws.agents.base

from __future__ import annotations

import json
import logging
import time
import uuid
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union

import boto3
from botocore.client import Config
from botocore.exceptions import UnknownServiceError
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import CallbackManager
from langchain_core.load import dumpd
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
from langchain_core.tools import BaseTool
from pydantic import model_validator

_DEFAULT_ACTION_GROUP_NAME = "DEFAULT_AG_"
_TEST_AGENT_ALIAS_ID = "TSTALIASID"


[docs] def get_boto_session( credentials_profile_name: Optional[str] = None, region_name: Optional[str] = None, endpoint_url: Optional[str] = None, ) -> Any: """ Construct the boto3 session """ if credentials_profile_name: session = boto3.Session(profile_name=credentials_profile_name) else: # use default credentials session = boto3.Session() client_params = { "config": Config( connect_timeout=120, read_timeout=120, retries={"max_attempts": 0} ) } if region_name: client_params["region_name"] = region_name if endpoint_url: client_params["endpoint_url"] = endpoint_url return client_params, session
[docs] def parse_agent_response(response: Any) -> OutputType: """ Parses the raw response from Bedrock Agent Args: response: The raw response from Bedrock Agent Returns Either a BedrockAgentAction or a BedrockAgentFinish """ response_text = "" event_stream = response["completion"] session_id = response["sessionId"] trace_log_elements = [] for event in event_stream: if "trace" in event: trace_log_elements.append(event["trace"]) if "returnControl" in event: response_text = json.dumps(event) break if "chunk" in event: response_text = event["chunk"]["bytes"].decode("utf-8") trace_log = json.dumps(trace_log_elements) agent_finish = BedrockAgentFinish( return_values={"output": response_text}, log=response_text, session_id=session_id, trace_log=trace_log, ) if not response_text: return agent_finish if "returnControl" not in response_text: return agent_finish return_control = json.loads(response_text).get("returnControl") if not return_control: return agent_finish invocation_inputs = return_control.get("invocationInputs") if not invocation_inputs: return agent_finish try: invocation_input = invocation_inputs[0].get("functionInvocationInput", {}) action_group = invocation_input.get("actionGroup", "") function = invocation_input.get("function", "") parameters = invocation_input.get("parameters", []) parameters_json = {} for parameter in parameters: parameters_json[parameter.get("name")] = parameter.get("value", "") tool = f"{action_group}::{function}" if _DEFAULT_ACTION_GROUP_NAME in action_group: tool = f"{function}" return [ BedrockAgentAction( tool=tool, tool_input=parameters_json, log=response_text, session_id=session_id, trace_log=trace_log, ) ] except Exception as ex: raise Exception("Parse exception encountered {}".format(repr(ex)))
def _create_bedrock_agent( bedrock_client: Any, agent_name: str, agent_resource_role_arn: str, instruction: str, foundation_model: str, client_token: Optional[str] = None, customer_encryption_key_arn: Optional[str] = None, description: Optional[str] = None, guardrail_configuration: Optional[GuardrailConfiguration] = None, idle_session_ttl_in_seconds: Optional[int] = None, ) -> Union[str, None]: """ Creates the bedrock agent """ create_agent_request: dict = { "agentName": agent_name, "agentResourceRoleArn": agent_resource_role_arn, "foundationModel": foundation_model, "instruction": instruction, } if description: create_agent_request["description"] = description if client_token: create_agent_request["clientToken"] = client_token if customer_encryption_key_arn: create_agent_request["customerEncryptionKeyArn"] = customer_encryption_key_arn if guardrail_configuration is not None: create_agent_request["guardrailConfiguration"] = { "guardrailIdentifier": guardrail_configuration["guardrail_identifier"], "guardrailVersion": guardrail_configuration["guardrail_version"] or "DRAFT", } if idle_session_ttl_in_seconds: create_agent_request["idleSessionTTLInSeconds"] = idle_session_ttl_in_seconds create_agent_response = bedrock_client.create_agent(**create_agent_request) request_id = create_agent_response.get("ResponseMetadata", {}).get("RequestId", "") logging.info(f"Create bedrock agent call successful with request id: {request_id}") agent_id = create_agent_response["agent"]["agentId"] create_agent_start_time = time.time() while time.time() - create_agent_start_time < 10: agent_creation_status = ( bedrock_client.get_agent(agentId=agent_id) .get("agent", {}) .get("agentStatus", {}) ) if agent_creation_status == "NOT_PREPARED": return agent_id else: time.sleep(2) logging.error(f"Failed to create bedrock agent {agent_id}") raise Exception(f"Failed to create bedrock agent {agent_id}") def _get_action_group_and_function_names(tool: BaseTool) -> Tuple[str, str]: """ Convert the LangChain 'Tool' into Bedrock Action Group name and Function name """ action_group_name = _DEFAULT_ACTION_GROUP_NAME function_name = tool.name tool_name_split = tool.name.split("::") if len(tool_name_split) > 1: action_group_name = tool_name_split[0] function_name = tool_name_split[1] return action_group_name, function_name def _create_bedrock_action_groups( bedrock_client: Any, agent_id: str, tools: List[BaseTool] ) -> None: """Create the bedrock action groups for the agent""" tools_by_action_group = defaultdict(list) for tool in tools: action_group_name, function_name = _get_action_group_and_function_names(tool) tools_by_action_group[action_group_name].append(tool) for action_group_name, functions in tools_by_action_group.items(): bedrock_client.create_agent_action_group( actionGroupName=action_group_name, actionGroupState="ENABLED", actionGroupExecutor={"customControl": "RETURN_CONTROL"}, functionSchema={ "functions": [_tool_to_function(function) for function in functions] }, agentId=agent_id, agentVersion="DRAFT", ) def _tool_to_function(tool: BaseTool) -> dict: """ Convert LangChain tool to a Bedrock function schema """ _, function_name = _get_action_group_and_function_names(tool) function_parameters = {} for arg_name, arg_details in tool.args.items(): function_parameters[arg_name] = { "description": arg_details.get( "description", arg_details.get("title", arg_name) ), "type": arg_details.get("type", "string"), "required": not bool(arg_details.get("default", None)), } return { "description": tool.description, "name": function_name, "parameters": function_parameters, } def _prepare_agent(bedrock_client: Any, agent_id: str) -> None: """ Prepare the agent for invocations """ bedrock_client.prepare_agent(agentId=agent_id) prepare_agent_start_time = time.time() while time.time() - prepare_agent_start_time < 10: agent_status = bedrock_client.get_agent(agentId=agent_id) if agent_status.get("agent", {}).get("agentStatus", "") == "PREPARED": return else: time.sleep(2) raise Exception(f"Timed out while preparing the agent with id {agent_id}") def _get_bedrock_agent(bedrock_client: Any, agent_name: str) -> Any: """ Get the agent by name """ next_token = None while True: if next_token: list_agents_response = bedrock_client.list_agents( maxResults=1000, nextToken=next_token ) else: list_agents_response = bedrock_client.list_agents(maxResults=1000) agent_summaries = list_agents_response.get("agentSummaries", []) next_token = list_agents_response.get("nextToken") agent_summary = next( (x for x in agent_summaries if x.get("agentName") == agent_name), None ) if agent_summary: return agent_summary if next_token is None: return None
[docs] class BedrockAgentFinish(AgentFinish): """AgentFinish with session id information. Parameters: session_id: Session id trace_log: trace log as string when enable_trace flag is set """ session_id: str trace_log: Optional[str] @classmethod def is_lc_serializable(cls) -> bool: """Check if the class is serializable by LangChain. Returns: False """ return False
[docs] class BedrockAgentAction(AgentAction): """AgentAction with session id information. Parameters: session_id: session id trace_log: trace log as string when enable_trace flag is set """ session_id: str trace_log: Optional[str] @classmethod def is_lc_serializable(cls) -> bool: """Check if the class is serializable by LangChain. Returns: False """ return False
OutputType = Union[List[BedrockAgentAction], BedrockAgentFinish]
[docs] class GuardrailConfiguration(TypedDict): guardrail_identifier: str guardrail_version: str
[docs] class BedrockAgentsRunnable(RunnableSerializable[Dict, OutputType]): """ Invoke a Bedrock Agent """ agent_id: Optional[str] """Bedrock Agent Id""" agent_alias_id: Optional[str] = _TEST_AGENT_ALIAS_ID """Bedrock Agent Alias Id""" client: Any """Boto3 client""" region_name: Optional[str] = None """Region""" credentials_profile_name: Optional[str] = None """Credentials to use to invoke the agent""" endpoint_url: Optional[str] = None """Endpoint URL""" enable_trace: Optional[bool] = False """Boolean flag to enable trace when invoking Bedrock Agent""" @model_validator(mode="before") @classmethod def validate_agent(cls, values: dict) -> Any: if values.get("client") is not None: return values try: client_params, session = get_boto_session( credentials_profile_name=values["credentials_profile_name"], region_name=values["region_name"], endpoint_url=values["endpoint_url"], ) values["client"] = session.client("bedrock-agent-runtime", **client_params) return values except ImportError: raise ModuleNotFoundError( "Could not import boto3 python package. " "Please install it with `pip install boto3`." ) except UnknownServiceError as e: raise ModuleNotFoundError( "Ensure that you have installed the latest boto3 package " "that contains the API for `bedrock-runtime-agent`." ) from e except Exception as e: raise ValueError( "Could not load credentials to authenticate with AWS client. " "Please check that credentials in the specified " "profile name are valid." ) from e
[docs] @classmethod def create_agent( cls, agent_name: str, agent_resource_role_arn: str, foundation_model: str, instruction: str, tools: List[BaseTool] = [], *, client_token: Optional[str] = None, customer_encryption_key_arn: Optional[str] = None, description: Optional[str] = None, guardrail_configuration: Optional[GuardrailConfiguration] = None, idle_session_ttl_in_seconds: Optional[int] = None, credentials_profile_name: Optional[str] = None, region_name: Optional[str] = None, bedrock_endpoint_url: Optional[str] = None, runtime_endpoint_url: Optional[str] = None, enable_trace: Optional[bool] = False, **kwargs: Any, ) -> BedrockAgentsRunnable: """ Creates a Bedrock Agent Runnable that can be used with an AgentExecutor or with LangGraph. This also sets up the Bedrock agent, actions and action groups infrastructure if they don't exist, ensures the agent is in PREPARED state so that it is ready to be called. Args: agent_name: Name of the agent agent_resource_role_arn: The Amazon Resource Name (ARN) of the IAM role with permissions to invoke API operations on the agent. foundation_model: The foundation model to be used for orchestration by the agent you create instruction: Instructions that tell the agent what it should do and how it should interact with users tools: List of tools. Accepts LangChain's BaseTool format client_token: A unique, case-sensitive identifier to ensure that the API request completes no more than one time. If this token matches a previous request, Amazon Bedrock ignores the request, but does not return an error customer_encryption_key_arn: The Amazon Resource Name (ARN) of the KMS key with which to encrypt the agent description: A description of the agent guardrail_configuration: The unique Guardrail configuration assigned to the agent when it is created. idle_session_ttl_in_seconds: The number of seconds for which Amazon Bedrock keeps information about a user's conversation with the agent. A user interaction remains active for the amount of time specified. If no conversation occurs during this time, the session expires and Amazon Bedrock deletes any data provided before the timeout credentials_profile_name: The profile name to use if different from default region_name: Region for the Bedrock agent bedrock_endpoint_url: Endpoint URL for bedrock agent runtime_endpoint_url: Endpoint URL for bedrock agent runtime enable_trace: Boolean flag to specify whether trace should be enabled when invoking the agent **kwargs: Additional arguments Returns: BedrockAgentsRunnable configured to invoke the Bedrock agent """ client_params, session = get_boto_session( credentials_profile_name=credentials_profile_name, region_name=region_name, endpoint_url=bedrock_endpoint_url, ) bedrock_client = session.client("bedrock-agent", **client_params) bedrock_agent = _get_bedrock_agent( bedrock_client=bedrock_client, agent_name=agent_name ) if bedrock_agent: agent_id = bedrock_agent["agentId"] agent_status = bedrock_agent["agentStatus"] if agent_status != "PREPARED": _prepare_agent(bedrock_client, agent_id) else: try: agent_id = _create_bedrock_agent( bedrock_client=bedrock_client, agent_name=agent_name, agent_resource_role_arn=agent_resource_role_arn, instruction=instruction, foundation_model=foundation_model, client_token=client_token, customer_encryption_key_arn=customer_encryption_key_arn, description=description, guardrail_configuration=guardrail_configuration, idle_session_ttl_in_seconds=idle_session_ttl_in_seconds, ) _create_bedrock_action_groups(bedrock_client, agent_id, tools) _prepare_agent(bedrock_client, agent_id) except Exception as exception: logging.error(f"Error in create agent call: {exception}") raise exception return cls( agent_id=agent_id, region_name=region_name, credentials_profile_name=credentials_profile_name, endpoint_url=runtime_endpoint_url, enable_trace=enable_trace, **kwargs, )
[docs] def invoke( self, input: Dict, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> OutputType: """ Invoke the Bedrock agent. Args: input: The LangChain Runnable input dictionary that can include: input: The input text to the agent memory_id: The memory id to use for an agent with memory enabled session_id: The session id to use. If not provided, a new session will be started end_session: Boolean indicating whether to end a session or not intermediate_steps: The intermediate steps that are used to provide RoC invocation details config: The optional RunnableConfig Returns: Union[List[BedrockAgentAction], BedrockAgentFinish] """ config = ensure_config(config) callback_manager = CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), inheritable_tags=config.get("tags"), inheritable_metadata=config.get("metadata"), ) run_manager = callback_manager.on_chain_start( dumpd(self), input, name=config.get("run_name") ) try: agent_input: Dict[str, Any] = { "agentId": self.agent_id, "agentAliasId": self.agent_alias_id, "enableTrace": self.enable_trace, "endSession": bool(input.get("end_session", False)), } if input.get("memory_id"): agent_input["memoryId"] = input.get("memory_id") if input.get("intermediate_steps"): session_id, session_state = self._parse_intermediate_steps( input.get("intermediate_steps") # type: ignore[arg-type] ) if session_id is not None: agent_input["sessionId"] = session_id if session_state is not None: agent_input["sessionState"] = session_state else: agent_input["inputText"] = input.get("input", "") agent_input["sessionId"] = input.get("session_id", str(uuid.uuid4())) output = self.client.invoke_agent(**agent_input) except Exception as e: run_manager.on_chain_error(e) raise e try: response = parse_agent_response(output) except Exception as e: run_manager.on_chain_error(e) raise e else: run_manager.on_chain_end(response) return response
def _parse_intermediate_steps( self, intermediate_steps: List[Tuple[BedrockAgentAction, str]] ) -> Tuple[Union[str, None], Union[Dict[str, Any], None]]: last_step = max(0, len(intermediate_steps) - 1) action = intermediate_steps[last_step][0] tool_invoked = action.tool messages = action.messages session_id = action.session_id if tool_invoked: action_group_name = _DEFAULT_ACTION_GROUP_NAME function_name = tool_invoked tool_name_split = tool_invoked.split("::") if len(tool_name_split) > 1: action_group_name = tool_name_split[0] function_name = tool_name_split[1] if messages: last_message = max(0, len(messages) - 1) message = messages[last_message] if type(message) is AIMessage: response = intermediate_steps[last_step][1] session_state = { "invocationId": json.loads(message.content) # type: ignore[arg-type] .get("returnControl", {}) .get("invocationId", ""), "returnControlInvocationResults": [ { "functionResult": { "actionGroup": action_group_name, "function": function_name, "responseBody": {"TEXT": {"body": response}}, } } ], } return session_id, session_state return None, None