[docs]classPaiEasChatEndpoint(BaseChatModel):"""Alibaba Cloud PAI-EAS LLM Service chat model API. To use, must have a deployed eas chat llm service on AliCloud. One can set the environment variable ``eas_service_url`` and ``eas_service_token`` set with your eas service url and service token. Example: .. code-block:: python from langchain_community.chat_models import PaiEasChatEndpoint eas_chat_endpoint = PaiEasChatEndpoint( eas_service_url="your_service_url", eas_service_token="your_service_token" ) """"""PAI-EAS Service URL"""eas_service_url:str"""PAI-EAS Service TOKEN"""eas_service_token:str"""PAI-EAS Service Infer Params"""max_new_tokens:Optional[int]=512temperature:Optional[float]=0.8top_p:Optional[float]=0.1top_k:Optional[int]=10do_sample:Optional[bool]=Falseuse_cache:Optional[bool]=Truestop_sequences:Optional[List[str]]=None"""Enable stream chat mode."""streaming:bool=False"""Key/value arguments to pass to the model. Reserved for future use"""model_kwargs:Optional[dict]=Noneversion:Optional[str]="2.0"timeout:Optional[int]=5000@root_validator(pre=True)defvalidate_environment(cls,values:Dict)->Dict:"""Validate that api key and python package exists in environment."""values["eas_service_url"]=get_from_dict_or_env(values,"eas_service_url","EAS_SERVICE_URL")values["eas_service_token"]=get_from_dict_or_env(values,"eas_service_token","EAS_SERVICE_TOKEN")returnvalues@propertydef_identifying_params(self)->Dict[str,Any]:"""Get the identifying parameters."""_model_kwargs=self.model_kwargsor{}return{"eas_service_url":self.eas_service_url,"eas_service_token":self.eas_service_token,**{"model_kwargs":_model_kwargs},}@propertydef_llm_type(self)->str:"""Return type of llm."""return"pai_eas_chat_endpoint"@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling Cohere API."""return{"max_new_tokens":self.max_new_tokens,"temperature":self.temperature,"top_k":self.top_k,"top_p":self.top_p,"stop_sequences":[],"do_sample":self.do_sample,"use_cache":self.use_cache,}def_invocation_params(self,stop_sequences:Optional[List[str]],**kwargs:Any)->dict:params=self._default_paramsifself.model_kwargs:params.update(self.model_kwargs)ifself.stop_sequencesisnotNoneandstop_sequencesisnotNone:raiseValueError("`stop` found in both the input and default params.")elifself.stop_sequencesisnotNone:params["stop"]=self.stop_sequenceselse:params["stop"]=stop_sequencesreturn{**params,**kwargs}
[docs]defformat_request_payload(self,messages:List[BaseMessage],**model_kwargs:Any)->dict:prompt:Dict[str,Any]={}user_content:List[str]=[]assistant_content:List[str]=[]formessageinmessages:"""Converts message to a dict according to role"""content=cast(str,message.content)ifisinstance(message,HumanMessage):user_content=user_content+[content]elifisinstance(message,AIMessage):assistant_content=assistant_content+[content]elifisinstance(message,SystemMessage):prompt["system_prompt"]=contentelifisinstance(message,ChatMessage)andmessage.rolein["user","assistant","system",]:ifmessage.role=="system":prompt["system_prompt"]=contentelifmessage.role=="user":user_content=user_content+[content]elifmessage.role=="assistant":assistant_content=assistant_content+[content]else:supported=",".join([roleforrolein["user","assistant","system"]])raiseValueError(f"""Received unsupported role. Supported roles for the LLaMa Foundation Model: {supported}""")prompt["prompt"]=user_content[len(user_content)-1]history=[history_itemfor_,history_iteminenumerate(zip(user_content[:-1],assistant_content))]prompt["history"]=historyreturn{**prompt,**model_kwargs}
def_format_response_payload(self,output:bytes,stop_sequences:Optional[List[str]])->str:"""Formats response"""try:text=json.loads(output)["response"]ifstop_sequences:text=enforce_stop_tokens(text,stop_sequences)returntextexceptExceptionase:ifisinstance(e,json.decoder.JSONDecodeError):returnoutput.decode("utf-8")raiseedef_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:output_str=self._call(messages,stop=stop,run_manager=run_manager,**kwargs)message=AIMessage(content=output_str)generation=ChatGeneration(message=message)returnChatResult(generations=[generation])def_call(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:params=self._invocation_params(stop,**kwargs)request_payload=self.format_request_payload(messages,**params)response_payload=self._call_eas(request_payload)generated_text=self._format_response_payload(response_payload,params["stop"])ifrun_manager:run_manager.on_llm_new_token(generated_text)returngenerated_textdef_call_eas(self,query_body:dict)->Any:"""Generate text from the eas service."""headers={"Content-Type":"application/json","Accept":"application/json","Authorization":f"{self.eas_service_token}",}# make requestresponse=requests.post(self.eas_service_url,headers=headers,json=query_body,timeout=self.timeout)ifresponse.status_code!=200:raiseException(f"Request failed with status code {response.status_code}"f" and message {response.text}")returnresponse.textdef_call_eas_stream(self,query_body:dict)->Any:"""Generate text from the eas service."""headers={"Content-Type":"application/json","Accept":"application/json","Authorization":f"{self.eas_service_token}",}# make requestresponse=requests.post(self.eas_service_url,headers=headers,json=query_body,timeout=self.timeout)ifresponse.status_code!=200:raiseException(f"Request failed with status code {response.status_code}"f" and message {response.text}")returnresponsedef_convert_chunk_to_message_message(self,chunk:str,)->AIMessageChunk:data=json.loads(chunk.encode("utf-8"))returnAIMessageChunk(content=data.get("response",""))asyncdef_astream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[ChatGenerationChunk]:params=self._invocation_params(stop,**kwargs)request_payload=self.format_request_payload(messages,**params)request_payload["use_stream_chat"]=Trueresponse=self._call_eas_stream(request_payload)forchunkinresponse.iter_lines(chunk_size=8192,decode_unicode=False,delimiter=b"\0"):ifchunk:content=self._convert_chunk_to_message_message(chunk)# identify stop sequence in generated text, if anystop_seq_found:Optional[str]=Noneforstop_seqinparams["stop"]:ifstop_seqincontent.content:stop_seq_found=stop_seq# identify text to yieldtext:Optional[str]=Noneifstop_seq_found:content.content=content.content[:content.content.index(stop_seq_found)]# yield text, if anyiftext:cg_chunk=ChatGenerationChunk(message=content)ifrun_manager:awaitrun_manager.on_llm_new_token(cast(str,content.content),chunk=cg_chunk)yieldcg_chunk# break if stop sequence foundifstop_seq_found:break