Source code for langchain_community.chat_models.yi
importjsonimportloggingfromcontextlibimportasynccontextmanagerfromtypingimportAny,AsyncIterator,Dict,Iterator,List,Mapping,Optional,Typeimportrequestsfromlangchain_core.callbacksimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_models.chat_modelsimport(BaseChatModel,agenerate_from_stream,generate_from_stream,)fromlangchain_core.messagesimport(AIMessage,AIMessageChunk,BaseMessage,BaseMessageChunk,ChatMessage,ChatMessageChunk,HumanMessage,HumanMessageChunk,SystemMessage,)fromlangchain_core.outputsimportChatGeneration,ChatGenerationChunk,ChatResultfromlangchain_core.utilsimport(convert_to_secret_str,get_from_dict_or_env,get_pydantic_field_names,)frompydanticimportConfigDict,Field,SecretStrlogger=logging.getLogger(__name__)DEFAULT_API_BASE_CN="https://api.lingyiwanwu.com/v1/chat/completions"DEFAULT_API_BASE_GLOBAL="https://api.01.ai/v1/chat/completions"def_convert_message_to_dict(message:BaseMessage)->dict:message_dict:Dict[str,Any]ifisinstance(message,ChatMessage):message_dict={"role":message.role,"content":message.content}elifisinstance(message,HumanMessage):message_dict={"role":"user","content":message.content}elifisinstance(message,AIMessage):message_dict={"role":"assistant","content":message.content}elifisinstance(message,SystemMessage):message_dict={"role":"assistant","content":message.content}else:raiseTypeError(f"Got unknown type {message}")returnmessage_dictdef_convert_dict_to_message(_dict:Mapping[str,Any])->BaseMessage:role=_dict["role"]ifrole=="user":returnHumanMessage(content=_dict["content"])elifrole=="assistant":returnAIMessage(content=_dict.get("content","")or"")elifrole=="system":returnAIMessage(content=_dict["content"])else:returnChatMessage(content=_dict["content"],role=role)def_convert_delta_to_message_chunk(_dict:Mapping[str,Any],default_class:Type[BaseMessageChunk])->BaseMessageChunk:role:str=_dict["role"]content=_dict.get("content")or""ifrole=="user"ordefault_class==HumanMessageChunk:returnHumanMessageChunk(content=content)elifrole=="assistant"ordefault_class==AIMessageChunk:returnAIMessageChunk(content=content)elifroleordefault_class==ChatMessageChunk:returnChatMessageChunk(content=content,role=role)else:returndefault_class(content=content,type=role)
[docs]classChatYi(BaseChatModel):"""Yi chat models API."""@propertydeflc_secrets(self)->Dict[str,str]:return{"yi_api_key":"YI_API_KEY",}@propertydeflc_serializable(self)->bool:returnTrueyi_api_base:str=Field(default=DEFAULT_API_BASE_CN)yi_api_key:SecretStr=Field(alias="api_key")region:str=Field(default="cn")# 默认使用中国区streaming:bool=Falserequest_timeout:int=Field(default=60,alias="timeout")model:str="yi-large"temperature:Optional[float]=Field(default=0.7)top_p:float=0.7model_kwargs:Dict[str,Any]=Field(default_factory=dict)model_config=ConfigDict(populate_by_name=True,)def__init__(self,**kwargs:Any)->None:kwargs["yi_api_key"]=convert_to_secret_str(get_from_dict_or_env(kwargs,["yi_api_key","api_key"],"YI_API_KEY",))ifkwargs.get("yi_api_base")isNone:region=kwargs.get("region","cn").lower()ifregion=="global":kwargs["yi_api_base"]=DEFAULT_API_BASE_GLOBALelse:kwargs["yi_api_base"]=DEFAULT_API_BASE_CNall_required_field_names=get_pydantic_field_names(self.__class__)extra=kwargs.get("model_kwargs",{})forfield_nameinlist(kwargs):iffield_nameinextra:raiseValueError(f"Found {field_name} supplied twice.")iffield_namenotinall_required_field_names:extra[field_name]=kwargs.pop(field_name)invalid_model_kwargs=all_required_field_names.intersection(extra.keys())ifinvalid_model_kwargs:raiseValueError(f"Parameters {invalid_model_kwargs} should be specified explicitly. "f"Instead they were passed in as part of `model_kwargs` parameter.")kwargs["model_kwargs"]=extrasuper().__init__(**kwargs)@propertydef_default_params(self)->Dict[str,Any]:return{"model":self.model,"temperature":self.temperature,"top_p":self.top_p,"stream":self.streaming,}def_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:ifself.streaming:stream_iter=self._stream(messages=messages,stop=stop,run_manager=run_manager,**kwargs)returngenerate_from_stream(stream_iter)res=self._chat(messages,**kwargs)ifres.status_code!=200:raiseValueError(f"Error from Yi api response: {res}")response=res.json()returnself._create_chat_result(response)def_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:res=self._chat(messages,stream=True,**kwargs)ifres.status_code!=200:raiseValueError(f"Error from Yi api response: {res}")default_chunk_class=AIMessageChunkforchunkinres.iter_lines():chunk=chunk.decode("utf-8").strip("\r\n")parts=chunk.split("data: ",1)chunk=parts[1]iflen(parts)>1elseNoneifchunkisNone:continueifchunk=="[DONE]":breakresponse=json.loads(chunk)forminresponse.get("choices"):chunk=_convert_delta_to_message_chunk(m.get("delta"),default_chunk_class)default_chunk_class=chunk.__class__cg_chunk=ChatGenerationChunk(message=chunk)ifrun_manager:run_manager.on_llm_new_token(chunk.content,chunk=cg_chunk)yieldcg_chunkasyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->ChatResult:should_stream=streamifstreamisnotNoneelseself.streamingifshould_stream:stream_iter=self._astream(messages,stop=stop,run_manager=run_manager,**kwargs)returnawaitagenerate_from_stream(stream_iter)headers=self._create_headers_parameters(**kwargs)payload=self._create_payload_parameters(messages,**kwargs)importhttpxasyncwithhttpx.AsyncClient(headers=headers,timeout=self.request_timeout)asclient:response=awaitclient.post(self.yi_api_base,json=payload)response.raise_for_status()returnself._create_chat_result(response.json())asyncdef_astream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[ChatGenerationChunk]:headers=self._create_headers_parameters(**kwargs)payload=self._create_payload_parameters(messages,stream=True,**kwargs)importhttpxasyncwithhttpx.AsyncClient(headers=headers,timeout=self.request_timeout)asclient:asyncwithaconnect_httpx_sse(client,"POST",self.yi_api_base,json=payload)asevent_source:asyncforsseinevent_source.aiter_sse():chunk=json.loads(sse.data)iflen(chunk["choices"])==0:continuechoice=chunk["choices"][0]chunk=_convert_delta_to_message_chunk(choice["delta"],AIMessageChunk)finish_reason=choice.get("finish_reason",None)generation_info=({"finish_reason":finish_reason}iffinish_reasonisnotNoneelseNone)chunk=ChatGenerationChunk(message=chunk,generation_info=generation_info)ifrun_manager:awaitrun_manager.on_llm_new_token(chunk.text,chunk=chunk)yieldchunkiffinish_reasonisnotNone:breakdef_chat(self,messages:List[BaseMessage],**kwargs:Any)->requests.Response:payload=self._create_payload_parameters(messages,**kwargs)url=self.yi_api_baseheaders=self._create_headers_parameters(**kwargs)res=requests.post(url=url,timeout=self.request_timeout,headers=headers,json=payload,stream=self.streaming,)returnresdef_create_payload_parameters(self,messages:List[BaseMessage],**kwargs:Any)->Dict[str,Any]:parameters={**self._default_params,**kwargs}temperature=parameters.pop("temperature",0.7)top_p=parameters.pop("top_p",0.7)model=parameters.pop("model")stream=parameters.pop("stream",False)payload={"model":model,"messages":[_convert_message_to_dict(m)forminmessages],"top_p":top_p,"temperature":temperature,"stream":stream,}returnpayloaddef_create_headers_parameters(self,**kwargs:Any)->Dict[str,Any]:parameters={**self._default_params,**kwargs}default_headers=parameters.pop("headers",{})api_key=""ifself.yi_api_key:api_key=self.yi_api_key.get_secret_value()headers={"Content-Type":"application/json","Authorization":f"Bearer {api_key}",**default_headers,}returnheadersdef_create_chat_result(self,response:Mapping[str,Any])->ChatResult:generations=[]forcinresponse["choices"]:message=_convert_dict_to_message(c["message"])gen=ChatGeneration(message=message)generations.append(gen)token_usage=response["usage"]llm_output={"token_usage":token_usage,"model":self.model}returnChatResult(generations=generations,llm_output=llm_output)@propertydef_llm_type(self)->str:return"yi-chat"