importcopyimportjsonimportuuidfromtypingimport(Any,AsyncIterator,Callable,Dict,Iterator,List,Literal,MutableMapping,Optional,Sequence,Type,Union,)fromcohere.typesimport(AssistantChatMessageV2,ChatMessageV2,ChatResponse,DocumentToolContent,NonStreamedChatResponse,SystemChatMessageV2,ToolCall,ToolCallV2,ToolCallV2Function,ToolChatMessageV2,UserChatMessageV2,)fromcohere.typesimportDocumentasDocumentV2fromlangchain_core._api.deprecationimportwarn_deprecatedfromlangchain_core.callbacksimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.documentsimportDocumentfromlangchain_core.language_modelsimportLanguageModelInputfromlangchain_core.language_models.chat_modelsimport(BaseChatModel,LangSmithParams,agenerate_from_stream,generate_from_stream,)fromlangchain_core.messagesimport(AIMessage,AIMessageChunk,BaseMessage,ChatMessage,HumanMessage,SystemMessage,ToolCallChunk,ToolMessage,)fromlangchain_core.messagesimport(ToolCallasLC_ToolCall,)fromlangchain_core.messages.aiimportUsageMetadatafromlangchain_core.output_parsersimportJsonOutputParser,PydanticOutputParserfromlangchain_core.output_parsers.baseimportOutputParserLikefromlangchain_core.output_parsers.openai_toolsimport(JsonOutputKeyToolsParser,PydanticToolsParser,)fromlangchain_core.outputsimportChatGeneration,ChatGenerationChunk,ChatResultfromlangchain_core.runnablesimportRunnablefromlangchain_core.toolsimportBaseToolfrompydanticimportBaseModel,ConfigDict,PrivateAttrfromlangchain_cohere.cohere_agentimport(_convert_to_cohere_tool,_format_to_cohere_tools_v2,)fromlangchain_cohere.llmsimportBaseCoherefromlangchain_cohere.react_multi_hop.promptimportconvert_to_documentsLC_TOOL_CALL_TEMPLATE={"id":"","type":"function","function":{"name":"","arguments":"",},}def_message_to_cohere_tool_results(messages:List[BaseMessage],tool_message_index:int)->List[Dict[str,Any]]:"""Get tool_results from messages."""tool_results=[]tool_message=messages[tool_message_index]ifnotisinstance(tool_message,ToolMessage):raiseValueError("The message index does not correspond to an instance of ToolMessage")messages_until_tool=messages[:tool_message_index]previous_ai_message=[messageformessageinmessages_until_toolifisinstance(message,AIMessage)andmessage.tool_calls][-1]tool_results.extend([{"call":ToolCall(name=lc_tool_call["name"],parameters=lc_tool_call["args"],),"outputs":convert_to_documents(tool_message.content),}forlc_tool_callinprevious_ai_message.tool_callsiflc_tool_call["id"]==tool_message.tool_call_id])returntool_resultsdef_get_curr_chat_turn_messages(messages:List[BaseMessage])->List[BaseMessage]:"""Get the messages for the current chat turn."""current_chat_turn_messages=[]formessageinmessages[::-1]:current_chat_turn_messages.append(message)ifisinstance(message,HumanMessage):breakreturncurrent_chat_turn_messages[::-1]def_messages_to_cohere_tool_results_curr_chat_turn(messages:List[BaseMessage],)->List[Dict[str,Any]]:"""Get tool_results from messages."""tool_results=[]curr_chat_turn_messages=_get_curr_chat_turn_messages(messages)formessageincurr_chat_turn_messages:ifisinstance(message,ToolMessage):tool_message=messageprevious_ai_msgs=[messageformessageincurr_chat_turn_messagesifisinstance(message,AIMessage)andmessage.tool_calls]ifprevious_ai_msgs:previous_ai_msg=previous_ai_msgs[-1]tool_results.extend([{"call":ToolCall(name=lc_tool_call["name"],parameters=lc_tool_call["args"],),"outputs":convert_to_documents(tool_message.content),}forlc_tool_callinprevious_ai_msg.tool_callsiflc_tool_call["id"]==tool_message.tool_call_id])returntool_results
[docs]defget_role(message:BaseMessage)->str:"""Get the role of the message. Args: message: The message. Returns: The role of the message. Raises: ValueError: If the message is of an unknown type. """ifisinstance(message,ChatMessage)orisinstance(message,HumanMessage):return"User"elifisinstance(message,AIMessage):return"Chatbot"elifisinstance(message,SystemMessage):return"System"elifisinstance(message,ToolMessage):return"Tool"else:raiseValueError(f"Got unknown type {type(message).__name__}")
def_get_message_cohere_format(message:BaseMessage,tool_results:Optional[List[Dict[Any,Any]]])->Dict[str,Union[str,List[LC_ToolCall],List[ToolCall],List[Union[str,Dict[Any,Any]]],List[Dict[Any,Any]],None,],]:"""Get the formatted message as required in cohere's api. Args: message: The BaseMessage. tool_results: The tool results if any Returns: The formatted message as required in cohere's api. """ifisinstance(message,AIMessage):return{"role":get_role(message),"message":message.content,"tool_calls":_get_tool_call_cohere_format(message.tool_calls),}elifisinstance(message,HumanMessage)orisinstance(message,SystemMessage):return{"role":get_role(message),"message":message.content}elifisinstance(message,ToolMessage):return{"role":get_role(message),"tool_results":tool_results}else:raiseValueError(f"Got unknown type {message}")def_get_tool_call_cohere_format(tool_calls:List[LC_ToolCall])->List[ToolCall]:"""Convert LangChain tool calls into Cohere's format"""cohere_tool_calls=[]forlc_tool_callintool_calls:name=lc_tool_call.get("name")parameters=lc_tool_call.get("args")id=lc_tool_call.get("id")cohere_tool_calls.append(ToolCall(name=name,parameters=parameters,id=id))returncohere_tool_calls
[docs]defget_cohere_chat_request(messages:List[BaseMessage],*,documents:Optional[List[Document]]=None,connectors:Optional[List[Dict[str,str]]]=None,stop_sequences:Optional[List[str]]=None,**kwargs:Any,)->Dict[str,Any]:"""Get the request for the Cohere chat API. Args: messages: The messages. connectors: The connectors. **kwargs: The keyword arguments. Returns: The request for the Cohere chat API. """ifconnectorsor"connectors"inkwargs:warn_deprecated(since="0.3.3",message=("The 'connectors' parameter is deprecated as of version 0.3.3.\n""Please use the 'tools' parameter instead."),removal="0.4.0",)additional_kwargs=messages[-1].additional_kwargs# cohere SDK will fail loudly if both connectors and documents are providedifadditional_kwargs.get("documents",[])anddocumentsandlen(documents)>0:raiseValueError("Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option."# noqa: E501)parsed_docs:Optional[Union[List[Document],List[Dict]]]=Noneif"documents"inadditional_kwargs:parsed_docs=(additional_kwargs["documents"]iflen(additional_kwargs.get("documents",[])or[])>0elseNone)elif(documentsisnotNone)and(len(documents)>0):parsed_docs=documentsformatted_docs:Optional[List[Dict[str,Any]]]=Noneifparsed_docs:formatted_docs=[]fori,parsed_docinenumerate(parsed_docs):ifisinstance(parsed_doc,Document):formatted_docs.append({"text":parsed_doc.page_content,"id":parsed_doc.metadata.get("id")orf"doc-{str(i)}",})elifisinstance(parsed_doc,dict):formatted_docs.append(parsed_doc)# by enabling automatic prompt truncation, the probability of request failure is# reduced with minimal impact on response qualityprompt_truncation=("AUTO"ifformatted_docsisnotNoneorconnectorsisnotNoneelseNone)tool_results:Optional[List[Dict[str,Any]]]=(_messages_to_cohere_tool_results_curr_chat_turn(messages)orkwargs.get("tool_results"))ifnottool_results:tool_results=None# check if the last message is a tool message or human messageifnot(isinstance(messages[-1],ToolMessage)orisinstance(messages[-1],HumanMessage)):raiseValueError("The last message is not an ToolMessage or HumanMessage")chat_history=[]temp_tool_results=[]# if force_single_step is set to False, then only message is empty in request if there is tool call # noqa: E501ifnotkwargs.get("force_single_step"):fori,messageinenumerate(messages[:-1]):# If there are multiple tool messages, then we need to aggregate them into one single tool message to pass into chat history # noqa: E501ifisinstance(message,ToolMessage):temp_tool_results+=_message_to_cohere_tool_results(messages,i)if(i==len(messages)-1)ornot(isinstance(messages[i+1],ToolMessage)):cohere_message=_get_message_cohere_format(message,temp_tool_results)chat_history.append(cohere_message)temp_tool_results=[]else:chat_history.append(_get_message_cohere_format(message,None))message_str=""iftool_resultselsemessages[-1].contentelse:message_str=""# if force_single_step is set to True, then message is the last human message in the conversation # noqa: E501fori,messageinenumerate(messages[:-1]):ifisinstance(message,AIMessage)andmessage.tool_calls:continue# If there are multiple tool messages, then we need to aggregate them into one single tool message to pass into chat history # noqa: E501ifisinstance(message,ToolMessage):temp_tool_results+=_message_to_cohere_tool_results(messages,i)if(i==len(messages)-1)ornot(isinstance(messages[i+1],ToolMessage)):cohere_message=_get_message_cohere_format(message,temp_tool_results)chat_history.append(cohere_message)temp_tool_results=[]else:chat_history.append(_get_message_cohere_format(message,None))# Add the last human message in the conversation to the message stringformessageinmessages[::-1]:if(isinstance(message,HumanMessage))and(message.content):message_str=message.contentbreakreq={"message":message_str,"chat_history":chat_history,"tool_results":tool_results,"documents":formatted_docs,"connectors":connectors,"prompt_truncation":prompt_truncation,"stop_sequences":stop_sequences,**kwargs,}return{k:vfork,vinreq.items()ifvisnotNone}
[docs]defget_role_v2(message:BaseMessage)->str:"""Get the role of the message (V2). Args: message: The message. Returns: The role of the message. Raises: ValueError: If the message is of an unknown type. """ifisinstance(message,ChatMessage)orisinstance(message,HumanMessage):return"user"elifisinstance(message,AIMessage):return"assistant"elifisinstance(message,SystemMessage):return"system"elifisinstance(message,ToolMessage):return"tool"else:raiseValueError(f"Got unknown type {type(message).__name__}")
def_get_message_cohere_format_v2(message:BaseMessage,tool_results:Optional[List[MutableMapping]]=None)->ChatMessageV2:"""Get the formatted message as required in cohere's api (V2). Args: message: The BaseMessage. tool_results: The tool results if any Returns: The formatted message as required in cohere's api. """ifisinstance(message,AIMessage):ifmessage.tool_calls:returnAssistantChatMessageV2(role=get_role_v2(message),tool_plan=message.contentifmessage.contentelse"I will assist you using the tools provided.",tool_calls=[ToolCallV2(id=tool_call.get("id"),type="function",function=ToolCallV2Function(name=tool_call.get("name"),arguments=json.dumps(tool_call.get("args")),),)fortool_callinmessage.tool_calls],)returnAssistantChatMessageV2(role=get_role_v2(message),content=message.content,)elifisinstance(message,HumanMessage):returnUserChatMessageV2(role=get_role_v2(message),content=message.content,)elifisinstance(message,SystemMessage):returnSystemChatMessageV2(role=get_role_v2(message),content=message.content,)elifisinstance(message,ToolMessage):iftool_resultsisNone:raiseValueError("Tool results are required for ToolMessage")content=[DocumentToolContent(type="document",document=DocumentV2(data=dict(tool_result),),)fortool_resultintool_results]ifnotcontent:content=[DocumentToolContent(type="document",document=DocumentV2(data={"output":""}))]returnToolChatMessageV2(role=get_role_v2(message),tool_call_id=message.tool_call_id,content=content,)else:raiseValueError(f"Got unknown type {message}")
[docs]defget_cohere_chat_request_v2(messages:List[BaseMessage],*,documents:Optional[List[Document]]=None,stop_sequences:Optional[List[str]]=None,**kwargs:Any,)->Dict[str,Any]:"""Get the request for the Cohere chat API (V2). Args: messages: The messages. **kwargs: The keyword arguments. Returns: The request for the Cohere chat API. """additional_kwargs=messages[-1].additional_kwargs# cohere SDK will fail loudly if both connectors and documents are providedifadditional_kwargs.get("documents",[])anddocumentsandlen(documents)>0:raiseValueError("Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option."# noqa: E501)parsed_docs:Optional[Union[List[Document],List[Dict]]]=Noneif"documents"inadditional_kwargs:parsed_docs=(additional_kwargs["documents"]iflen(additional_kwargs.get("documents",[])or[])>0elseNone)elif(documentsisnotNone)and(len(documents)>0):parsed_docs=documentsformatted_docs:Optional[List[Union[str,Dict[str,Any]]]]=Noneifparsed_docs:formatted_docs=[]fori,parsed_docinenumerate(parsed_docs):ifisinstance(parsed_doc,Document):formatted_docs.append({"id":parsed_doc.metadata.get("id")orf"doc-{str(i)}","data":{"text":parsed_doc.page_content,},})elifisinstance(parsed_doc,dict):if"data"notinparsed_doc:formatted_docs.append({"id":parsed_doc.get("id")orf"doc-{str(i)}","data":{**parsed_doc,},})else:formatted_docs.append(parsed_doc)elifisinstance(parsed_doc,str):formatted_docs.append(parsed_doc)else:formatted_docs.append({"data":parsed_doc})# check if the last message is a tool message or human messageifnot(isinstance(messages[-1],ToolMessage)orisinstance(messages[-1],HumanMessage)):raiseValueError("The last message is not an ToolMessage or HumanMessage")ifkwargs.get("preamble"):messages=[SystemMessage(content=str(kwargs.get("preamble")))]+messagesdelkwargs["preamble"]ifkwargs.get("connectors"):warn_deprecated("0.4.0",message=("The 'connectors' parameter is deprecated as of version 0.4.0.\n""Please use the 'tools' parameter instead."),removal="0.4.0",)raiseValueError("The 'connectors' parameter is deprecated as of version 0.4.0.")chat_history_with_curr_msg=[]formessageinmessages:ifisinstance(message,ToolMessage):tool_output=convert_to_documents(message.content)chat_history_with_curr_msg.append(_get_message_cohere_format_v2(message,tool_output))else:chat_history_with_curr_msg.append(_get_message_cohere_format_v2(message,None))req={"messages":chat_history_with_curr_msg,"documents":formatted_docs,"stop_sequences":stop_sequences,**kwargs,}return{k:vfork,vinreq.items()ifvisnotNone}
[docs]classChatCohere(BaseChatModel,BaseCohere):""" Implements the BaseChatModel (and BaseLanguageModel) interface with Cohere's large language models. Find out more about us at https://cohere.com and https://huggingface.co/CohereForAI This implementation uses the Chat API - see https://docs.cohere.com/reference/chat To use this you'll need to a Cohere API key - either pass it to cohere_api_key parameter or set the COHERE_API_KEY environment variable. API keys are available on https://cohere.com - it's free to sign up and trial API keys work with this implementation. Basic Example: .. code-block:: python from langchain_cohere import ChatCohere from langchain_core.messages import HumanMessage llm = ChatCohere(cohere_api_key="{API KEY}") message = [HumanMessage(content="Hello, can you introduce yourself?")] print(llm.invoke(message).content) """preamble:Optional[str]=None_default_model_name:Optional[str]=PrivateAttr(default=None)# Used internally to cache API calls to list models.model_config=ConfigDict(populate_by_name=True,arbitrary_types_allowed=True,)
[docs]defwith_structured_output(self,schema:Union[Dict,Type[BaseModel]],method:Literal["function_calling","tool_calling","json_mode","json_schema"]="json_schema",**kwargs:Any,)->Runnable[LanguageModelInput,Union[Dict,BaseModel]]:"""Model wrapper that returns outputs formatted to match the given schema. Given schema can be a Pydantic class or a dict. Args: schema: The output schema as a dict or a Pydantic class. If a Pydantic class then the model output will be an object of that class. If a dict then the model output will be a dict. method: The method for steering model generation, one of: - "function_calling" or "tool_calling": Uses Cohere's tool-calling (formerly called function calling) API: https://docs.cohere.com/v2/docs/tool-use - "json_schema": Uses Cohere's Structured Output API: https://docs.cohere.com/docs/structured-outputs Allows the user to pass a json schema (or pydantic) to the model for structured output. This is the default method. Supported for "command-r", "command-r-plus", and later models. - "json_mode": Uses Cohere's Structured Output API: https://docs.cohere.com/docs/structured-outputs Supported for "command-r", "command-r-plus", and later models. Returns: A Runnable that takes any ChatModel input and returns either a dict or Pydantic class as output. """if(notschema)and(method!="json_mode"):raiseValueError("schema must be specified when method is not 'json_mode'. "f"Received {schema}.")is_pydantic_schema=isinstance(schema,type)andissubclass(schema,BaseModel)ifmethod=="function_calling"ormethod=="tool_calling":llm=self.bind_tools([schema],**kwargs)ifis_pydantic_schema:output_parser:OutputParserLike=PydanticToolsParser(tools=[schema],first_tool_only=True)else:key_name=_convert_to_cohere_tool(schema)["name"]output_parser=JsonOutputKeyToolsParser(key_name=key_name,first_tool_only=True)elifmethod=="json_mode":# Refers to Cohere's `json_object` modellm=self.bind(response_format={"type":"json_object"})output_parser=(PydanticOutputParser(pydantic_object=schema)# type: ignore[arg-type]ifis_pydantic_schemaelseJsonOutputParser())elifmethod=="json_schema":response_format=(dict(schema.model_json_schema().items()# type: ignore[union-attr])ifis_pydantic_schemaelseschema)cohere_response_format:Dict[Any,Any]={"type":"json_object"}cohere_response_format["schema"]={k:vfork,vinresponse_format.items()# type: ignore[union-attr]}llm=self.bind(response_format=cohere_response_format)ifis_pydantic_schema:output_parser=PydanticOutputParser(pydantic_object=schema)else:output_parser=JsonOutputParser()else:raiseValueError(f"Unrecognized method argument. Expected one of 'function_calling' or "f"or 'json_schema' or 'json_mode'. Received: '{method}'")returnllm|output_parser
@propertydef_llm_type(self)->str:"""Return type of chat model."""return"cohere-chat"@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling Cohere API."""base_params={"model":self.model,"temperature":self.temperature,"preamble":self.preamble,}return{k:vfork,vinbase_params.items()ifvisnotNone}@propertydef_identifying_params(self)->Dict[str,Any]:"""Get the identifying parameters."""returnself._default_paramsdef_get_ls_params(self,stop:Optional[List[str]]=None,**kwargs:Any)->LangSmithParams:"""Get standard params for tracing."""params=self._get_invocation_params(stop=stop,**kwargs)ls_params=LangSmithParams(ls_provider="cohere",ls_model_name=self.model_name,ls_model_type="chat",ls_temperature=params.get("temperature",self.temperature),)ifls_max_tokens:=params.get("max_tokens"):ls_params["ls_max_tokens"]=ls_max_tokensifls_stop:=stoporparams.get("stop",None)orself.stop:ls_params["ls_stop"]=ls_stopreturnls_paramsdef_stream_v1(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:request=get_cohere_chat_request(messages,stop_sequences=stop,**self._default_params,**kwargs)ifhasattr(self.client,"chat_stream"):# detect and support sdk v5stream=self.client.chat_stream(**request)else:stream=self.client.chat(**request,stream=True)fordatainstream:ifdata.event_type=="text-generation":delta=data.textchunk=ChatGenerationChunk(message=AIMessageChunk(content=delta))ifrun_manager:run_manager.on_llm_new_token(delta,chunk=chunk)yieldchunkifdata.event_type=="tool-calls-chunk":ifdata.tool_call_delta:delta=data.tool_call_deltacohere_tool_call_chunk=_format_cohere_tool_calls([delta])[0]message=AIMessageChunk(content="",tool_call_chunks=[ToolCallChunk(name=cohere_tool_call_chunk["function"].get("name"),args=cohere_tool_call_chunk["function"].get("arguments"),id=cohere_tool_call_chunk.get("id"),index=delta.index,)],)chunk=ChatGenerationChunk(message=message)else:delta=data.textchunk=ChatGenerationChunk(message=AIMessageChunk(content=delta))ifrun_manager:run_manager.on_llm_new_token(delta,chunk=chunk)yieldchunkelifdata.event_type=="stream-end":generation_info=self._get_generation_info(data.response)message=AIMessageChunk(content="",additional_kwargs=generation_info,)yieldChatGenerationChunk(message=message,generation_info=generation_info,)def_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:# Workaround to allow create_cohere_react_agent to work with the# current implementation. create_cohere_react_agent relies on the# 'raw_prompting' parameter to be set, which is only available# in the v1 API.# TODO: Remove this workaround once create_cohere_react_agent is# updated to work with the v2 API.ifkwargs.get("raw_prompting"):forvalueinself._stream_v1(messages,stop=stop,run_manager=run_manager,**kwargs):yieldvaluereturnrequest=get_cohere_chat_request_v2(messages,stop_sequences=stop,**self._default_params,**kwargs)stream=self.client.v2.chat_stream(**request)curr_tool_call:Dict[str,Any]=copy.deepcopy(LC_TOOL_CALL_TEMPLATE)tool_calls=[]fordatainstream:ifdata.type=="content-delta":delta=data.delta.message.content.textchunk=ChatGenerationChunk(message=AIMessageChunk(content=delta))ifrun_manager:run_manager.on_llm_new_token(delta,chunk=chunk)yieldchunkelifdata.typein{"tool-call-start","tool-call-delta","tool-plan-delta","tool-call-end",}:# tool-call-start: Contains the name of the tool function.# No arguments are included# tool-call-delta: Contains the arguments of the tool function.# The function name is not included# tool-plan-delta: Contains a chunk of the tool-plan message# tool-call-end: End of tool call streamingifdata.typein{"tool-call-start","tool-call-delta"}:index=data.indexdelta=data.delta.message# To construct the current tool call you need# to buffer all the deltasifdata.type=="tool-call-start":curr_tool_call["id"]=delta.tool_calls.idcurr_tool_call["function"]["name"]=delta.tool_calls.function.nameelifdata.type=="tool-call-delta":curr_tool_call["function"]["arguments"]+=delta.tool_calls.function.arguments# If the current stream event is a tool-call-start,# then the ToolCallV2 object will only contain the function# name. If the current stream event is a tool-call-delta,# then the ToolCallV2 object will only contain the arguments.tool_call_v2=ToolCallV2(function=ToolCallV2Function(name=delta.tool_calls.function.nameifhasattr(delta.tool_calls.function,"name")elseNone,arguments=delta.tool_calls.function.argumentsifhasattr(delta.tool_calls.function,"arguments")elseNone,))cohere_tool_call_chunk=_format_cohere_tool_calls_v2([tool_call_v2])[0]message=AIMessageChunk(content="",tool_call_chunks=[ToolCallChunk(name=cohere_tool_call_chunk["function"].get("name"),args=cohere_tool_call_chunk["function"].get("arguments"),id=cohere_tool_call_chunk.get("id"),index=index,)],)chunk=ChatGenerationChunk(message=message)elifdata.type=="tool-plan-delta":delta=data.delta.message.tool_planchunk=ChatGenerationChunk(message=AIMessageChunk(content=delta))elifdata.type=="tool-call-end":# Maintain a list of all of the tool calls seen during streamingtool_calls.append(curr_tool_call)curr_tool_call=copy.deepcopy(LC_TOOL_CALL_TEMPLATE)ifrun_manager:run_manager.on_llm_new_token(delta,chunk=chunk)yieldchunkelifdata.type=="message-end":delta=data.deltageneration_info=self._get_stream_info_v2(delta,documents=request.get("documents"),tool_calls=tool_calls)message=AIMessageChunk(content="",additional_kwargs=generation_info,)yieldChatGenerationChunk(message=message,generation_info=generation_info,)asyncdef_astream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[ChatGenerationChunk]:request=get_cohere_chat_request_v2(messages,stop_sequences=stop,**self._default_params,**kwargs)stream=self.async_client.v2.chat_stream(**request)curr_tool_call:Dict[str,Any]=copy.deepcopy(LC_TOOL_CALL_TEMPLATE)tool_plan_deltas=[]tool_calls=[]asyncfordatainstream:ifdata.type=="content-delta":delta=data.delta.message.content.textchunk=ChatGenerationChunk(message=AIMessageChunk(content=delta))ifrun_manager:awaitrun_manager.on_llm_new_token(delta,chunk=chunk)yieldchunkelifdata.typein{"tool-call-start","tool-call-delta","tool-plan-delta","tool-call-end",}:# tool-call-start: Contains the name of the tool function.# No arguments are included# tool-call-delta: Contains the arguments of the tool function.# The function name is not included# tool-plan-delta: Contains a chunk of the tool-plan message# tool-call-end: End of tool call streamingifdata.typein{"tool-call-start","tool-call-delta"}:index=data.indexdelta=data.delta.message# To construct the current tool call you# need to buffer all the deltasifdata.type=="tool-call-start":curr_tool_call["id"]=delta.tool_calls.idcurr_tool_call["function"]["name"]=delta.tool_calls.function.nameelifdata.type=="tool-call-delta":curr_tool_call["function"]["arguments"]+=delta.tool_calls.function.arguments# If the current stream event is a tool-call-start,# then the ToolCallV2 object will only contain the# function name. If the current stream event is a# tool-call-delta, then the ToolCallV2 object will# only contain the arguments.tool_call_v2=ToolCallV2(function=ToolCallV2Function(name=delta.tool_calls.function.nameifhasattr(delta.tool_calls.function,"name")elseNone,arguments=delta.tool_calls.function.argumentsifhasattr(delta.tool_calls.function,"arguments")elseNone,))cohere_tool_call_chunk=_format_cohere_tool_calls_v2([tool_call_v2])[0]message=AIMessageChunk(content="",tool_call_chunks=[ToolCallChunk(name=cohere_tool_call_chunk["function"].get("name"),args=cohere_tool_call_chunk["function"].get("arguments"),id=cohere_tool_call_chunk.get("id"),index=index,)],)chunk=ChatGenerationChunk(message=message)elifdata.type=="tool-plan-delta":delta=data.delta.message.tool_planchunk=ChatGenerationChunk(message=AIMessageChunk(content=delta))tool_plan_deltas.append(delta)elifdata.type=="tool-call-end":# Maintain a list of all of the tool calls seen during streamingtool_calls.append(curr_tool_call)curr_tool_call=copy.deepcopy(LC_TOOL_CALL_TEMPLATE)ifrun_manager:awaitrun_manager.on_llm_new_token(delta,chunk=chunk)elifdata.type=="message-end":delta=data.deltageneration_info=self._get_stream_info_v2(delta,documents=request.get("documents"),tool_calls=tool_calls)tool_call_chunks=[]iftool_calls:content="".join(tool_plan_deltas)try:tool_call_chunks=[{"name":tool_call["function"].get("name"),"args":tool_call["function"].get("arguments"),"id":tool_call.get("id"),"index":tool_call.get("index"),}fortool_callintool_calls]exceptKeyError:passelse:content=""message=AIMessageChunk(content=content,additional_kwargs=generation_info,tool_call_chunks=tool_call_chunks,usage_metadata=generation_info.get("token_count"),)yieldChatGenerationChunk(message=message,generation_info=generation_info,)def_get_generation_info(self,response:NonStreamedChatResponse)->Dict[str,Any]:"""Get the generation info from cohere API response."""generation_info:Dict[str,Any]={"documents":response.documents,"citations":response.citations,"search_results":response.search_results,"search_queries":response.search_queries,"is_search_required":response.is_search_required,"generation_id":response.generation_id,}ifresponse.tool_calls:# Only populate tool_calls when 1) present on the response and# 2) has one or more calls.generation_info["tool_calls"]=_format_cohere_tool_calls(response.tool_calls)ifhasattr(response,"token_count"):generation_info["token_count"]=response.token_countelifhasattr(response,"meta")andresponse.metaisnotNone:ifhasattr(response.meta,"tokens")andresponse.meta.tokensisnotNone:generation_info["token_count"]=response.meta.tokens.dict()returngeneration_infodef_get_generation_info_v2(self,response:ChatResponse,documents:Optional[List[Dict[str,Any]]]=None)->Dict[str,Any]:"""Get the generation info from cohere API response (V2)."""generation_info:Dict[str,Any]={"id":response.id,"finish_reason":response.finish_reason,}ifdocuments:generation_info["documents"]=documentsifresponse.message:ifresponse.message.tool_plan:generation_info["tool_plan"]=response.message.tool_planifresponse.message.tool_calls:generation_info["tool_calls"]=_format_cohere_tool_calls_v2(response.message.tool_calls)ifresponse.message.content:generation_info["content"]=response.message.content[0].textifresponse.message.citations:generation_info["citations"]=response.message.citationsifresponse.usage:ifresponse.usage.tokens:generation_info["token_count"]=response.usage.tokens.dict()returngeneration_infodef_get_stream_info_v2(self,final_delta:Any,documents:Optional[List[Dict[str,Any]]]=None,tool_calls:Optional[List[Dict[str,Any]]]=None,)->Dict[str,Any]:"""Get the stream info from cohere API response (V2)."""input_tokens=final_delta.usage.billed_units.input_tokensoutput_tokens=final_delta.usage.billed_units.output_tokenstotal_tokens=input_tokens+output_tokensstream_info={"finish_reason":final_delta.finish_reason,"token_count":{"total_tokens":total_tokens,"input_tokens":input_tokens,"output_tokens":output_tokens,},}ifdocuments:stream_info["documents"]=documentsiftool_calls:stream_info["tool_calls"]=tool_callsreturnstream_infodef_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:ifself.streaming:stream_iter=self._stream(messages,stop=stop,run_manager=run_manager,**kwargs)returngenerate_from_stream(stream_iter)request=get_cohere_chat_request_v2(messages,stop_sequences=stop,**self._default_params,**kwargs)response=self.client.v2.chat(**request)generation_info=self._get_generation_info_v2(response,request.get("documents"))if"tool_calls"ingeneration_info:content=response.message.tool_planifresponse.message.tool_planelse""tool_calls=[lc_tool_callfortool_callinresponse.message.tool_callsif(lc_tool_call:=_convert_cohere_v2_tool_call_to_langchain(tool_call))]else:content=(response.message.content[0].textifresponse.message.contentelse"")tool_calls=[]usage_metadata=_get_usage_metadata_v2(response)message=AIMessage(content=content,additional_kwargs=generation_info,tool_calls=tool_calls,usage_metadata=usage_metadata,)returnChatResult(generations=[ChatGeneration(message=message,generation_info=generation_info)])asyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:ifself.streaming:stream_iter=self._astream(messages,stop=stop,run_manager=run_manager,**kwargs)returnawaitagenerate_from_stream(stream_iter)request=get_cohere_chat_request_v2(messages,stop_sequences=stop,**self._default_params,**kwargs)response=awaitself.async_client.v2.chat(**request)generation_info=self._get_generation_info_v2(response,request.get("documents"))if"tool_calls"ingeneration_info:content=response.message.tool_planifresponse.message.tool_planelse""tool_calls=[lc_tool_callfortool_callinresponse.tool_callsif(lc_tool_call:=_convert_cohere_v2_tool_call_to_langchain(tool_call))]else:content=(response.message.content[0].textifresponse.message.contentelse"")tool_calls=[]usage_metadata=_get_usage_metadata_v2(response)message=AIMessage(content=content,additional_kwargs=generation_info,tool_calls=tool_calls,usage_metadata=usage_metadata,)returnChatResult(generations=[ChatGeneration(message=message,generation_info=generation_info)])@propertydefmodel_name(self)->str:ifself.modelisnotNone:returnself.modelifself._default_model_nameisNone:self._default_model_name=self._get_default_model()returnself._default_model_name
[docs]defget_num_tokens(self,text:str)->int:"""Calculate number of tokens."""model=self.model_namereturnlen(self.client.tokenize(text=text,model=model).tokens)
def_format_cohere_tool_calls(tool_calls:Optional[List[ToolCall]]=None,)->List[Dict]:""" Formats a Cohere API response into the tool call format used elsewhere in Langchain. """ifnottool_calls:return[]formatted_tool_calls=[]fortool_callintool_calls:formatted_tool_calls.append({"id":uuid.uuid4().hex[:],"type":"function","function":{"name":tool_call.name,"arguments":json.dumps(tool_call.parameters),},})returnformatted_tool_callsdef_format_cohere_tool_calls_v2(tool_calls:Optional[List[ToolCallV2]]=None,)->List[Dict[str,Any]]:""" Formats a V2 Cohere API response into the tool call format used elsewhere in Langchain. """ifnottool_calls:return[]formatted_tool_calls=[]fortool_callintool_calls:ifnottool_call.function:continueformatted_tool_calls.append({"id":tool_call.idoruuid.uuid4().hex[:],"type":"function","function":{"name":tool_call.function.name,"arguments":tool_call.function.arguments,},})returnformatted_tool_callsdef_convert_cohere_tool_call_to_langchain(tool_call:ToolCall)->LC_ToolCall:"""Convert a Cohere tool call into langchain_core.messages.ToolCall"""_id=uuid.uuid4().hex[:]returnLC_ToolCall(name=tool_call.name,args=tool_call.parameters,id=_id)def_convert_cohere_v2_tool_call_to_langchain(tool_call:ToolCallV2,)->Optional[LC_ToolCall]:"""Convert a Cohere V2 tool call into langchain_core.messages.ToolCall"""_id=tool_call.idoruuid.uuid4().hex[:]ifnottool_call.functionornottool_call.function.name:returnNonereturnLC_ToolCall(name=str(tool_call.function.name),args=json.loads(tool_call.function.arguments)iftool_call.function.argumentselse{},id=_id,)def_get_usage_metadata(response:NonStreamedChatResponse)->Optional[UsageMetadata]:"""Get standard usage metadata from chat response."""metadata=response.metaifmetadata:iftokens:=metadata.tokens:input_tokens=int(tokens.input_tokensor0)output_tokens=int(tokens.output_tokensor0)total_tokens=input_tokens+output_tokensreturnUsageMetadata(input_tokens=input_tokens,output_tokens=output_tokens,total_tokens=total_tokens,)returnNonedef_get_usage_metadata_v2(response:ChatResponse)->Optional[UsageMetadata]:"""Get standard usage metadata from chat response."""metadata=response.usageifmetadata:iftokens:=metadata.tokens:input_tokens=int(tokens.input_tokensor0)output_tokens=int(tokens.output_tokensor0)total_tokens=input_tokens+output_tokensreturnUsageMetadata(input_tokens=input_tokens,output_tokens=output_tokens,total_tokens=total_tokens,)returnNone