Source code for langchain_community.chat_models.ernie
importloggingimportthreadingfromtypingimportAny,Dict,List,Mapping,Optionalimportrequestsfromlangchain_core._api.deprecationimportdeprecatedfromlangchain_core.callbacksimportCallbackManagerForLLMRunfromlangchain_core.language_models.chat_modelsimportBaseChatModelfromlangchain_core.messagesimport(AIMessage,BaseMessage,ChatMessage,HumanMessage,)fromlangchain_core.outputsimportChatGeneration,ChatResultfromlangchain_core.utilsimportget_from_dict_or_envfrompydanticimportmodel_validatorlogger=logging.getLogger(__name__)def_convert_message_to_dict(message:BaseMessage)->dict:ifisinstance(message,ChatMessage):message_dict={"role":message.role,"content":message.content}elifisinstance(message,HumanMessage):message_dict={"role":"user","content":message.content}elifisinstance(message,AIMessage):message_dict={"role":"assistant","content":message.content}else:raiseValueError(f"Got unknown type {message}")returnmessage_dict
[docs]@deprecated(since="0.0.13",alternative="langchain_community.chat_models.QianfanChatEndpoint",)classErnieBotChat(BaseChatModel):"""`ERNIE-Bot` large language model. ERNIE-Bot is a large language model developed by Baidu, covering a huge amount of Chinese data. To use, you should have the `ernie_client_id` and `ernie_client_secret` set, or set the environment variable `ERNIE_CLIENT_ID` and `ERNIE_CLIENT_SECRET`. Note: access_token will be automatically generated based on client_id and client_secret, and will be regenerated after expiration (30 days). Default model is `ERNIE-Bot-turbo`, currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot`, `ERNIE-Bot-8K`, `ERNIE-Bot-4`, `ERNIE-Bot-turbo-AI`. Example: .. code-block:: python from langchain_community.chat_models import ErnieBotChat chat = ErnieBotChat(model_name='ERNIE-Bot') Deprecated Note: Please use `QianfanChatEndpoint` instead of this class. `QianfanChatEndpoint` is a more suitable choice for production. Always test your code after changing to `QianfanChatEndpoint`. Example of `QianfanChatEndpoint`: .. code-block:: python from langchain_community.chat_models import QianfanChatEndpoint qianfan_chat = QianfanChatEndpoint(model="ERNIE-Bot", endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk") """ernie_api_base:Optional[str]=None"""Baidu application custom endpoints"""ernie_client_id:Optional[str]=None"""Baidu application client id"""ernie_client_secret:Optional[str]=None"""Baidu application client secret"""access_token:Optional[str]=None"""access token is generated by client id and client secret, setting this value directly will cause an error"""model_name:str="ERNIE-Bot-turbo""""model name of ernie, default is `ERNIE-Bot-turbo`. Currently supported `ERNIE-Bot-turbo`, `ERNIE-Bot`"""system:Optional[str]=None"""system is mainly used for model character design, for example, you are an AI assistant produced by xxx company. The length of the system is limiting of 1024 characters."""request_timeout:Optional[int]=60"""request timeout for chat http requests"""streaming:Optional[bool]=False"""streaming mode. not supported yet."""top_p:Optional[float]=0.8temperature:Optional[float]=0.95penalty_score:Optional[float]=1_lock=threading.Lock()@model_validator(mode="before")@classmethoddefvalidate_environment(cls,values:Dict)->Any:values["ernie_api_base"]=get_from_dict_or_env(values,"ernie_api_base","ERNIE_API_BASE","https://aip.baidubce.com")values["ernie_client_id"]=get_from_dict_or_env(values,"ernie_client_id","ERNIE_CLIENT_ID",)values["ernie_client_secret"]=get_from_dict_or_env(values,"ernie_client_secret","ERNIE_CLIENT_SECRET",)returnvaluesdef_chat(self,payload:object)->dict:base_url=f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/chat"model_paths={"ERNIE-Bot-turbo":"eb-instant","ERNIE-Bot":"completions","ERNIE-Bot-8K":"ernie_bot_8k","ERNIE-Bot-4":"completions_pro","ERNIE-Bot-turbo-AI":"ai_apaas","BLOOMZ-7B":"bloomz_7b1","Llama-2-7b-chat":"llama_2_7b","Llama-2-13b-chat":"llama_2_13b","Llama-2-70b-chat":"llama_2_70b",}ifself.model_nameinmodel_paths:url=f"{base_url}/{model_paths[self.model_name]}"else:raiseValueError(f"Got unknown model_name {self.model_name}")resp=requests.post(url,timeout=self.request_timeout,headers={"Content-Type":"application/json",},params={"access_token":self.access_token},json=payload,)returnresp.json()def_refresh_access_token_with_lock(self)->None:withself._lock:logger.debug("Refreshing access token")base_url:str=f"{self.ernie_api_base}/oauth/2.0/token"resp=requests.post(base_url,timeout=10,headers={"Content-Type":"application/json","Accept":"application/json",},params={"grant_type":"client_credentials","client_id":self.ernie_client_id,"client_secret":self.ernie_client_secret,},)self.access_token=str(resp.json().get("access_token"))def_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:ifself.streaming:raiseValueError("`streaming` option currently unsupported.")ifnotself.access_token:self._refresh_access_token_with_lock()payload={"messages":[_convert_message_to_dict(m)forminmessages],"top_p":self.top_p,"temperature":self.temperature,"penalty_score":self.penalty_score,"system":self.system,**kwargs,}logger.debug(f"Payload for ernie api is {payload}")resp=self._chat(payload)ifresp.get("error_code"):ifresp.get("error_code")==111:logger.debug("access_token expired, refresh it")self._refresh_access_token_with_lock()resp=self._chat(payload)else:raiseValueError(f"Error from ErnieChat api response: {resp}")returnself._create_chat_result(resp)def_create_chat_result(self,response:Mapping[str,Any])->ChatResult:if"function_call"inresponse:additional_kwargs={"function_call":dict(response.get("function_call",{}))}else:additional_kwargs={}generations=[ChatGeneration(message=AIMessage(content=response.get("result",""),additional_kwargs={**additional_kwargs},))]token_usage=response.get("usage",{})llm_output={"token_usage":token_usage,"model_name":self.model_name}returnChatResult(generations=generations,llm_output=llm_output)@propertydef_llm_type(self)->str:return"ernie-bot-chat"