importjsonimportuuidfromtypingimport(TYPE_CHECKING,Any,AsyncIterator,Callable,Dict,Iterator,List,Optional,Sequence,Type,Union,)fromcohere.typesimportNonStreamedChatResponse,ToolCallfromlangchain_core.callbacksimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.documentsimportDocumentfromlangchain_core.language_modelsimportLanguageModelInputfromlangchain_core.language_models.chat_modelsimport(BaseChatModel,LangSmithParams,agenerate_from_stream,generate_from_stream,)fromlangchain_core.messagesimport(AIMessage,AIMessageChunk,BaseMessage,ChatMessage,HumanMessage,SystemMessage,ToolCallChunk,ToolMessage,)fromlangchain_core.messagesimport(ToolCallasLC_ToolCall,)fromlangchain_core.messages.aiimportUsageMetadatafromlangchain_core.output_parsers.baseimportOutputParserLikefromlangchain_core.output_parsers.openai_toolsimport(JsonOutputKeyToolsParser,PydanticToolsParser,)fromlangchain_core.outputsimportChatGeneration,ChatGenerationChunk,ChatResultfromlangchain_core.pydantic_v1importBaseModel,PrivateAttrfromlangchain_core.runnablesimportRunnablefromlangchain_core.toolsimportBaseToolfromlangchain_cohere.cohere_agentimport(_convert_to_cohere_tool,_format_to_cohere_tools,)fromlangchain_cohere.llmsimportBaseCoherefromlangchain_cohere.react_multi_hop.promptimportconvert_to_documentsdef_message_to_cohere_tool_results(messages:List[BaseMessage],tool_message_index:int)->List[Dict[str,Any]]:"""Get tool_results from messages."""tool_results=[]tool_message=messages[tool_message_index]ifnotisinstance(tool_message,ToolMessage):raiseValueError("The message index does not correspond to an instance of ToolMessage")messages_until_tool=messages[:tool_message_index]previous_ai_message=[messageformessageinmessages_until_toolifisinstance(message,AIMessage)andmessage.tool_calls][-1]tool_results.extend([{"call":ToolCall(name=lc_tool_call["name"],parameters=lc_tool_call["args"],),"outputs":convert_to_documents(tool_message.content),}forlc_tool_callinprevious_ai_message.tool_callsiflc_tool_call["id"]==tool_message.tool_call_id])returntool_resultsdef_get_curr_chat_turn_messages(messages:List[BaseMessage])->List[BaseMessage]:"""Get the messages for the current chat turn."""current_chat_turn_messages=[]formessageinmessages[::-1]:current_chat_turn_messages.append(message)ifisinstance(message,HumanMessage):breakreturncurrent_chat_turn_messages[::-1]def_messages_to_cohere_tool_results_curr_chat_turn(messages:List[BaseMessage],)->List[Dict[str,Any]]:"""Get tool_results from messages."""tool_results=[]curr_chat_turn_messages=_get_curr_chat_turn_messages(messages)formessageincurr_chat_turn_messages:ifisinstance(message,ToolMessage):tool_message=messageprevious_ai_msgs=[messageformessageincurr_chat_turn_messagesifisinstance(message,AIMessage)andmessage.tool_calls]ifprevious_ai_msgs:previous_ai_msg=previous_ai_msgs[-1]tool_results.extend([{"call":ToolCall(name=lc_tool_call["name"],parameters=lc_tool_call["args"],),"outputs":convert_to_documents(tool_message.content),}forlc_tool_callinprevious_ai_msg.tool_callsiflc_tool_call["id"]==tool_message.tool_call_id])returntool_resultsifTYPE_CHECKING:fromcohere.typesimportListModelsResponse# noqa: F401
[docs]defget_role(message:BaseMessage)->str:"""Get the role of the message. Args: message: The message. Returns: The role of the message. Raises: ValueError: If the message is of an unknown type. """ifisinstance(message,ChatMessage)orisinstance(message,HumanMessage):return"User"elifisinstance(message,AIMessage):return"Chatbot"elifisinstance(message,SystemMessage):return"System"elifisinstance(message,ToolMessage):return"Tool"else:raiseValueError(f"Got unknown type {type(message).__name__}")
def_get_message_cohere_format(message:BaseMessage,tool_results:Optional[List[Dict[Any,Any]]])->Dict[str,Union[str,List[LC_ToolCall],List[Union[str,Dict[Any,Any]]],List[Dict[Any,Any]],None,],]:"""Get the formatted message as required in cohere's api. Args: message: The BaseMessage. tool_results: The tool results if any Returns: The formatted message as required in cohere's api. """ifisinstance(message,AIMessage):return{"role":get_role(message),"message":message.content,"tool_calls":message.tool_calls,}elifisinstance(message,HumanMessage)orisinstance(message,SystemMessage):return{"role":get_role(message),"message":message.content}elifisinstance(message,ToolMessage):return{"role":get_role(message),"tool_results":tool_results}else:raiseValueError(f"Got unknown type {message}")
[docs]defget_cohere_chat_request(messages:List[BaseMessage],*,documents:Optional[List[Document]]=None,connectors:Optional[List[Dict[str,str]]]=None,stop_sequences:Optional[List[str]]=None,**kwargs:Any,)->Dict[str,Any]:"""Get the request for the Cohere chat API. Args: messages: The messages. connectors: The connectors. **kwargs: The keyword arguments. Returns: The request for the Cohere chat API. """additional_kwargs=messages[-1].additional_kwargs# cohere SDK will fail loudly if both connectors and documents are providedifadditional_kwargs.get("documents",[])anddocumentsandlen(documents)>0:raiseValueError("Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option."# noqa: E501)parsed_docs:Optional[Union[List[Document],List[Dict]]]=Noneif"documents"inadditional_kwargs:parsed_docs=(additional_kwargs["documents"]iflen(additional_kwargs.get("documents",[])or[])>0elseNone)elif(documentsisnotNone)and(len(documents)>0):parsed_docs=documentsformatted_docs:Optional[List[Dict[str,Any]]]=Noneifparsed_docs:formatted_docs=[]fori,parsed_docinenumerate(parsed_docs):ifisinstance(parsed_doc,Document):formatted_docs.append({"text":parsed_doc.page_content,"id":parsed_doc.metadata.get("id")orf"doc-{str(i)}",})elifisinstance(parsed_doc,dict):formatted_docs.append(parsed_doc)# by enabling automatic prompt truncation, the probability of request failure is# reduced with minimal impact on response qualityprompt_truncation=("AUTO"ifformatted_docsisnotNoneorconnectorsisnotNoneelseNone)tool_results:Optional[List[Dict[str,Any]]]=(_messages_to_cohere_tool_results_curr_chat_turn(messages)orkwargs.get("tool_results"))ifnottool_results:tool_results=None# check if the last message is a tool message or human messageifnot(isinstance(messages[-1],ToolMessage)orisinstance(messages[-1],HumanMessage)):raiseValueError("The last message is not an ToolMessage or HumanMessage")chat_history=[]temp_tool_results=[]# if force_single_step is set to False, then only message is empty in request if there is tool call # noqa: E501ifnotkwargs.get("force_single_step"):fori,messageinenumerate(messages[:-1]):# If there are multiple tool messages, then we need to aggregate them into one single tool message to pass into chat history # noqa: E501ifisinstance(message,ToolMessage):temp_tool_results+=_message_to_cohere_tool_results(messages,i)if(i==len(messages)-1)ornot(isinstance(messages[i+1],ToolMessage)):cohere_message=_get_message_cohere_format(message,temp_tool_results)chat_history.append(cohere_message)temp_tool_results=[]else:chat_history.append(_get_message_cohere_format(message,None))message_str=""iftool_resultselsemessages[-1].contentelse:message_str=""# if force_single_step is set to True, then message is the last human message in the conversation # noqa: E501fori,messageinenumerate(messages[:-1]):ifisinstance(message,AIMessage)andmessage.tool_calls:continue# If there are multiple tool messages, then we need to aggregate them into one single tool message to pass into chat history # noqa: E501ifisinstance(message,ToolMessage):temp_tool_results+=_message_to_cohere_tool_results(messages,i)if(i==len(messages)-1)ornot(isinstance(messages[i+1],ToolMessage)):cohere_message=_get_message_cohere_format(message,temp_tool_results)chat_history.append(cohere_message)temp_tool_results=[]else:chat_history.append(_get_message_cohere_format(message,None))# Add the last human message in the conversation to the message stringformessageinmessages[::-1]:if(isinstance(message,HumanMessage))and(message.content):message_str=message.contentbreakreq={"message":message_str,"chat_history":chat_history,"tool_results":tool_results,"documents":formatted_docs,"connectors":connectors,"prompt_truncation":prompt_truncation,"stop_sequences":stop_sequences,**kwargs,}return{k:vfork,vinreq.items()ifvisnotNone}
[docs]classChatCohere(BaseChatModel,BaseCohere):""" Implements the BaseChatModel (and BaseLanguageModel) interface with Cohere's large language models. Find out more about us at https://cohere.com and https://huggingface.co/CohereForAI This implementation uses the Chat API - see https://docs.cohere.com/reference/chat To use this you'll need to a Cohere API key - either pass it to cohere_api_key parameter or set the COHERE_API_KEY environment variable. API keys are available on https://cohere.com - it's free to sign up and trial API keys work with this implementation. Basic Example: .. code-block:: python from langchain_cohere import ChatCohere from langchain_core.messages import HumanMessage llm = ChatCohere(cohere_api_key="{API KEY}") message = [HumanMessage(content="Hello, can you introduce yourself?")] print(llm.invoke(message).content) """preamble:Optional[str]=None_default_model_name:Optional[str]=PrivateAttr(default=None)# Used internally to cache API calls to list models.classConfig:"""Configuration for this pydantic object."""allow_population_by_field_name=Truearbitrary_types_allowed=True
[docs]defwith_structured_output(self,schema:Union[Dict,Type[BaseModel]],**kwargs:Any,)->Runnable[LanguageModelInput,Union[Dict,BaseModel]]:"""Model wrapper that returns outputs formatted to match the given schema. Args: schema: The output schema as a dict or a Pydantic class. If a Pydantic class then the model output will be an object of that class. If a dict then the model output will be a dict. Returns: A Runnable that takes any ChatModel input and returns either a dict or Pydantic class as output. """is_pydantic_schema=isinstance(schema,type)andissubclass(schema,BaseModel)llm=self.bind_tools([schema],**kwargs)ifis_pydantic_schema:output_parser:OutputParserLike=PydanticToolsParser(tools=[schema],first_tool_only=True)else:key_name=_convert_to_cohere_tool(schema)["name"]output_parser=JsonOutputKeyToolsParser(key_name=key_name,first_tool_only=True)returnllm|output_parser
@propertydef_llm_type(self)->str:"""Return type of chat model."""return"cohere-chat"@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling Cohere API."""base_params={"model":self.model,"temperature":self.temperature,"preamble":self.preamble,}return{k:vfork,vinbase_params.items()ifvisnotNone}@propertydef_identifying_params(self)->Dict[str,Any]:"""Get the identifying parameters."""returnself._default_paramsdef_get_ls_params(self,stop:Optional[List[str]]=None,**kwargs:Any)->LangSmithParams:"""Get standard params for tracing."""params=self._get_invocation_params(stop=stop,**kwargs)ls_params=LangSmithParams(ls_provider="cohere",ls_model_name=self.model_name,ls_model_type="chat",ls_temperature=params.get("temperature",self.temperature),)ifls_max_tokens:=params.get("max_tokens"):ls_params["ls_max_tokens"]=ls_max_tokensifls_stop:=stoporparams.get("stop",None)orself.stop:ls_params["ls_stop"]=ls_stopreturnls_paramsdef_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:request=get_cohere_chat_request(messages,stop_sequences=stop,**self._default_params,**kwargs)ifhasattr(self.client,"chat_stream"):# detect and support sdk v5stream=self.client.chat_stream(**request)else:stream=self.client.chat(**request,stream=True)fordatainstream:ifdata.event_type=="text-generation":delta=data.textchunk=ChatGenerationChunk(message=AIMessageChunk(content=delta))ifrun_manager:run_manager.on_llm_new_token(delta,chunk=chunk)yieldchunkifdata.event_type=="tool-calls-chunk":ifdata.tool_call_delta:delta=data.tool_call_deltacohere_tool_call_chunk=_format_cohere_tool_calls([delta])[0]message=AIMessageChunk(content="",tool_call_chunks=[ToolCallChunk(name=cohere_tool_call_chunk["function"].get("name"),args=cohere_tool_call_chunk["function"].get("arguments"),id=cohere_tool_call_chunk.get("id"),index=delta.index,)],)chunk=ChatGenerationChunk(message=message)else:delta=data.textchunk=ChatGenerationChunk(message=AIMessageChunk(content=delta))ifrun_manager:run_manager.on_llm_new_token(delta,chunk=chunk)yieldchunkelifdata.event_type=="stream-end":generation_info=self._get_generation_info(data.response)message=AIMessageChunk(content="",additional_kwargs=generation_info,)yieldChatGenerationChunk(message=message,generation_info=generation_info,)asyncdef_astream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[ChatGenerationChunk]:request=get_cohere_chat_request(messages,stop_sequences=stop,**self._default_params,**kwargs)ifhasattr(self.async_client,"chat_stream"):# detect and support sdk v5stream=self.async_client.chat_stream(**request)else:stream=self.async_client.chat(**request,stream=True)asyncfordatainstream:ifdata.event_type=="text-generation":delta=data.textchunk=ChatGenerationChunk(message=AIMessageChunk(content=delta))ifrun_manager:awaitrun_manager.on_llm_new_token(delta,chunk=chunk)yieldchunkelifdata.event_type=="stream-end":generation_info=self._get_generation_info(data.response)tool_call_chunks=[]iftool_calls:=generation_info.get("tool_calls"):content=data.response.texttry:tool_call_chunks=[{"name":tool_call["function"].get("name"),"args":tool_call["function"].get("arguments"),"id":tool_call.get("id"),"index":tool_call.get("index"),}fortool_callintool_calls]exceptKeyError:passelse:content=""ifisinstance(data.response,NonStreamedChatResponse):usage_metadata=_get_usage_metadata(data.response)else:usage_metadata=Nonemessage=AIMessageChunk(content=content,additional_kwargs=generation_info,tool_call_chunks=tool_call_chunks,usage_metadata=usage_metadata,)yieldChatGenerationChunk(message=message,generation_info=generation_info,)def_get_generation_info(self,response:NonStreamedChatResponse)->Dict[str,Any]:"""Get the generation info from cohere API response."""generation_info:Dict[str,Any]={"documents":response.documents,"citations":response.citations,"search_results":response.search_results,"search_queries":response.search_queries,"is_search_required":response.is_search_required,"generation_id":response.generation_id,}ifresponse.tool_calls:# Only populate tool_calls when 1) present on the response and# 2) has one or more calls.generation_info["tool_calls"]=_format_cohere_tool_calls(response.tool_calls)ifhasattr(response,"token_count"):generation_info["token_count"]=response.token_countelifhasattr(response,"meta")andresponse.metaisnotNone:ifhasattr(response.meta,"tokens")andresponse.meta.tokensisnotNone:generation_info["token_count"]=response.meta.tokens.dict()returngeneration_infodef_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:ifself.streaming:stream_iter=self._stream(messages,stop=stop,run_manager=run_manager,**kwargs)returngenerate_from_stream(stream_iter)request=get_cohere_chat_request(messages,stop_sequences=stop,**self._default_params,**kwargs)response=self.client.chat(**request)generation_info=self._get_generation_info(response)if"tool_calls"ingeneration_info:tool_calls=[_convert_cohere_tool_call_to_langchain(tool_call)fortool_callinresponse.tool_calls]else:tool_calls=[]usage_metadata=_get_usage_metadata(response)message=AIMessage(content=response.text,additional_kwargs=generation_info,tool_calls=tool_calls,usage_metadata=usage_metadata,)returnChatResult(generations=[ChatGeneration(message=message,generation_info=generation_info)])asyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:ifself.streaming:stream_iter=self._astream(messages,stop=stop,run_manager=run_manager,**kwargs)returnawaitagenerate_from_stream(stream_iter)request=get_cohere_chat_request(messages,stop_sequences=stop,**self._default_params,**kwargs)response=awaitself.async_client.chat(**request)generation_info=self._get_generation_info(response)if"tool_calls"ingeneration_info:tool_calls=[_convert_cohere_tool_call_to_langchain(tool_call)fortool_callinresponse.tool_calls]else:tool_calls=[]usage_metadata=_get_usage_metadata(response)message=AIMessage(content=response.text,additional_kwargs=generation_info,tool_calls=tool_calls,usage_metadata=usage_metadata,)returnChatResult(generations=[ChatGeneration(message=message,generation_info=generation_info)])def_get_default_model(self)->str:"""Fetches the current default model name."""response=self.client.models.list(default_only=True,endpoint="chat")# type: "ListModelsResponse"ifnotresponse.models:raiseException("invalid cohere list models response")ifnotresponse.models[0].name:raiseException("invalid cohere list models response")returnresponse.models[0].name@propertydefmodel_name(self)->str:ifself.modelisnotNone:returnself.modelifself._default_model_nameisNone:self._default_model_name=self._get_default_model()returnself._default_model_name
[docs]defget_num_tokens(self,text:str)->int:"""Calculate number of tokens."""model=self.model_namereturnlen(self.client.tokenize(text=text,model=model).tokens)
def_format_cohere_tool_calls(tool_calls:Optional[List[ToolCall]]=None,)->List[Dict]:""" Formats a Cohere API response into the tool call format used elsewhere in Langchain. """ifnottool_calls:return[]formatted_tool_calls=[]fortool_callintool_calls:formatted_tool_calls.append({"id":uuid.uuid4().hex[:],"function":{"name":tool_call.name,"arguments":json.dumps(tool_call.parameters),},"type":"function",})returnformatted_tool_callsdef_convert_cohere_tool_call_to_langchain(tool_call:ToolCall)->LC_ToolCall:"""Convert a Cohere tool call into langchain_core.messages.ToolCall"""_id=uuid.uuid4().hex[:]returnLC_ToolCall(name=tool_call.name,args=tool_call.parameters,id=_id)def_get_usage_metadata(response:NonStreamedChatResponse)->Optional[UsageMetadata]:"""Get standard usage metadata from chat response."""metadata=response.metaifmetadata:iftokens:=metadata.tokens:input_tokens=int(tokens.input_tokensor0)output_tokens=int(tokens.output_tokensor0)total_tokens=input_tokens+output_tokensreturnUsageMetadata(input_tokens=input_tokens,output_tokens=output_tokens,total_tokens=total_tokens,)returnNone