importasyncioimportjsonimportwarningsfromabcimportABCfromtypingimport(Any,AsyncGenerator,AsyncIterator,Dict,Iterator,List,Mapping,Optional,Tuple,)fromlangchain_core._api.deprecationimportdeprecatedfromlangchain_core.callbacksimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_models.llmsimportLLMfromlangchain_core.outputsimportGenerationChunkfromlangchain_core.utilsimportget_from_dict_or_env,pre_initfrompydanticimportBaseModel,ConfigDict,Fieldfromlangchain_community.llms.utilsimportenforce_stop_tokensfromlangchain_community.utilities.anthropicimport(get_num_tokens_anthropic,get_token_ids_anthropic,)AMAZON_BEDROCK_TRACE_KEY="amazon-bedrock-trace"GUARDRAILS_BODY_KEY="amazon-bedrock-guardrailAssessment"HUMAN_PROMPT="\n\nHuman:"ASSISTANT_PROMPT="\n\nAssistant:"ALTERNATION_ERROR=("Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'.")def_add_newlines_before_ha(input_text:str)->str:new_text=input_textforwordin["Human:","Assistant:"]:new_text=new_text.replace(word,"\n\n"+word)foriinrange(2):new_text=new_text.replace("\n\n\n"+word,"\n\n"+word)returnnew_textdef_human_assistant_format(input_text:str)->str:ifinput_text.count("Human:")==0or(input_text.find("Human:")>input_text.find("Assistant:")and"Assistant:"ininput_text):input_text=HUMAN_PROMPT+" "+input_text# SILENT CORRECTIONifinput_text.count("Assistant:")==0:input_text=input_text+ASSISTANT_PROMPT# SILENT CORRECTIONifinput_text[:len("Human:")]=="Human:":input_text="\n\n"+input_textinput_text=_add_newlines_before_ha(input_text)count=0# track alternationforiinrange(len(input_text)):ifinput_text[i:i+len(HUMAN_PROMPT)]==HUMAN_PROMPT:ifcount%2==0:count+=1else:warnings.warn(ALTERNATION_ERROR+f" Received {input_text}")ifinput_text[i:i+len(ASSISTANT_PROMPT)]==ASSISTANT_PROMPT:ifcount%2==1:count+=1else:warnings.warn(ALTERNATION_ERROR+f" Received {input_text}")ifcount%2==1:# Only saw Human, no Assistantinput_text=input_text+ASSISTANT_PROMPT# SILENT CORRECTIONreturninput_textdef_stream_response_to_generation_chunk(stream_response:Dict[str,Any],)->GenerationChunk:"""Convert a stream response to a generation chunk."""ifnotstream_response["delta"]:returnGenerationChunk(text="")returnGenerationChunk(text=stream_response["delta"]["text"],generation_info=dict(finish_reason=stream_response.get("stop_reason",None),),)
[docs]classLLMInputOutputAdapter:"""Adapter class to prepare the inputs from Langchain to a format that LLM model expects. It also provides helper function to extract the generated text from the model response."""provider_to_output_key_map={"anthropic":"completion","amazon":"outputText","cohere":"text","meta":"generation","mistral":"outputs",}
[docs]@classmethoddefprepare_output_stream(cls,provider:str,response:Any,stop:Optional[List[str]]=None,messages_api:bool=False,)->Iterator[GenerationChunk]:stream=response.get("body")ifnotstream:returnifmessages_api:output_key="message"else:output_key=cls.provider_to_output_key_map.get(provider,"")ifnotoutput_key:raiseValueError(f"Unknown streaming response output key for provider: {provider}")foreventinstream:chunk=event.get("chunk")ifnotchunk:continuechunk_obj=json.loads(chunk.get("bytes").decode())ifprovider=="cohere"and(chunk_obj["is_finished"]orchunk_obj[output_key]=="<EOS_TOKEN>"):returnelif(provider=="mistral"andchunk_obj.get(output_key,[{}])[0].get("stop_reason","")=="stop"):returnelifmessages_apiand(chunk_obj.get("type")=="content_block_stop"):returnifmessages_apiandchunk_obj.get("type")in("message_start","content_block_start","content_block_delta",):ifchunk_obj.get("type")=="content_block_delta":chk=_stream_response_to_generation_chunk(chunk_obj)yieldchkelse:continueelse:# chunk obj format varies with provideryieldGenerationChunk(text=(chunk_obj[output_key]ifprovider!="mistral"elsechunk_obj[output_key][0]["text"]),generation_info={GUARDRAILS_BODY_KEY:(chunk_obj.get(GUARDRAILS_BODY_KEY)ifGUARDRAILS_BODY_KEYinchunk_objelseNone),},)
[docs]@classmethodasyncdefaprepare_output_stream(cls,provider:str,response:Any,stop:Optional[List[str]]=None)->AsyncIterator[GenerationChunk]:stream=response.get("body")ifnotstream:returnoutput_key=cls.provider_to_output_key_map.get(provider,None)ifnotoutput_key:raiseValueError(f"Unknown streaming response output key for provider: {provider}")foreventinstream:chunk=event.get("chunk")ifnotchunk:continuechunk_obj=json.loads(chunk.get("bytes").decode())ifprovider=="cohere"and(chunk_obj["is_finished"]orchunk_obj[output_key]=="<EOS_TOKEN>"):returnif(provider=="mistral"andchunk_obj.get(output_key,[{}])[0].get("stop_reason","")=="stop"):returnyieldGenerationChunk(text=(chunk_obj[output_key]ifprovider!="mistral"elsechunk_obj[output_key][0]["text"]))
[docs]classBedrockBase(BaseModel,ABC):"""Base class for Bedrock models."""model_config=ConfigDict(protected_namespaces=())client:Any=Field(exclude=True)#: :meta private:region_name:Optional[str]=None"""The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable or region specified in ~/.aws/config in case it is not provided here. """credentials_profile_name:Optional[str]=Field(default=None,exclude=True)"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which has either access keys or role information specified. If not specified, the default credential profile or, if on an EC2 instance, credentials from IMDS will be used. See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html """config:Any=None"""An optional botocore.config.Config instance to pass to the client."""provider:Optional[str]=None"""The model provider, e.g., amazon, cohere, ai21, etc. When not supplied, provider is extracted from the first part of the model_id e.g. 'amazon' in 'amazon.titan-text-express-v1'. This value should be provided for model ids that do not have the provider in them, e.g., custom and provisioned models that have an ARN associated with them."""model_id:str"""Id of the model to call, e.g., amazon.titan-text-express-v1, this is equivalent to the modelId property in the list-foundation-models api. For custom and provisioned models, an ARN value is expected."""model_kwargs:Optional[Dict]=None"""Keyword arguments to pass to the model."""endpoint_url:Optional[str]=None"""Needed if you don't want to default to us-east-1 endpoint"""streaming:bool=False"""Whether to stream the results."""provider_stop_sequence_key_name_map:Mapping[str,str]={"anthropic":"stop_sequences","amazon":"stopSequences","ai21":"stop_sequences","cohere":"stop_sequences","mistral":"stop",}guardrails:Optional[Mapping[str,Any]]={"id":None,"version":None,"trace":False,}""" An optional dictionary to configure guardrails for Bedrock. This field 'guardrails' consists of two keys: 'id' and 'version', which should be strings, but are initialized to None. It's used to determine if specific guardrails are enabled and properly set. Type: Optional[Mapping[str, str]]: A mapping with 'id' and 'version' keys. Example: llm = Bedrock(model_id="<model_id>", client=<bedrock_client>, model_kwargs={}, guardrails={ "id": "<guardrail_id>", "version": "<guardrail_version>"}) To enable tracing for guardrails, set the 'trace' key to True and pass a callback handler to the 'run_manager' parameter of the 'generate', '_call' methods. Example: llm = Bedrock(model_id="<model_id>", client=<bedrock_client>, model_kwargs={}, guardrails={ "id": "<guardrail_id>", "version": "<guardrail_version>", "trace": True}, callbacks=[BedrockAsyncCallbackHandler()]) [https://python.langchain.com/docs/modules/callbacks/] for more information on callback handlers. class BedrockAsyncCallbackHandler(AsyncCallbackHandler): async def on_llm_error( self, error: BaseException, **kwargs: Any, ) -> Any: reason = kwargs.get("reason") if reason == "GUARDRAIL_INTERVENED": ...Logic to handle guardrail intervention... """# noqa: E501
[docs]@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate that AWS credentials to and python package exists in environment."""# Skip creating new client if passed in constructorifvalues.get("client")isnotNone:returnvaluestry:importboto3ifvalues["credentials_profile_name"]isnotNone:session=boto3.Session(profile_name=values["credentials_profile_name"])else:# use default credentialssession=boto3.Session()values["region_name"]=get_from_dict_or_env(values,"region_name","AWS_DEFAULT_REGION",default=session.region_name,)client_params={}ifvalues["region_name"]:client_params["region_name"]=values["region_name"]ifvalues["endpoint_url"]:client_params["endpoint_url"]=values["endpoint_url"]ifvalues["config"]:client_params["config"]=values["config"]values["client"]=session.client("bedrock-runtime",**client_params)exceptImportError:raiseImportError("Could not import boto3 python package. ""Please install it with `pip install boto3`.")exceptValueErrorase:raiseValueError(f"Error raised by bedrock service: {e}")exceptExceptionase:raiseValueError("Could not load credentials to authenticate with AWS client. ""Please check that credentials in the specified "f"profile name are valid. Bedrock error: {e}")fromereturnvalues
@propertydef_identifying_params(self)->Mapping[str,Any]:"""Get the identifying parameters."""_model_kwargs=self.model_kwargsor{}return{**{"model_kwargs":_model_kwargs},}def_get_provider(self)->str:ifself.provider:returnself.providerifself.model_id.startswith("arn"):raiseValueError("Model provider should be supplied when passing a model ARN as model_id")returnself.model_id.split(".")[0]@propertydef_model_is_anthropic(self)->bool:returnself._get_provider()=="anthropic"@propertydef_guardrails_enabled(self)->bool:""" Determines if guardrails are enabled and correctly configured. Checks if 'guardrails' is a dictionary with non-empty 'id' and 'version' keys. Checks if 'guardrails.trace' is true. Returns: bool: True if guardrails are correctly configured, False otherwise. Raises: TypeError: If 'guardrails' lacks 'id' or 'version' keys. """try:return(isinstance(self.guardrails,dict)andbool(self.guardrails["id"])andbool(self.guardrails["version"]))exceptKeyErrorase:raiseTypeError("Guardrails must be a dictionary with 'id' and 'version' keys.")fromedef_get_guardrails_canonical(self)->Dict[str,Any]:""" The canonical way to pass in guardrails to the bedrock service adheres to the following format: "amazon-bedrock-guardrailDetails": { "guardrailId": "string", "guardrailVersion": "string" } """return{"amazon-bedrock-guardrailDetails":{"guardrailId":self.guardrails.get("id"),# type: ignore[union-attr]"guardrailVersion":self.guardrails.get("version"),# type: ignore[union-attr]}}def_prepare_input_and_invoke(self,prompt:Optional[str]=None,system:Optional[str]=None,messages:Optional[List[Dict]]=None,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Tuple[str,Dict[str,Any]]:_model_kwargs=self.model_kwargsor{}provider=self._get_provider()params={**_model_kwargs,**kwargs}ifself._guardrails_enabled:params.update(self._get_guardrails_canonical())input_body=LLMInputOutputAdapter.prepare_input(provider=provider,model_kwargs=params,prompt=prompt,system=system,messages=messages,)body=json.dumps(input_body)accept="application/json"contentType="application/json"request_options={"body":body,"modelId":self.model_id,"accept":accept,"contentType":contentType,}ifself._guardrails_enabled:request_options["guardrail"]="ENABLED"ifself.guardrails.get("trace"):# type: ignore[union-attr]request_options["trace"]="ENABLED"try:response=self.client.invoke_model(**request_options)text,body,usage_info=LLMInputOutputAdapter.prepare_output(provider,response).values()exceptExceptionase:raiseValueError(f"Error raised by bedrock service: {e}")ifstopisnotNone:text=enforce_stop_tokens(text,stop)# Verify and raise a callback error if any intervention occurs or a signal is# sent from a Bedrock service,# such as when guardrails are triggered.services_trace=self._get_bedrock_services_signal(body)# type: ignore[arg-type]ifservices_trace.get("signal")andrun_managerisnotNone:run_manager.on_llm_error(Exception(f"Error raised by bedrock service: {services_trace.get('reason')}"),**services_trace,)returntext,usage_infodef_get_bedrock_services_signal(self,body:dict)->dict:""" This function checks the response body for an interrupt flag or message that indicates whether any of the Bedrock services have intervened in the processing flow. It is primarily used to identify modifications or interruptions imposed by these services during the request-response cycle with a Large Language Model (LLM). """# noqa: E501if(self._guardrails_enabledandself.guardrails.get("trace")# type: ignore[union-attr]andself._is_guardrails_intervention(body)):return{"signal":True,"reason":"GUARDRAIL_INTERVENED","trace":body.get(AMAZON_BEDROCK_TRACE_KEY),}return{"signal":False,"reason":None,"trace":None,}def_is_guardrails_intervention(self,body:dict)->bool:returnbody.get(GUARDRAILS_BODY_KEY)=="GUARDRAIL_INTERVENED"def_prepare_input_and_invoke_stream(self,prompt:Optional[str]=None,system:Optional[str]=None,messages:Optional[List[Dict]]=None,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[GenerationChunk]:_model_kwargs=self.model_kwargsor{}provider=self._get_provider()ifstop:ifprovidernotinself.provider_stop_sequence_key_name_map:raiseValueError(f"Stop sequence key name for {provider} is not supported.")# stop sequence from _generate() overrides# stop sequences in the class attribute_model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)]=stopifprovider=="cohere":_model_kwargs["stream"]=Trueparams={**_model_kwargs,**kwargs}ifself._guardrails_enabled:params.update(self._get_guardrails_canonical())input_body=LLMInputOutputAdapter.prepare_input(provider=provider,prompt=prompt,system=system,messages=messages,model_kwargs=params,)body=json.dumps(input_body)request_options={"body":body,"modelId":self.model_id,"accept":"application/json","contentType":"application/json",}ifself._guardrails_enabled:request_options["guardrail"]="ENABLED"ifself.guardrails.get("trace"):# type: ignore[union-attr]request_options["trace"]="ENABLED"try:response=self.client.invoke_model_with_response_stream(**request_options)exceptExceptionase:raiseValueError(f"Error raised by bedrock service: {e}")forchunkinLLMInputOutputAdapter.prepare_output_stream(provider,response,stop,TrueifmessageselseFalse):# verify and raise callback error if any middleware intervenedself._get_bedrock_services_signal(chunk.generation_info)# type: ignore[arg-type]ifrun_managerisnotNone:run_manager.on_llm_new_token(chunk.text,chunk=chunk)yieldchunkasyncdef_aprepare_input_and_invoke_stream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[GenerationChunk]:_model_kwargs=self.model_kwargsor{}provider=self._get_provider()ifstop:ifprovidernotinself.provider_stop_sequence_key_name_map:raiseValueError(f"Stop sequence key name for {provider} is not supported.")_model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)]=stopifprovider=="cohere":_model_kwargs["stream"]=Trueparams={**_model_kwargs,**kwargs}input_body=LLMInputOutputAdapter.prepare_input(provider=provider,prompt=prompt,model_kwargs=params)body=json.dumps(input_body)response=awaitasyncio.get_running_loop().run_in_executor(None,lambda:self.client.invoke_model_with_response_stream(body=body,modelId=self.model_id,accept="application/json",contentType="application/json",),)asyncforchunkinLLMInputOutputAdapter.aprepare_output_stream(provider,response,stop):ifrun_managerisnotNoneandasyncio.iscoroutinefunction(run_manager.on_llm_new_token):awaitrun_manager.on_llm_new_token(chunk.text,chunk=chunk)elifrun_managerisnotNone:run_manager.on_llm_new_token(chunk.text,chunk=chunk)# type: ignore[unused-coroutine]yieldchunk
[docs]@deprecated(since="0.0.34",removal="1.0",alternative_import="langchain_aws.BedrockLLM")classBedrock(LLM,BedrockBase):"""Bedrock models. To authenticate, the AWS client uses the following methods to automatically load credentials: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html If a specific credential profile should be used, you must pass the name of the profile from the ~/.aws/credentials file that is to be used. Make sure the credentials / roles used have the required policies to access the Bedrock service. """""" Example: .. code-block:: python from bedrock_langchain.bedrock_llm import BedrockLLM llm = BedrockLLM( credentials_profile_name="default", model_id="amazon.titan-text-express-v1", streaming=True ) """
[docs]@pre_initdefvalidate_environment(cls,values:Dict)->Dict:model_id=values["model_id"]ifmodel_id.startswith("anthropic.claude-3"):raiseValueError("Claude v3 models are not supported by this LLM.""Please use `from langchain_community.chat_models import BedrockChat` ""instead.")returnsuper().validate_environment(values)
@propertydef_llm_type(self)->str:"""Return type of llm."""return"amazon_bedrock"@classmethoddefis_lc_serializable(cls)->bool:"""Return whether this model can be serialized by Langchain."""returnTrue@classmethoddefget_lc_namespace(cls)->List[str]:"""Get the namespace of the langchain object."""return["langchain","llms","bedrock"]@propertydeflc_attributes(self)->Dict[str,Any]:attributes:Dict[str,Any]={}ifself.region_name:attributes["region_name"]=self.region_namereturnattributesmodel_config=ConfigDict(extra="forbid",)def_stream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[GenerationChunk]:"""Call out to Bedrock service with streaming. Args: prompt (str): The prompt to pass into the model stop (Optional[List[str]], optional): Stop sequences. These will override any stop sequences in the `model_kwargs` attribute. Defaults to None. run_manager (Optional[CallbackManagerForLLMRun], optional): Callback run managers used to process the output. Defaults to None. Returns: Iterator[GenerationChunk]: Generator that yields the streamed responses. Yields: Iterator[GenerationChunk]: Responses from the model. """returnself._prepare_input_and_invoke_stream(prompt=prompt,stop=stop,run_manager=run_manager,**kwargs)def_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call out to Bedrock service model. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = llm.invoke("Tell me a joke.") """ifself.streaming:completion=""forchunkinself._stream(prompt=prompt,stop=stop,run_manager=run_manager,**kwargs):completion+=chunk.textreturncompletiontext,_=self._prepare_input_and_invoke(prompt=prompt,stop=stop,run_manager=run_manager,**kwargs)returntextasyncdef_astream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncGenerator[GenerationChunk,None]:"""Call out to Bedrock service with streaming. Args: prompt (str): The prompt to pass into the model stop (Optional[List[str]], optional): Stop sequences. These will override any stop sequences in the `model_kwargs` attribute. Defaults to None. run_manager (Optional[CallbackManagerForLLMRun], optional): Callback run managers used to process the output. Defaults to None. Yields: AsyncGenerator[GenerationChunk, None]: Generator that asynchronously yields the streamed responses. """asyncforchunkinself._aprepare_input_and_invoke_stream(prompt=prompt,stop=stop,run_manager=run_manager,**kwargs):yieldchunkasyncdef_acall(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call out to Bedrock service model. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = await llm._acall("Tell me a joke.") """ifnotself.streaming:raiseValueError("Streaming must be set to True for async operations. ")chunks=[chunk.textasyncforchunkinself._astream(prompt=prompt,stop=stop,run_manager=run_manager,**kwargs)]return"".join(chunks)