Source code for langchain_community.chat_models.litellm_router
"""LiteLLM Router as LangChain Model."""fromtypingimportAny,AsyncIterator,Iterator,List,Mapping,Optionalfromlangchain_core.callbacks.managerimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_models.chat_modelsimport(agenerate_from_stream,generate_from_stream,)fromlangchain_core.messagesimportAIMessageChunk,BaseMessagefromlangchain_core.outputsimportChatGeneration,ChatGenerationChunk,ChatResultfromlangchain_community.chat_models.litellmimport(ChatLiteLLM,_convert_delta_to_message_chunk,_convert_dict_to_message,)token_usage_key_name="token_usage"# nosec # incorrectly flagged as passwordmodel_extra_key_name="model_extra"# nosec # incorrectly flagged as password
[docs]defget_llm_output(usage:Any,**params:Any)->dict:"""Get llm output from usage and params."""llm_output={token_usage_key_name:usage}# copy over metadata (metadata came from router completion call)metadata=params["metadata"]forkeyinmetadata:ifkeynotinllm_output:# if token usage in metadata, prefer metadata's copy of itllm_output[key]=metadata[key]returnllm_output
[docs]classChatLiteLLMRouter(ChatLiteLLM):"""LiteLLM Router as LangChain Model."""router:Anydef__init__(self,*,router:Any,**kwargs:Any)->None:"""Construct Chat LiteLLM Router."""super().__init__(router=router,**kwargs)# type: ignoreself.router=router@propertydef_llm_type(self)->str:return"LiteLLMRouter"def_prepare_params_for_router(self,params:Any)->None:# allow the router to set api_base based on its model choiceapi_base_key_name="api_base"ifapi_base_key_nameinparamsandparams[api_base_key_name]isNone:delparams[api_base_key_name]# add metadata so router can fill it belowparams.setdefault("metadata",{})
[docs]defset_default_model(self,model_name:str)->None:"""Set the default model to use for completion calls. Sets `self.model` to `model_name` if it is in the litellm router's (`self.router`) model list. This provides the default model to use for completion calls if no `model` kwarg is provided. """model_list=self.router.model_listifnotmodel_list:raiseValueError("model_list is None or empty.")forentryinmodel_list:ifentry["model_name"]==model_name:self.model=model_namereturnraiseValueError(f"Model {model_name} not found in model_list.")
def_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->ChatResult:should_stream=streamifstreamisnotNoneelseself.streamingifshould_stream:stream_iter=self._stream(messages,stop=stop,run_manager=run_manager,**kwargs)returngenerate_from_stream(stream_iter)message_dicts,params=self._create_message_dicts(messages,stop)params={**params,**kwargs}self._prepare_params_for_router(params)response=self.router.completion(messages=message_dicts,**params,)returnself._create_chat_result(response,**params)def_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:default_chunk_class=AIMessageChunkmessage_dicts,params=self._create_message_dicts(messages,stop)params={**params,**kwargs,"stream":True}self._prepare_params_for_router(params)forchunkinself.router.completion(messages=message_dicts,**params):iflen(chunk["choices"])==0:continuedelta=chunk["choices"][0]["delta"]chunk=_convert_delta_to_message_chunk(delta,default_chunk_class)default_chunk_class=chunk.__class__cg_chunk=ChatGenerationChunk(message=chunk)ifrun_manager:run_manager.on_llm_new_token(chunk.content,chunk=cg_chunk,**params)yieldcg_chunkasyncdef_astream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[ChatGenerationChunk]:default_chunk_class=AIMessageChunkmessage_dicts,params=self._create_message_dicts(messages,stop)params={**params,**kwargs,"stream":True}self._prepare_params_for_router(params)asyncforchunkinawaitself.router.acompletion(messages=message_dicts,**params):iflen(chunk["choices"])==0:continuedelta=chunk["choices"][0]["delta"]chunk=_convert_delta_to_message_chunk(delta,default_chunk_class)default_chunk_class=chunk.__class__cg_chunk=ChatGenerationChunk(message=chunk)ifrun_manager:awaitrun_manager.on_llm_new_token(chunk.content,chunk=cg_chunk,**params)yieldcg_chunkasyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->ChatResult:should_stream=streamifstreamisnotNoneelseself.streamingifshould_stream:stream_iter=self._astream(messages=messages,stop=stop,run_manager=run_manager,**kwargs)returnawaitagenerate_from_stream(stream_iter)message_dicts,params=self._create_message_dicts(messages,stop)params={**params,**kwargs}self._prepare_params_for_router(params)response=awaitself.router.acompletion(messages=message_dicts,**params,)returnself._create_chat_result(response,**params)# from# https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/chat_models/openai.py# but modified to handle LiteLLM Usage classdef_combine_llm_outputs(self,llm_outputs:List[Optional[dict]])->dict:overall_token_usage:dict={}system_fingerprint=Noneforoutputinllm_outputs:ifoutputisNone:# Happens in streamingcontinuetoken_usage=output["token_usage"]iftoken_usageisnotNone:# get dict from LiteLLM Usage classfork,vintoken_usage.model_dump().items():ifkinoverall_token_usageandoverall_token_usage[k]isnotNone:overall_token_usage[k]+=velse:overall_token_usage[k]=vifsystem_fingerprintisNone:system_fingerprint=output.get("system_fingerprint")combined={"token_usage":overall_token_usage,"model_name":self.model}ifsystem_fingerprint:combined["system_fingerprint"]=system_fingerprintreturncombineddef_create_chat_result(self,response:Mapping[str,Any],**params:Any)->ChatResult:fromlitellm.utilsimportUsagegenerations=[]forresinresponse["choices"]:message=_convert_dict_to_message(res["message"])gen=ChatGeneration(message=message,generation_info=dict(finish_reason=res.get("finish_reason")),)generations.append(gen)token_usage=response.get("usage",Usage(prompt_tokens=0,total_tokens=0))llm_output=get_llm_output(token_usage,**params)returnChatResult(generations=generations,llm_output=llm_output)