importasyncioimportjsonimportloggingimportwarningsfromabcimportABCfromtypingimport(Any,AsyncGenerator,AsyncIterator,Dict,Iterator,List,Mapping,Optional,Tuple,TypedDict,Union,)fromlangchain_core.callbacksimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_modelsimportLLM,BaseLanguageModel,LangSmithParamsfromlangchain_core.messagesimportAIMessageChunk,ToolCallfromlangchain_core.messages.toolimporttool_call,tool_call_chunkfromlangchain_core.outputsimportGeneration,GenerationChunk,LLMResultfromlangchain_core.utilsimportsecret_from_envfrompydanticimportConfigDict,Field,SecretStr,model_validatorfromtyping_extensionsimportSelffromlangchain_aws.function_callingimport_tools_in_paramsfromlangchain_aws.utilsimport(anthropic_tokens_supported,create_aws_client,enforce_stop_tokens,get_num_tokens_anthropic,get_token_ids_anthropic,thinking_in_params,)logger=logging.getLogger(__name__)AMAZON_BEDROCK_TRACE_KEY="amazon-bedrock-trace"GUARDRAILS_BODY_KEY="amazon-bedrock-guardrailAction"HUMAN_PROMPT="\n\nHuman:"ASSISTANT_PROMPT="\n\nAssistant:"ALTERNATION_ERROR=("Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'.")def_add_newlines_before_ha(input_text:str)->str:new_text=input_textforwordin["Human:","Assistant:"]:new_text=new_text.replace(word,"\n\n"+word)foriinrange(2):new_text=new_text.replace("\n\n\n"+word,"\n\n"+word)returnnew_textdef_human_assistant_format(input_text:str)->str:ifinput_text.count("Human:")==0or(input_text.find("Human:")>input_text.find("Assistant:")and"Assistant:"ininput_text):input_text=HUMAN_PROMPT+" "+input_text# SILENT CORRECTIONifinput_text.count("Assistant:")==0:input_text=input_text+ASSISTANT_PROMPT# SILENT CORRECTIONifinput_text[:len("Human:")]=="Human:":input_text="\n\n"+input_textinput_text=_add_newlines_before_ha(input_text)count=0# track alternationforiinrange(len(input_text)):ifinput_text[i:i+len(HUMAN_PROMPT)]==HUMAN_PROMPT:ifcount%2==0:count+=1else:warnings.warn(ALTERNATION_ERROR+f" Received {input_text}")ifinput_text[i:i+len(ASSISTANT_PROMPT)]==ASSISTANT_PROMPT:ifcount%2==1:count+=1else:warnings.warn(ALTERNATION_ERROR+f" Received {input_text}")ifcount%2==1:# Only saw Human, no Assistantinput_text=input_text+ASSISTANT_PROMPT# SILENT CORRECTIONreturninput_textdef_stream_response_to_generation_chunk(stream_response:Dict[str,Any],provider:str,output_key:str,messages_api:bool,coerce_content_to_string:bool,)->Union[GenerationChunk,AIMessageChunk,None]:# type ignore[return]"""Convert a stream response to a generation chunk."""ifmessages_api:msg_type=stream_response.get("type")ifmsg_type=="message_start":returnAIMessageChunk(content=""ifcoerce_content_to_stringelse[],)elif(msg_type=="content_block_start"andstream_response["content_block"]isnotNoneandstream_response["content_block"]["type"]=="tool_use"):content_block=stream_response["content_block"]content_block["index"]=stream_response["index"]tc_chunk=tool_call_chunk(index=stream_response["index"],id=stream_response["content_block"]["id"],name=stream_response["content_block"]["name"],args="",)returnAIMessageChunk(content=[content_block],tool_call_chunks=[tc_chunk],# type: ignore)elifmsg_type=="content_block_delta":ifnotstream_response["delta"]:returnAIMessageChunk(content="")ifstream_response["delta"]["type"]=="text_delta":ifcoerce_content_to_string:returnAIMessageChunk(content=stream_response["delta"]["text"])else:content_block=stream_response["delta"]content_block["index"]=stream_response["index"]content_block["type"]="text"returnAIMessageChunk(content=[content_block])elifstream_response["delta"]["type"]=="input_json_delta":content_block=stream_response["delta"]content_block["index"]=stream_response["index"]content_block["type"]="tool_use"tc_chunk={"index":stream_response["index"],"id":None,"name":None,"args":stream_response["delta"]["partial_json"],}returnAIMessageChunk(content=[content_block],tool_call_chunks=[tc_chunk],# type: ignore)elifstream_response["delta"]["type"]=="thinking_delta":content_block=stream_response["delta"]content_block["index"]=stream_response["index"]content_block["type"]="thinking"returnAIMessageChunk(content=[content_block])elifstream_response["delta"]["type"]=="signature_delta":content_block=stream_response["delta"]content_block["index"]=stream_response["index"]content_block["type"]="thinking"returnAIMessageChunk(content=[content_block])elifmsg_type=="message_delta":returnAIMessageChunk(content="",response_metadata={"stop_reason":stream_response["delta"].get("stop_reason"),"stop_sequence":stream_response["delta"].get("stop_sequence"),},)else:returnNone# chunk obj format varies with providergeneration_info={k:vfork,vinstream_response.items()ifknotin[output_key,"prompt_token_count","generation_token_count","created"]}returnGenerationChunk(text=(stream_response[output_key]ifprovidernotin["mistral","deepseek"]elsestream_response[output_key][0]["text"]),generation_info=generation_info,)def_combine_generation_info_for_llm_result(chunks_generation_info:List[Dict[str,Any]],provider_stop_code:str)->Dict[str,Any]:""" Returns usage and stop reason information with the intent to pack into an LLMResult Takes a list of generation_info from GenerationChunks If the messages api is being used, the generation_info from some of these chunks should contain "usage" keys if not, the token counts should be found within "amazon-bedrock-invocationMetrics" """total_usage_info={"prompt_tokens":0,"completion_tokens":0}stop_reason=""forgeneration_infoinchunks_generation_info:if"usage"ingeneration_info:usage_info=generation_info["usage"]if"input_tokens"inusage_info:total_usage_info["prompt_tokens"]+=sum(usage_info["input_tokens"])if"output_tokens"inusage_info:total_usage_info["completion_tokens"]+=sum(usage_info["output_tokens"])if"amazon-bedrock-invocationMetrics"ingeneration_info:usage_info=generation_info["amazon-bedrock-invocationMetrics"]if"inputTokenCount"inusage_info:total_usage_info["prompt_tokens"]+=usage_info["inputTokenCount"]if"outputTokenCount"inusage_info:total_usage_info["completion_tokens"]+=usage_info["outputTokenCount"]ifprovider_stop_codeisnotNoneandprovider_stop_codeingeneration_info:# uses the last stop reasonstop_reason=generation_info[provider_stop_code]total_usage_info["total_tokens"]=(total_usage_info["prompt_tokens"]+total_usage_info["completion_tokens"])return{"usage":total_usage_info,"stop_reason":stop_reason}def_get_invocation_metrics_chunk(chunk:Dict[str,Any])->GenerationChunk:generation_info={}ifmetrics:=chunk.get("amazon-bedrock-invocationMetrics"):input_tokens=metrics.get("inputTokenCount",0)output_tokens=metrics.get("outputTokenCount",0)generation_info["usage_metadata"]={"input_tokens":input_tokens,"output_tokens":output_tokens,"total_tokens":input_tokens+output_tokens,}returnGenerationChunk(text="",generation_info=generation_info)
[docs]classLLMInputOutputAdapter:"""Adapter class to prepare the inputs from Langchain to a format that LLM model expects. It also provides helper function to extract the generated text from the model response."""provider_to_output_key_map={"anthropic":"completion","amazon":"outputText","cohere":"text","deepseek":"choices","meta":"generation","mistral":"outputs",}
[docs]@classmethoddefprepare_input(cls,provider:str,model_kwargs:Dict[str,Any],prompt:Optional[str]=None,system:Optional[str]=None,messages:Optional[List[Dict]]=None,tools:Optional[List[AnthropicTool]]=None,*,max_tokens:Optional[int]=None,temperature:Optional[float]=None,)->Dict[str,Any]:input_body={**model_kwargs}ifprovider=="anthropic":ifmessages:# Check if we're using extended thinkingthinking_enabled=thinking_in_params(model_kwargs)iftools:input_body["tools"]=toolsinput_body["anthropic_version"]="bedrock-2023-05-31"# Special handling for tool results with thinkingifthinking_enabled:# Check if we have a tool_result in the last user message# and need to ensure the previous assistant message starts with thinkingif(len(messages)>=2andmessages[-1]["role"]=="user"andmessages[-2]["role"]=="assistant"):# Check if the last user message contains tool_resultlast_user_msg=messages[-1].get("content",[])tool_result=Falseifisinstance(last_user_msg,list):tool_result=any(item.get("type")=="tool_result"foriteminlast_user_msgifisinstance(item,dict))iftool_result:# Make sure the assistant message has thinking firstasst_content=messages[-2].get("content",[])ifisinstance(asst_content,list)andasst_content:# Find thinking blocks and move them to the front if neededthinking_blocks=[blockforblockinasst_contentifisinstance(block,dict)andblock.get("type")in["thinking","redacted_thinking"]]ifthinking_blocksandasst_content[0].get("type")notin["thinking","redacted_thinking"]:# Reorder to put thinking blocks firstnew_content=thinking_blocks.copy()new_content.extend([blockforblockinasst_contentifisinstance(block,dict)andblock.get("type")notin["thinking","redacted_thinking"]])messages[-2]["content"]=new_contentinput_body["messages"]=messagesifsystem:input_body["system"]=systemifmax_tokens:input_body["max_tokens"]=max_tokenselif"max_tokens"notininput_body:input_body["max_tokens"]=1024ifprompt:input_body["prompt"]=_human_assistant_format(prompt)ifmax_tokens:input_body["max_tokens_to_sample"]=max_tokenselif"max_tokens_to_sample"notininput_body:input_body["max_tokens_to_sample"]=1024iftemperatureisnotNone:input_body["temperature"]=temperatureelifproviderin("ai21","cohere","meta","mistral","deepseek"):input_body["prompt"]=promptifmax_tokens:ifprovider=="cohere":input_body["max_tokens"]=max_tokenselifprovider=="meta":input_body["max_gen_len"]=max_tokenselifprovider=="mistral":input_body["max_tokens"]=max_tokenselifprovider=="deepseek":input_body["max_tokens"]=max_tokenselse:# TODO: Add AI21 support, param depends on specific model.passiftemperatureisnotNone:input_body["temperature"]=temperatureelifprovider=="amazon":input_body=dict()input_body["inputText"]=promptinput_body["textGenerationConfig"]={**model_kwargs}ifmax_tokens:input_body["textGenerationConfig"]["maxTokenCount"]=max_tokensiftemperatureisnotNone:input_body["textGenerationConfig"]["temperature"]=temperatureelse:input_body["inputText"]=promptreturninput_body
[docs]@classmethoddefprepare_output(cls,provider:str,response:Any)->dict:text=""tool_calls=[]thinking={}response_body=json.loads(response.get("body").read().decode())ifprovider=="anthropic":if"completion"inresponse_body:text=response_body.get("completion")elif"content"inresponse_body:content=response_body.get("content",[])# Extract text contenttext_blocks=[block["text"]forblockincontentifblock.get("type")=="text"]iftext_blocks:text="".join(text_blocks)# Extract thinking contentthinking_blocks=[blockforblockincontentifblock.get("type")=="thinking"]ifthinking_blocks:# Get the first thinking block (there's typically just one)thinking_block=thinking_blocks[0]thinking={"text":thinking_block.get("thinking",""),"signature":thinking_block.get("signature",""),}# Extract tool calls if presentifany(block.get("type")=="tool_use"forblockincontent):tool_calls=extract_tool_calls(content)else:ifprovider=="ai21":text=response_body.get("completions")[0].get("data").get("text")elifprovider=="cohere":text=response_body.get("generations")[0].get("text")elifprovider=="meta":text=response_body.get("generation")elifprovider=="mistral":text=response_body.get("outputs")[0].get("text")elifprovider=="deepseek":text=response_body.get("choices")[0].get("text")else:text=response_body.get("results")[0].get("outputText")headers=response.get("ResponseMetadata",{}).get("HTTPHeaders",{})prompt_tokens=int(headers.get("x-amzn-bedrock-input-token-count",0))completion_tokens=int(headers.get("x-amzn-bedrock-output-token-count",0))return{"text":text,"thinking":thinking,"tool_calls":tool_calls,"body":response_body,"usage":{"prompt_tokens":prompt_tokens,"completion_tokens":completion_tokens,"total_tokens":prompt_tokens+completion_tokens,},"stop_reason":response_body.get("stop_reason"),}
[docs]@classmethoddefprepare_output_stream(cls,provider:str,response:Any,stop:Optional[List[str]]=None,messages_api:bool=False,coerce_content_to_string:bool=False,)->Iterator[Union[GenerationChunk,AIMessageChunk]]:stream=response.get("body")ifnotstream:returnifmessages_api:output_key="message"else:output_key=cls.provider_to_output_key_map.get(provider,"")ifnotoutput_key:raiseValueError(f"Unknown streaming response output key for provider: {provider}")foreventinstream:chunk=event.get("chunk")ifnotchunk:continuechunk_obj=json.loads(chunk.get("bytes").decode())ifprovider=="cohere"and(chunk_obj["is_finished"]orchunk_obj[output_key]=="<EOS_TOKEN>"):returngeneration_chunk=_stream_response_to_generation_chunk(chunk_obj,provider=provider,output_key=output_key,messages_api=messages_api,coerce_content_to_string=coerce_content_to_string,)ifgeneration_chunk:yieldgeneration_chunkifprovider=="deepseek":opt=chunk_obj.get(output_key,[{}])[0]ifopt.get("stop_reason")in["stop","length"]oropt.get("finish_reason")=="eos_token":yield_get_invocation_metrics_chunk(chunk_obj)returnelif(provider=="mistral"andchunk_obj.get(output_key,[{}])[0].get("stop_reason","")=="stop"):yield_get_invocation_metrics_chunk(chunk_obj)returnelifprovider=="meta"andchunk_obj.get("stop_reason","")=="stop":yield_get_invocation_metrics_chunk(chunk_obj)returnelifmessages_apiand(chunk_obj.get("type")=="message_stop"):yield_get_invocation_metrics_chunk(chunk_obj)return
[docs]@classmethodasyncdefaprepare_output_stream(cls,provider:str,response:Any,stop:Optional[List[str]]=None,messages_api:bool=False,coerce_content_to_string:bool=False,)->AsyncIterator[Union[GenerationChunk,AIMessageChunk]]:stream=response.get("body")ifnotstream:returnoutput_key=cls.provider_to_output_key_map.get(provider,None)ifnotoutput_key:raiseValueError(f"Unknown streaming response output key for provider: {provider}")foreventinstream:chunk=event.get("chunk")ifnotchunk:continuechunk_obj=json.loads(chunk.get("bytes").decode())ifprovider=="cohere"and(chunk_obj["is_finished"]orchunk_obj[output_key]=="<EOS_TOKEN>"):returnif(provider=="mistral"andchunk_obj.get(output_key,[{}])[0].get("stop_reason","")=="stop"):returngeneration_chunk=_stream_response_to_generation_chunk(chunk_obj,provider=provider,output_key=output_key,messages_api=messages_api,coerce_content_to_string=coerce_content_to_string,)ifgeneration_chunk:yieldgeneration_chunkelse:continue
[docs]classBedrockBase(BaseLanguageModel,ABC):"""Base class for Bedrock models."""client:Any=Field(default=None,exclude=True)#: :meta private:region_name:Optional[str]=Field(default=None,alias="region")"""The aws region e.g., `us-west-2`. Falls back to AWS_REGION or AWS_DEFAULT_REGION env variable or region specified in ~/.aws/config in case it is not provided here. """credentials_profile_name:Optional[str]=Field(default=None,exclude=True)"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which has either access keys or role information specified. If not specified, the default credential profile or, if on an EC2 instance, credentials from IMDS will be used. See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html """aws_access_key_id:Optional[SecretStr]=Field(default_factory=secret_from_env("AWS_ACCESS_KEY_ID",default=None))"""AWS access key id. If provided, aws_secret_access_key must also be provided. If not specified, the default credential profile or, if on an EC2 instance, credentials from IMDS will be used. See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html If not provided, will be read from 'AWS_ACCESS_KEY_ID' environment variable. """aws_secret_access_key:Optional[SecretStr]=Field(default_factory=secret_from_env("AWS_SECRET_ACCESS_KEY",default=None))"""AWS secret_access_key. If provided, aws_access_key_id must also be provided. If not specified, the default credential profile or, if on an EC2 instance, credentials from IMDS will be used. See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html If not provided, will be read from 'AWS_SECRET_ACCESS_KEY' environment variable. """aws_session_token:Optional[SecretStr]=Field(default_factory=secret_from_env("AWS_SESSION_TOKEN",default=None))"""AWS session token. If provided, aws_access_key_id and aws_secret_access_key must also be provided. Not required unless using temporary credentials. See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html If not provided, will be read from 'AWS_SESSION_TOKEN' environment variable. """config:Any=None"""An optional botocore.config.Config instance to pass to the client."""provider:Optional[str]=None"""The model provider, e.g., amazon, cohere, ai21, etc. When not supplied, provider is extracted from the first part of the model_id e.g. 'amazon' in 'amazon.titan-text-express-v1'. This value should be provided for model ids that do not have the provider in them, e.g., custom and provisioned models that have an ARN associated with them."""model_id:str=Field(alias="model")"""Id of the model to call, e.g., amazon.titan-text-express-v1, this is equivalent to the modelId property in the list-foundation-models api. For custom and provisioned models, an ARN value is expected."""model_kwargs:Optional[Dict[str,Any]]=None"""Keyword arguments to pass to the model."""endpoint_url:Optional[str]=None"""Needed if you don't want to default to us-east-1 endpoint"""streaming:bool=False"""Whether to stream the results."""provider_stop_sequence_key_name_map:Mapping[str,str]={"anthropic":"stop_sequences","amazon":"stopSequences","ai21":"stop_sequences","cohere":"stop_sequences","mistral":"stop_sequences",}provider_stop_reason_key_map:Mapping[str,str]={"anthropic":"stop_reason","amazon":"completionReason","ai21":"finishReason","cohere":"finish_reason","mistral":"stop_reason",}guardrails:Optional[Mapping[str,Any]]={"trace":None,"guardrailIdentifier":None,"guardrailVersion":None,}""" An optional dictionary to configure guardrails for Bedrock. This field 'guardrails' consists of two keys: 'guardrailId' and 'guardrailVersion', which should be strings, but are initialized to None. It's used to determine if specific guardrails are enabled and properly set. Type: Optional[Mapping[str, str]]: A mapping with 'guardrailId' and 'guardrailVersion' keys. Example: llm = BedrockLLM(model_id="<model_id>", client=<bedrock_client>, model_kwargs={}, guardrails={ "guardrailId": "<guardrail_id>", "guardrailVersion": "<guardrail_version>"}) To enable tracing for guardrails, set the 'trace' key to True and pass a callback handler to the 'run_manager' parameter of the 'generate', '_call' methods. Example: llm = BedrockLLM(model_id="<model_id>", client=<bedrock_client>, model_kwargs={}, guardrails={ "guardrailId": "<guardrail_id>", "guardrailVersion": "<guardrail_version>", "trace": True}, callbacks=[BedrockAsyncCallbackHandler()]) [https://python.langchain.com/docs/modules/callbacks/] for more information on callback handlers. class BedrockAsyncCallbackHandler(AsyncCallbackHandler): async def on_llm_error( self, error: BaseException, **kwargs: Any, ) -> Any: reason = kwargs.get("reason") if reason == "GUARDRAIL_INTERVENED": ...Logic to handle guardrail intervention... """# noqa: E501temperature:Optional[float]=Nonemax_tokens:Optional[int]=None@propertydeflc_secrets(self)->Dict[str,str]:return{"aws_access_key_id":"AWS_ACCESS_KEY_ID","aws_secret_access_key":"AWS_SECRET_ACCESS_KEY","aws_session_token":"AWS_SESSION_TOKEN",}@model_validator(mode="after")defvalidate_environment(self)->Self:"""Validate that AWS credentials to and python package exists in environment."""ifself.model_kwargs:if"temperature"inself.model_kwargs:ifself.temperatureisNone:self.temperature=self.model_kwargs["temperature"]self.model_kwargs.pop("temperature")if"max_tokens"inself.model_kwargs:ifnotself.max_tokens:self.max_tokens=self.model_kwargs["max_tokens"]self.model_kwargs.pop("max_tokens")# Skip creating new client if passed in constructorifself.clientisNone:self.client=create_aws_client(region_name=self.region_name,credentials_profile_name=self.credentials_profile_name,aws_access_key_id=self.aws_access_key_id,aws_secret_access_key=self.aws_secret_access_key,aws_session_token=self.aws_session_token,endpoint_url=self.endpoint_url,config=self.config,service_name="bedrock-runtime",)returnself@propertydef_identifying_params(self)->Dict[str,Any]:_model_kwargs=self.model_kwargsor{}return{"model_id":self.model_id,"provider":self._get_provider(),"stream":self.streaming,"trace":self.guardrails.get("trace"),# type: ignore[union-attr]"guardrailIdentifier":self.guardrails.get("guardrailIdentifier",None),# type: ignore[union-attr]"guardrailVersion":self.guardrails.get("guardrailVersion",None),# type: ignore[union-attr]**_model_kwargs,}def_get_provider(self)->str:# If provider supplied by user, return as-isifself.provider:returnself.provider# If model_id is an arn, can't extract provider from model_id,# so this requires passing in the provider by userifself.model_id.startswith("arn"):raiseValueError("Model provider should be supplied when passing a model ARN as ""model_id")# If model_id has region prefixed to them,# for example eu.anthropic.claude-3-haiku-20240307-v1:0,# provider is the second part, otherwise, the first partparts=self.model_id.split(".",maxsplit=2)return(parts[1]if(len(parts)>1andparts[0].lower()in{"eu","us","ap","sa"})elseparts[0])def_get_model(self)->str:returnself.model_id.split(".",maxsplit=1)[-1]@propertydef_model_is_anthropic(self)->bool:returnself._get_provider()=="anthropic"@propertydef_guardrails_enabled(self)->bool:""" Determines if guardrails are enabled and correctly configured. Checks if 'guardrails' is a dictionary with non-empty 'id' and 'version' keys. Checks if 'guardrails.trace' is true. Returns: bool: True if guardrails are correctly configured, False otherwise. Raises: TypeError: If 'guardrails' lacks 'id' or 'version' keys. """try:return(isinstance(self.guardrails,dict)andbool(self.guardrails["guardrailIdentifier"])andbool(self.guardrails["guardrailVersion"]))exceptKeyErrorase:raiseTypeError("Guardrails must be a dictionary with 'guardrailIdentifier' \ and 'guardrailVersion' keys.")fromedef_prepare_input_and_invoke(self,prompt:Optional[str]=None,system:Optional[str]=None,messages:Optional[List[Dict]]=None,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Tuple[str,List[ToolCall],Dict[str,Any],]:_model_kwargs=self.model_kwargsor{}provider=self._get_provider()params={**_model_kwargs,**kwargs}# Pre-process for thinking with tool useifmessagesand"claude-3"inself._get_model()andthinking_in_params(params):# We need to ensure thinking blocks are first in assistant messages# Process each message in the sequencefori,messageinenumerate(messages):ifmessage.get("role")=="assistant"andi>0:content=message.get("content",[])ifisinstance(content,list)andcontent:# Find any thinking blocksthinking_blocks=[jforj,iteminenumerate(content)ifisinstance(item,dict)anditem.get("type")in["thinking","redacted_thinking"]]# If thinking blocks exist but aren't first, reorderifthinking_blocksandthinking_blocks[0]>0:# Extract thinking blocksthinking_content=[content[j]forjinthinking_blocks]# Extract non-thinking blocksother_content=[itemforj,iteminenumerate(content)ifjnotinthinking_blocks]# Reorder with thinking firstmessage["content"]=thinking_content+other_contentif"claude-3"inself._get_model()and_tools_in_params(params):input_body=LLMInputOutputAdapter.prepare_input(provider=provider,model_kwargs=params,prompt=prompt,system=system,messages=messages,tools=params["tools"],max_tokens=self.max_tokens,temperature=self.temperature,)else:input_body=LLMInputOutputAdapter.prepare_input(provider=provider,model_kwargs=params,prompt=prompt,system=system,messages=messages,max_tokens=self.max_tokens,temperature=self.temperature,)body=json.dumps(input_body)accept="application/json"contentType="application/json"request_options={"body":body,"modelId":self.model_id,"accept":accept,"contentType":contentType,}ifself._guardrails_enabled:request_options["guardrailIdentifier"]=self.guardrails.get(# type: ignore[union-attr]"guardrailIdentifier","")request_options["guardrailVersion"]=self.guardrails.get(# type: ignore[union-attr]"guardrailVersion","")ifself.guardrails.get("trace"):# type: ignore[union-attr]request_options["trace"]="ENABLED"try:logger.debug(f"Request body sent to bedrock: {request_options}")logger.info("Using Bedrock Invoke API to generate response")response=self.client.invoke_model(**request_options)(text,thinking,tool_calls,body,usage_info,stop_reason,)=LLMInputOutputAdapter.prepare_output(provider,response).values()logger.debug(f"Response received from Bedrock: {response}")exceptExceptionase:logger.exception("Error raised by bedrock service")ifrun_managerisnotNone:run_manager.on_llm_error(e)raiseeifstopisnotNone:text=enforce_stop_tokens(text,stop)llm_output={"usage":usage_info,"stop_reason":stop_reason,"thinking":thinking,}# Verify and raise a callback error if any intervention occurs or a signal is# sent from a Bedrock service,# such as when guardrails are triggered.services_trace=self._get_bedrock_services_signal(body)# type: ignore[arg-type]ifrun_managerisnotNoneandservices_trace.get("signal"):run_manager.on_llm_error(Exception(f"Error raised by bedrock service: {services_trace.get('reason')}"),**services_trace,)returntext,tool_calls,llm_outputdef_get_bedrock_services_signal(self,body:dict)->dict:""" This function checks the response body for an interrupt flag or message that indicates whether any of the Bedrock services have intervened in the processing flow. It is primarily used to identify modifications or interruptions imposed by these services during the request-response cycle with a Large Language Model (LLM). """# noqa: E501if(self._guardrails_enabledandself.guardrails.get("trace")# type: ignore[union-attr]andself._is_guardrails_intervention(body)):return{"signal":True,"reason":"GUARDRAIL_INTERVENED","trace":body.get(AMAZON_BEDROCK_TRACE_KEY),}return{"signal":False,"reason":None,"trace":None,}def_is_guardrails_intervention(self,body:dict)->bool:returnbody.get(GUARDRAILS_BODY_KEY)=="INTERVENED"def_prepare_input_and_invoke_stream(self,prompt:Optional[str]=None,system:Optional[str]=None,messages:Optional[List[Dict]]=None,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[Union[GenerationChunk,AIMessageChunk]]:_model_kwargs=self.model_kwargsor{}provider=self._get_provider()ifstop:ifprovidernotinself.provider_stop_sequence_key_name_map:raiseValueError(f"Stop sequence key name for {provider} is not supported.")# stop sequence from _generate() overrides# stop sequences in the class attributeifk:=self.provider_stop_sequence_key_name_map.get(provider):_model_kwargs[k]=stopifprovider=="cohere":_model_kwargs["stream"]=Trueparams={**_model_kwargs,**kwargs}input_body=LLMInputOutputAdapter.prepare_input(provider=provider,prompt=prompt,system=system,messages=messages,model_kwargs=params,max_tokens=self.max_tokens,temperature=self.temperature,)coerce_content_to_string=Trueif"claude-3"inself._get_model():if_tools_in_params(params):coerce_content_to_string=Falseinput_body=LLMInputOutputAdapter.prepare_input(provider=provider,model_kwargs=params,prompt=prompt,system=system,messages=messages,tools=params["tools"],max_tokens=self.max_tokens,temperature=self.temperature,)elifthinking_in_params(params):coerce_content_to_string=Falsebody=json.dumps(input_body)request_options={"body":body,"modelId":self.model_id,"accept":"application/json","contentType":"application/json",}ifself._guardrails_enabled:request_options["guardrailIdentifier"]=self.guardrails.get(# type: ignore[union-attr]"guardrailIdentifier","")request_options["guardrailVersion"]=self.guardrails.get(# type: ignore[union-attr]"guardrailVersion","")ifself.guardrails.get("trace"):# type: ignore[union-attr]request_options["trace"]="ENABLED"try:response=self.client.invoke_model_with_response_stream(**request_options)exceptExceptionase:logger.exception("Error raised by bedrock service")ifrun_managerisnotNone:run_manager.on_llm_error(e)raiseeforchunkinLLMInputOutputAdapter.prepare_output_stream(provider,response,stop,TrueifmessageselseFalse,coerce_content_to_string=coerce_content_to_string,):yieldchunk# verify and raise callback error if any middleware intervenedifnotisinstance(chunk,AIMessageChunk):self._get_bedrock_services_signal(chunk.generation_info)# type: ignore[arg-type]asyncdef_aprepare_input_and_invoke_stream(self,prompt:str,system:Optional[str]=None,messages:Optional[List[Dict]]=None,stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[Union[GenerationChunk,AIMessageChunk]]:_model_kwargs=self.model_kwargsor{}provider=self._get_provider()ifstop:ifprovidernotinself.provider_stop_sequence_key_name_map:raiseValueError(f"Stop sequence key name for {provider} is not supported.")ifk:=self.provider_stop_sequence_key_name_map.get(provider):_model_kwargs[k]=stopifprovider=="cohere":_model_kwargs["stream"]=Trueparams={**_model_kwargs,**kwargs}if"claude-3"inself._get_model()and_tools_in_params(params):input_body=LLMInputOutputAdapter.prepare_input(provider=provider,model_kwargs=params,prompt=prompt,system=system,messages=messages,tools=params["tools"],max_tokens=self.max_tokens,temperature=self.temperature,)else:input_body=LLMInputOutputAdapter.prepare_input(provider=provider,prompt=prompt,system=system,messages=messages,model_kwargs=params,max_tokens=self.max_tokens,temperature=self.temperature,)body=json.dumps(input_body)response=awaitasyncio.get_running_loop().run_in_executor(None,lambda:self.client.invoke_model_with_response_stream(body=body,modelId=self.model_id,accept="application/json",contentType="application/json",),)asyncforchunkinLLMInputOutputAdapter.aprepare_output_stream(provider,response,stop,TrueifmessageselseFalse,):yieldchunk
[docs]classBedrockLLM(LLM,BedrockBase):"""Bedrock models. To authenticate, the AWS client uses the following methods to automatically load credentials: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html If a specific credential profile should be used, you must pass the name of the profile from the ~/.aws/credentials file that is to be used. Make sure the credentials / roles used have the required policies to access the Bedrock service. """""" Example: .. code-block:: python from bedrock_langchain.bedrock_llm import BedrockLLM llm = BedrockLLM( credentials_profile_name="default", model_id="amazon.titan-text-express-v1", streaming=True ) """@model_validator(mode="after")defvalidate_environment_llm(self)->Self:model_id=self.model_idifmodel_id.startswith("anthropic.claude-3"):raiseValueError("Claude v3 models are not supported by this LLM.""Please use `from langchain_aws import ChatBedrock` ""instead.")returnself@propertydef_llm_type(self)->str:"""Return type of llm."""return"amazon_bedrock"@classmethoddefis_lc_serializable(cls)->bool:"""Return whether this model can be serialized by Langchain."""returnTrue@classmethoddefget_lc_namespace(cls)->List[str]:"""Get the namespace of the langchain object."""return["langchain","llms","bedrock"]@propertydeflc_attributes(self)->Dict[str,Any]:attributes:Dict[str,Any]={}ifself.region_name:attributes["region_name"]=self.region_namereturnattributesdef_get_ls_params(self,stop:Optional[List[str]]=None,**kwargs:Any)->LangSmithParams:"""Get standard params for tracing."""ls_params=super()._get_ls_params(stop=stop,**kwargs)ls_params["ls_provider"]="amazon_bedrock"ls_params["ls_model_name"]=self.model_idreturnls_paramsmodel_config=ConfigDict(extra="forbid",populate_by_name=True,)def_stream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[GenerationChunk]:"""Call out to Bedrock service with streaming. Args: prompt (str): The prompt to pass into the model stop (Optional[List[str]], optional): Stop sequences. These will override any stop sequences in the `model_kwargs` attribute. Defaults to None. run_manager (Optional[CallbackManagerForLLMRun], optional): Callback run managers used to process the output. Defaults to None. Returns: Iterator[GenerationChunk]: Generator that yields the streamed responses. Yields: Iterator[GenerationChunk]: Responses from the model. """returnself._prepare_input_and_invoke_stream(# type: ignoreprompt=prompt,stop=stop,run_manager=run_manager,**kwargs)def_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call out to Bedrock service model. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = llm("Tell me a joke.") """provider=self._get_provider()provider_stop_reason_code=self.provider_stop_reason_key_map.get(provider,"stop_reason")ifself.streaming:all_chunks:List[GenerationChunk]=[]completion=""forchunkinself._stream(prompt=prompt,stop=stop,run_manager=run_manager,**kwargs):completion+=chunk.textall_chunks.append(chunk)ifrun_managerisnotNone:chunks_generation_info=[chunk.generation_infoforchunkinall_chunksifchunk.generation_infoisnotNone]llm_output=_combine_generation_info_for_llm_result(chunks_generation_info,provider_stop_code=provider_stop_reason_code)all_generations=[Generation(text=chunk.text,generation_info=chunk.generation_info)forchunkinall_chunks]run_manager.on_llm_end(LLMResult(generations=[all_generations],llm_output=llm_output))returncompletiontext,tool_calls,llm_output=self._prepare_input_and_invoke(prompt=prompt,stop=stop,run_manager=run_manager,**kwargs)ifrun_managerisnotNone:run_manager.on_llm_end(LLMResult(generations=[[Generation(text=text)]],llm_output=llm_output))returntextasyncdef_astream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncGenerator[GenerationChunk,None]:"""Call out to Bedrock service with streaming. Args: prompt (str): The prompt to pass into the model stop (Optional[List[str]], optional): Stop sequences. These will override any stop sequences in the `model_kwargs` attribute. Defaults to None. run_manager (Optional[CallbackManagerForLLMRun], optional): Callback run managers used to process the output. Defaults to None. Yields: AsyncGenerator[GenerationChunk, None]: Generator that asynchronously yields the streamed responses. """asyncforchunkinself._aprepare_input_and_invoke_stream(prompt=prompt,stop=stop,run_manager=run_manager,**kwargs):yieldchunk# type: ignoreasyncdef_acall(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call out to Bedrock service model. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = await llm._acall("Tell me a joke.") """ifnotself.streaming:raiseValueError("Streaming must be set to True for async operations. ")provider=self._get_provider()provider_stop_reason_code=self.provider_stop_reason_key_map.get(provider,"stop_reason")chunks=[chunkasyncforchunkinself._astream(prompt=prompt,stop=stop,run_manager=run_manager,**kwargs)]ifrun_managerisnotNone:chunks_generation_info=[chunk.generation_infoforchunkinchunksifchunk.generation_infoisnotNone]llm_output=_combine_generation_info_for_llm_result(chunks_generation_info,provider_stop_code=provider_stop_reason_code)generations=[Generation(text=chunk.text,generation_info=chunk.generation_info)forchunkinchunks]awaitrun_manager.on_llm_end(LLMResult(generations=[generations],llm_output=llm_output))return"".join([chunk.textforchunkinchunks])
[docs]defget_token_ids(self,text:str)->List[int]:ifself._model_is_anthropicandnotself.custom_get_token_ids:ifanthropic_tokens_supported():returnget_token_ids_anthropic(text)else:warnings.warn("Falling back to default token method due to missing or ""incompatible `anthropic` installation ""(needs <=0.38.0).\n\nIf using `anthropic>0.38.0`, ""it is recommended to provide the model class with a ""custom_get_token_ids method implementing a more accurate ""tokenizer for Anthropic. For get_num_tokens, as another ""alternative, you can implement your own token counter method ""using the ChatAnthropic or AnthropicLLM classes.")returnsuper().get_token_ids(text)