[docs]classChatDeepInfraException(Exception):"""Exception raised when the DeepInfra API returns an error."""pass
def_create_retry_decorator(llm:ChatDeepInfra,run_manager:Optional[Union[AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun]]=None,)->Callable[[Any],Any]:"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions."""returncreate_base_retry_decorator(error_types=[requests.exceptions.ConnectTimeout,ChatDeepInfraException],max_retries=llm.max_retries,run_manager=run_manager,)def_parse_tool_calling(tool_call:dict)->ToolCall:""" Convert a tool calling response from server to a ToolCall object. Args: tool_call: Returns: """name=tool_call["function"].get("name","")args=json.loads(tool_call["function"]["arguments"])id=tool_call.get("id")returncreate_tool_call(name=name,args=args,id=id)def_convert_to_tool_calling(tool_call:ToolCall)->Dict[str,Any]:""" Convert a ToolCall object to a tool calling request for server. Args: tool_call: Returns: """return{"type":"function","function":{"arguments":json.dumps(tool_call["args"]),"name":tool_call["name"],},"id":tool_call.get("id"),}def_convert_dict_to_message(_dict:Mapping[str,Any])->BaseMessage:role=_dict["role"]ifrole=="user":returnHumanMessage(content=_dict["content"])elifrole=="assistant":content=_dict.get("content","")or""tool_calls_content=_dict.get("tool_calls",[])or[]tool_calls=[_parse_tool_calling(tool_call)fortool_callintool_calls_content]returnAIMessage(content=content,tool_calls=tool_calls)elifrole=="system":returnSystemMessage(content=_dict["content"])elifrole=="function":returnFunctionMessage(content=_dict["content"],name=_dict["name"])else:returnChatMessage(content=_dict["content"],role=role)def_convert_delta_to_message_chunk(_dict:Mapping[str,Any],default_class:Type[BaseMessageChunk])->BaseMessageChunk:role=_dict.get("role")content=_dict.get("content")or""ifrole=="user"ordefault_class==HumanMessageChunk:returnHumanMessageChunk(content=content)elifrole=="assistant"ordefault_class==AIMessageChunk:tool_calls=[_parse_tool_calling(tool_call)fortool_callin_dict.get("tool_calls",[])]returnAIMessageChunk(content=content,tool_calls=tool_calls)elifrole=="system"ordefault_class==SystemMessageChunk:returnSystemMessageChunk(content=content)elifrole=="function"ordefault_class==FunctionMessageChunk:returnFunctionMessageChunk(content=content,name=_dict["name"])elifroleordefault_class==ChatMessageChunk:returnChatMessageChunk(content=content,role=role)# type: ignore[arg-type]else:returndefault_class(content=content)# type: ignore[call-arg]def_convert_message_to_dict(message:BaseMessage)->dict:ifisinstance(message,ChatMessage):message_dict={"role":message.role,"content":message.content}elifisinstance(message,HumanMessage):message_dict={"role":"user","content":message.content}elifisinstance(message,AIMessage):tool_calls=[_convert_to_tool_calling(tool_call)fortool_callinmessage.tool_calls]message_dict={"role":"assistant","content":message.content,"tool_calls":tool_calls,# type: ignore[dict-item]}elifisinstance(message,SystemMessage):message_dict={"role":"system","content":message.content}elifisinstance(message,FunctionMessage):message_dict={"role":"function","content":message.content,"name":message.name,}elifisinstance(message,ToolMessage):message_dict={"role":"tool","content":message.content,"name":message.name,# type: ignore[dict-item]"tool_call_id":message.tool_call_id,}else:raiseValueError(f"Got unknown type {message}")if"name"inmessage.additional_kwargs:message_dict["name"]=message.additional_kwargs["name"]returnmessage_dict
[docs]classChatDeepInfra(BaseChatModel):"""A chat model that uses the DeepInfra API."""# client: Any #: :meta private:model_name:str=Field(default="meta-llama/Llama-2-70b-chat-hf",alias="model")"""Model name to use."""deepinfra_api_token:Optional[str]=Nonerequest_timeout:Optional[float]=Field(default=None,alias="timeout")temperature:Optional[float]=1model_kwargs:Dict[str,Any]=Field(default_factory=dict)"""Run inference with this temperature. Must be in the closed interval [0.0, 1.0]."""top_p:Optional[float]=None"""Decode using nucleus sampling: consider the smallest set of tokens whose probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""top_k:Optional[int]=None"""Decode using top-k sampling: consider the set of top_k most probable tokens. Must be positive."""n:int=1"""Number of chat completions to generate for each prompt. Note that the API may not return the full n completions if duplicates are generated."""max_tokens:int=256streaming:bool=Falsemax_retries:int=1classConfig:"""Configuration for this pydantic object."""allow_population_by_field_name=True@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling OpenAI API."""return{"model":self.model_name,"max_tokens":self.max_tokens,"stream":self.streaming,"n":self.n,"temperature":self.temperature,"request_timeout":self.request_timeout,**self.model_kwargs,}@propertydef_client_params(self)->Dict[str,Any]:"""Get the parameters used for the openai client."""return{**self._default_params}
[docs]defcompletion_with_retry(self,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any)->Any:"""Use tenacity to retry the completion call."""retry_decorator=_create_retry_decorator(self,run_manager=run_manager)@retry_decoratordef_completion_with_retry(**kwargs:Any)->Any:try:request_timeout=kwargs.pop("request_timeout")request=Requests(headers=self._headers())response=request.post(url=self._url(),data=self._body(kwargs),timeout=request_timeout)self._handle_status(response.status_code,response.text)returnresponseexceptExceptionase:print("EX",e)# noqa: T201raisereturn_completion_with_retry(**kwargs)
[docs]asyncdefacompletion_with_retry(self,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->Any:"""Use tenacity to retry the async completion call."""retry_decorator=_create_retry_decorator(self,run_manager=run_manager)@retry_decoratorasyncdef_completion_with_retry(**kwargs:Any)->Any:try:request_timeout=kwargs.pop("request_timeout")request=Requests(headers=self._headers())asyncwithrequest.apost(url=self._url(),data=self._body(kwargs),timeout=request_timeout)asresponse:self._handle_status(response.status,response.text)returnawaitresponse.json()exceptExceptionase:print("EX",e)# noqa: T201raisereturnawait_completion_with_retry(**kwargs)
@root_validator(pre=True)definit_defaults(cls,values:Dict)->Dict:"""Validate api key, python package exists, temperature, top_p, and top_k."""# For compatibility with LiteLLMapi_key=get_from_dict_or_env(values,"deepinfra_api_key","DEEPINFRA_API_KEY",default="",)values["deepinfra_api_token"]=get_from_dict_or_env(values,"deepinfra_api_token","DEEPINFRA_API_TOKEN",default=api_key,)returnvalues@root_validator(pre=False,skip_on_failure=True)defvalidate_environment(cls,values:Dict)->Dict:ifvalues["temperature"]isnotNoneandnot0<=values["temperature"]<=1:raiseValueError("temperature must be in the range [0.0, 1.0]")ifvalues["top_p"]isnotNoneandnot0<=values["top_p"]<=1:raiseValueError("top_p must be in the range [0.0, 1.0]")ifvalues["top_k"]isnotNoneandvalues["top_k"]<=0:raiseValueError("top_k must be positive")returnvaluesdef_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->ChatResult:should_stream=streamifstreamisnotNoneelseself.streamingifshould_stream:stream_iter=self._stream(messages,stop=stop,run_manager=run_manager,**kwargs)returngenerate_from_stream(stream_iter)message_dicts,params=self._create_message_dicts(messages,stop)params={**params,**kwargs}response=self.completion_with_retry(messages=message_dicts,run_manager=run_manager,**params)returnself._create_chat_result(response.json())def_create_chat_result(self,response:Mapping[str,Any])->ChatResult:generations=[]forresinresponse["choices"]:message=_convert_dict_to_message(res["message"])gen=ChatGeneration(message=message,generation_info=dict(finish_reason=res.get("finish_reason")),)generations.append(gen)token_usage=response.get("usage",{})llm_output={"token_usage":token_usage,"model":self.model_name}res=ChatResult(generations=generations,llm_output=llm_output)returnresdef_create_message_dicts(self,messages:List[BaseMessage],stop:Optional[List[str]])->Tuple[List[Dict[str,Any]],Dict[str,Any]]:params=self._client_paramsifstopisnotNone:if"stop"inparams:raiseValueError("`stop` found in both the input and default params.")params["stop"]=stopmessage_dicts=[_convert_message_to_dict(m)forminmessages]returnmessage_dicts,paramsdef_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:message_dicts,params=self._create_message_dicts(messages,stop)params={**params,**kwargs,"stream":True}response=self.completion_with_retry(messages=message_dicts,run_manager=run_manager,**params)forlinein_parse_stream(response.iter_lines()):chunk=_handle_sse_line(line)ifchunk:cg_chunk=ChatGenerationChunk(message=chunk,generation_info=None)ifrun_manager:run_manager.on_llm_new_token(str(chunk.content),chunk=cg_chunk)yieldcg_chunkasyncdef_astream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[ChatGenerationChunk]:message_dicts,params=self._create_message_dicts(messages,stop)params={"messages":message_dicts,"stream":True,**params,**kwargs}request_timeout=params.pop("request_timeout")request=Requests(headers=self._headers())asyncwithrequest.apost(url=self._url(),data=self._body(params),timeout=request_timeout)asresponse:asyncforlinein_parse_stream_async(response.content):chunk=_handle_sse_line(line)ifchunk:cg_chunk=ChatGenerationChunk(message=chunk,generation_info=None)ifrun_manager:awaitrun_manager.on_llm_new_token(str(chunk.content),chunk=cg_chunk)yieldcg_chunkasyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->ChatResult:should_stream=streamifstreamisnotNoneelseself.streamingifshould_stream:stream_iter=self._astream(messages,stop=stop,run_manager=run_manager,**kwargs)returnawaitagenerate_from_stream(stream_iter)message_dicts,params=self._create_message_dicts(messages,stop)params={"messages":message_dicts,**params,**kwargs}res=awaitself.acompletion_with_retry(run_manager=run_manager,**params)returnself._create_chat_result(res)@propertydef_identifying_params(self)->Dict[str,Any]:"""Get the identifying parameters."""return{"model":self.model_name,"temperature":self.temperature,"top_p":self.top_p,"top_k":self.top_k,"n":self.n,}@propertydef_llm_type(self)->str:return"deepinfra-chat"def_handle_status(self,code:int,text:Any)->None:ifcode>=500:raiseChatDeepInfraException(f"DeepInfra Server error status {code}: {text}")elifcode>=400:raiseValueError(f"DeepInfra received an invalid payload: {text}")elifcode!=200:raiseException(f"DeepInfra returned an unexpected response with status "f"{code}: {text}")def_url(self)->str:return"https://stage.api.deepinfra.com/v1/openai/chat/completions"def_headers(self)->Dict:return{"Authorization":f"bearer {self.deepinfra_api_token}","Content-Type":"application/json",}def_body(self,kwargs:Any)->Dict:returnkwargs
[docs]defbind_tools(self,tools:Sequence[Union[Dict[str,Any],Type[BaseModel],Callable,BaseTool]],**kwargs:Any,)->Runnable[LanguageModelInput,BaseMessage]:"""Bind tool-like objects to this chat model. Assumes model is compatible with OpenAI tool-calling API. Args: tools: A list of tool definitions to bind to this chat model. Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic models, callables, and BaseTools will be automatically converted to their schema dictionary representation. **kwargs: Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. """formatted_tools=[convert_to_openai_tool(tool)fortoolintools]returnsuper().bind(tools=formatted_tools,**kwargs)
def_parse_stream(rbody:Iterator[bytes])->Iterator[str]:forlineinrbody:_line=_parse_stream_helper(line)if_lineisnotNone:yield_lineasyncdef_parse_stream_async(rbody:aiohttp.StreamReader)->AsyncIterator[str]:asyncforlineinrbody:_line=_parse_stream_helper(line)if_lineisnotNone:yield_linedef_parse_stream_helper(line:bytes)->Optional[str]:iflineandline.startswith(b"data:"):ifline.startswith(b"data: "):# SSE event may be valid when it contain whitespaceline=line[len(b"data: "):]else:line=line[len(b"data:"):]ifline.strip()==b"[DONE]":# return here will cause GeneratorExit exception in urllib3# and it will close http connection with TCP ResetreturnNoneelse:returnline.decode("utf-8")returnNonedef_handle_sse_line(line:str)->Optional[BaseMessageChunk]:try:obj=json.loads(line)default_chunk_class=AIMessageChunkdelta=obj.get("choices",[{}])[0].get("delta",{})return_convert_delta_to_message_chunk(delta,default_chunk_class)exceptException:returnNone