Source code for langchain_community.chat_models.litellm
"""Wrapper around LiteLLM's model I/O library."""from__future__importannotationsimportjsonimportloggingfromtypingimport(Any,AsyncIterator,Callable,Dict,Iterator,List,Mapping,Optional,Sequence,Tuple,Type,Union,)fromlangchain_core.callbacksimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_modelsimportLanguageModelInputfromlangchain_core.language_models.chat_modelsimport(BaseChatModel,agenerate_from_stream,generate_from_stream,)fromlangchain_core.language_models.llmsimportcreate_base_retry_decoratorfromlangchain_core.messagesimport(AIMessage,AIMessageChunk,BaseMessage,BaseMessageChunk,ChatMessage,ChatMessageChunk,FunctionMessage,FunctionMessageChunk,HumanMessage,HumanMessageChunk,SystemMessage,SystemMessageChunk,ToolCall,ToolCallChunk,ToolMessage,)fromlangchain_core.outputsimport(ChatGeneration,ChatGenerationChunk,ChatResult,)fromlangchain_core.pydantic_v1importBaseModel,Fieldfromlangchain_core.runnablesimportRunnablefromlangchain_core.toolsimportBaseToolfromlangchain_core.utilsimportget_from_dict_or_env,pre_initfromlangchain_core.utils.function_callingimportconvert_to_openai_toollogger=logging.getLogger(__name__)
[docs]classChatLiteLLMException(Exception):"""Error with the `LiteLLM I/O` library"""
def_create_retry_decorator(llm:ChatLiteLLM,run_manager:Optional[Union[AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun]]=None,)->Callable[[Any],Any]:"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""importlitellmerrors=[litellm.Timeout,litellm.APIError,litellm.APIConnectionError,litellm.RateLimitError,]returncreate_base_retry_decorator(error_types=errors,max_retries=llm.max_retries,run_manager=run_manager)def_convert_dict_to_message(_dict:Mapping[str,Any])->BaseMessage:role=_dict["role"]ifrole=="user":returnHumanMessage(content=_dict["content"])elifrole=="assistant":# Fix for azure# Also OpenAI returns None for tool invocationscontent=_dict.get("content","")or""additional_kwargs={}if_dict.get("function_call"):additional_kwargs["function_call"]=dict(_dict["function_call"])if_dict.get("tool_calls"):additional_kwargs["tool_calls"]=_dict["tool_calls"]returnAIMessage(content=content,additional_kwargs=additional_kwargs)elifrole=="system":returnSystemMessage(content=_dict["content"])elifrole=="function":returnFunctionMessage(content=_dict["content"],name=_dict["name"])else:returnChatMessage(content=_dict["content"],role=role)
[docs]asyncdefacompletion_with_retry(llm:ChatLiteLLM,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->Any:"""Use tenacity to retry the async completion call."""retry_decorator=_create_retry_decorator(llm,run_manager=run_manager)@retry_decoratorasyncdef_completion_with_retry(**kwargs:Any)->Any:# Use OpenAI's async api https://github.com/openai/openai-python#async-apireturnawaitllm.client.acreate(**kwargs)returnawait_completion_with_retry(**kwargs)
def_convert_delta_to_message_chunk(_dict:Mapping[str,Any],default_class:Type[BaseMessageChunk])->BaseMessageChunk:role=_dict.get("role")content=_dict.get("content")or""if_dict.get("function_call"):additional_kwargs={"function_call":dict(_dict["function_call"])}else:additional_kwargs={}tool_call_chunks=[]ifraw_tool_calls:=_dict.get("tool_calls"):additional_kwargs["tool_calls"]=raw_tool_callstry:tool_call_chunks=[ToolCallChunk(name=rtc["function"].get("name"),args=rtc["function"].get("arguments"),id=rtc.get("id"),index=rtc["index"],)forrtcinraw_tool_calls]exceptKeyError:passifrole=="user"ordefault_class==HumanMessageChunk:returnHumanMessageChunk(content=content)elifrole=="assistant"ordefault_class==AIMessageChunk:returnAIMessageChunk(content=content,additional_kwargs=additional_kwargs,tool_call_chunks=tool_call_chunks,)elifrole=="system"ordefault_class==SystemMessageChunk:returnSystemMessageChunk(content=content)elifrole=="function"ordefault_class==FunctionMessageChunk:returnFunctionMessageChunk(content=content,name=_dict["name"])elifroleordefault_class==ChatMessageChunk:returnChatMessageChunk(content=content,role=role)# type: ignore[arg-type]else:returndefault_class(content=content)# type: ignore[call-arg]def_lc_tool_call_to_openai_tool_call(tool_call:ToolCall)->dict:return{"type":"function","id":tool_call["id"],"function":{"name":tool_call["name"],"arguments":json.dumps(tool_call["args"]),},}def_convert_message_to_dict(message:BaseMessage)->dict:message_dict:Dict[str,Any]={"content":message.content}ifisinstance(message,ChatMessage):message_dict["role"]=message.roleelifisinstance(message,HumanMessage):message_dict["role"]="user"elifisinstance(message,AIMessage):message_dict["role"]="assistant"if"function_call"inmessage.additional_kwargs:message_dict["function_call"]=message.additional_kwargs["function_call"]ifmessage.tool_calls:message_dict["tool_calls"]=[_lc_tool_call_to_openai_tool_call(tc)fortcinmessage.tool_calls]elif"tool_calls"inmessage.additional_kwargs:message_dict["tool_calls"]=message.additional_kwargs["tool_calls"]elifisinstance(message,SystemMessage):message_dict["role"]="system"elifisinstance(message,FunctionMessage):message_dict["role"]="function"message_dict["name"]=message.nameelifisinstance(message,ToolMessage):message_dict["role"]="tool"message_dict["tool_call_id"]=message.tool_call_idelse:raiseValueError(f"Got unknown type {message}")if"name"inmessage.additional_kwargs:message_dict["name"]=message.additional_kwargs["name"]returnmessage_dict
[docs]classChatLiteLLM(BaseChatModel):"""Chat model that uses the LiteLLM API."""client:Any#: :meta private:model:str="gpt-3.5-turbo"model_name:Optional[str]=None"""Model name to use."""openai_api_key:Optional[str]=Noneazure_api_key:Optional[str]=Noneanthropic_api_key:Optional[str]=Nonereplicate_api_key:Optional[str]=Nonecohere_api_key:Optional[str]=Noneopenrouter_api_key:Optional[str]=Nonestreaming:bool=Falseapi_base:Optional[str]=Noneorganization:Optional[str]=Nonecustom_llm_provider:Optional[str]=Nonerequest_timeout:Optional[Union[float,Tuple[float,float]]]=Nonetemperature:Optional[float]=1model_kwargs:Dict[str,Any]=Field(default_factory=dict)"""Run inference with this temperature. Must be in the closed interval [0.0, 1.0]."""top_p:Optional[float]=None"""Decode using nucleus sampling: consider the smallest set of tokens whose probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""top_k:Optional[int]=None"""Decode using top-k sampling: consider the set of top_k most probable tokens. Must be positive."""n:int=1"""Number of chat completions to generate for each prompt. Note that the API may not return the full n completions if duplicates are generated."""max_tokens:Optional[int]=Nonemax_retries:int=6@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling OpenAI API."""set_model_value=self.modelifself.model_nameisnotNone:set_model_value=self.model_namereturn{"model":set_model_value,"force_timeout":self.request_timeout,"max_tokens":self.max_tokens,"stream":self.streaming,"n":self.n,"temperature":self.temperature,"custom_llm_provider":self.custom_llm_provider,**self.model_kwargs,}@propertydef_client_params(self)->Dict[str,Any]:"""Get the parameters used for the openai client."""set_model_value=self.modelifself.model_nameisnotNone:set_model_value=self.model_nameself.client.api_base=self.api_baseself.client.organization=self.organizationcreds:Dict[str,Any]={"model":set_model_value,"force_timeout":self.request_timeout,"api_base":self.api_base,}return{**self._default_params,**creds}
[docs]defcompletion_with_retry(self,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any)->Any:"""Use tenacity to retry the completion call."""retry_decorator=_create_retry_decorator(self,run_manager=run_manager)@retry_decoratordef_completion_with_retry(**kwargs:Any)->Any:returnself.client.completion(**kwargs)return_completion_with_retry(**kwargs)
@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate api key, python package exists, temperature, top_p, and top_k."""try:importlitellmexceptImportError:raiseChatLiteLLMException("Could not import litellm python package. ""Please install it with `pip install litellm`")values["openai_api_key"]=get_from_dict_or_env(values,"openai_api_key","OPENAI_API_KEY",default="")values["azure_api_key"]=get_from_dict_or_env(values,"azure_api_key","AZURE_API_KEY",default="")values["anthropic_api_key"]=get_from_dict_or_env(values,"anthropic_api_key","ANTHROPIC_API_KEY",default="")values["replicate_api_key"]=get_from_dict_or_env(values,"replicate_api_key","REPLICATE_API_KEY",default="")values["openrouter_api_key"]=get_from_dict_or_env(values,"openrouter_api_key","OPENROUTER_API_KEY",default="")values["cohere_api_key"]=get_from_dict_or_env(values,"cohere_api_key","COHERE_API_KEY",default="")values["huggingface_api_key"]=get_from_dict_or_env(values,"huggingface_api_key","HUGGINGFACE_API_KEY",default="")values["together_ai_api_key"]=get_from_dict_or_env(values,"together_ai_api_key","TOGETHERAI_API_KEY",default="")values["client"]=litellmifvalues["temperature"]isnotNoneandnot0<=values["temperature"]<=1:raiseValueError("temperature must be in the range [0.0, 1.0]")ifvalues["top_p"]isnotNoneandnot0<=values["top_p"]<=1:raiseValueError("top_p must be in the range [0.0, 1.0]")ifvalues["top_k"]isnotNoneandvalues["top_k"]<=0:raiseValueError("top_k must be positive")returnvaluesdef_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->ChatResult:should_stream=streamifstreamisnotNoneelseself.streamingifshould_stream:stream_iter=self._stream(messages,stop=stop,run_manager=run_manager,**kwargs)returngenerate_from_stream(stream_iter)message_dicts,params=self._create_message_dicts(messages,stop)params={**params,**kwargs}response=self.completion_with_retry(messages=message_dicts,run_manager=run_manager,**params)returnself._create_chat_result(response)def_create_chat_result(self,response:Mapping[str,Any])->ChatResult:generations=[]forresinresponse["choices"]:message=_convert_dict_to_message(res["message"])gen=ChatGeneration(message=message,generation_info=dict(finish_reason=res.get("finish_reason")),)generations.append(gen)token_usage=response.get("usage",{})set_model_value=self.modelifself.model_nameisnotNone:set_model_value=self.model_namellm_output={"token_usage":token_usage,"model":set_model_value}returnChatResult(generations=generations,llm_output=llm_output)def_create_message_dicts(self,messages:List[BaseMessage],stop:Optional[List[str]])->Tuple[List[Dict[str,Any]],Dict[str,Any]]:params=self._client_paramsifstopisnotNone:if"stop"inparams:raiseValueError("`stop` found in both the input and default params.")params["stop"]=stopmessage_dicts=[_convert_message_to_dict(m)forminmessages]returnmessage_dicts,paramsdef_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:message_dicts,params=self._create_message_dicts(messages,stop)params={**params,**kwargs,"stream":True}default_chunk_class=AIMessageChunkforchunkinself.completion_with_retry(messages=message_dicts,run_manager=run_manager,**params):ifnotisinstance(chunk,dict):chunk=chunk.model_dump()iflen(chunk["choices"])==0:continuedelta=chunk["choices"][0]["delta"]chunk=_convert_delta_to_message_chunk(delta,default_chunk_class)default_chunk_class=chunk.__class__cg_chunk=ChatGenerationChunk(message=chunk)ifrun_manager:run_manager.on_llm_new_token(chunk.content,chunk=cg_chunk)yieldcg_chunkasyncdef_astream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[ChatGenerationChunk]:message_dicts,params=self._create_message_dicts(messages,stop)params={**params,**kwargs,"stream":True}default_chunk_class=AIMessageChunkasyncforchunkinawaitacompletion_with_retry(self,messages=message_dicts,run_manager=run_manager,**params):ifnotisinstance(chunk,dict):chunk=chunk.model_dump()iflen(chunk["choices"])==0:continuedelta=chunk["choices"][0]["delta"]chunk=_convert_delta_to_message_chunk(delta,default_chunk_class)default_chunk_class=chunk.__class__cg_chunk=ChatGenerationChunk(message=chunk)ifrun_manager:awaitrun_manager.on_llm_new_token(chunk.content,chunk=cg_chunk)yieldcg_chunkasyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->ChatResult:should_stream=streamifstreamisnotNoneelseself.streamingifshould_stream:stream_iter=self._astream(messages=messages,stop=stop,run_manager=run_manager,**kwargs)returnawaitagenerate_from_stream(stream_iter)message_dicts,params=self._create_message_dicts(messages,stop)params={**params,**kwargs}response=awaitacompletion_with_retry(self,messages=message_dicts,run_manager=run_manager,**params)returnself._create_chat_result(response)
[docs]defbind_tools(self,tools:Sequence[Union[Dict[str,Any],Type[BaseModel],Callable,BaseTool]],**kwargs:Any,)->Runnable[LanguageModelInput,BaseMessage]:"""Bind tool-like objects to this chat model. LiteLLM expects tools argument in OpenAI format. Args: tools: A list of tool definitions to bind to this chat model. Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic models, callables, and BaseTools will be automatically converted to their schema dictionary representation. tool_choice: Which tool to require the model to call. Must be the name of the single provided function or "auto" to automatically determine which function to call (if any), or a dict of the form: {"type": "function", "function": {"name": <<tool_name>>}}. **kwargs: Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. """formatted_tools=[convert_to_openai_tool(tool)fortoolintools]returnsuper().bind(tools=formatted_tools,**kwargs)
@propertydef_identifying_params(self)->Dict[str,Any]:"""Get the identifying parameters."""set_model_value=self.modelifself.model_nameisnotNone:set_model_value=self.model_namereturn{"model":set_model_value,"temperature":self.temperature,"top_p":self.top_p,"top_k":self.top_k,"n":self.n,}@propertydef_llm_type(self)->str:return"litellm-chat"