[docs]defget_role(message:BaseMessage)->str:"""Get the role of the message. Args: message: The message. Returns: The role of the message. Raises: ValueError: If the message is of an unknown type. """ifisinstance(message,ChatMessage)orisinstance(message,HumanMessage):return"User"elifisinstance(message,AIMessage):return"Chatbot"elifisinstance(message,SystemMessage):return"System"else:raiseValueError(f"Got unknown type {message}")
[docs]defget_cohere_chat_request(messages:List[BaseMessage],*,connectors:Optional[List[Dict[str,str]]]=None,**kwargs:Any,)->Dict[str,Any]:"""Get the request for the Cohere chat API. Args: messages: The messages. connectors: The connectors. **kwargs: The keyword arguments. Returns: The request for the Cohere chat API. """documents=(Noneif"source_documents"notinkwargselse[{"snippet":doc.page_content,"id":doc.metadata.get("id")orf"doc-{str(i)}",}fori,docinenumerate(kwargs["source_documents"])])kwargs.pop("source_documents",None)maybe_connectors=connectorsifdocumentsisNoneelseNone# by enabling automatic prompt truncation, the probability of request failure is# reduced with minimal impact on response qualityprompt_truncation=("AUTO"ifdocumentsisnotNoneorconnectorsisnotNoneelseNone)req={"message":messages[-1].content,"chat_history":[{"role":get_role(x),"message":x.content}forxinmessages[:-1]],"documents":documents,"connectors":maybe_connectors,"prompt_truncation":prompt_truncation,**kwargs,}return{k:vfork,vinreq.items()ifvisnotNone}
[docs]@deprecated(since="0.0.30",removal="1.0",alternative_import="langchain_cohere.ChatCohere")classChatCohere(BaseChatModel,BaseCohere):"""`Cohere` chat large language models. To use, you should have the ``cohere`` python package installed, and the environment variable ``COHERE_API_KEY`` set with your API key, or pass it as a named parameter to the constructor. Example: .. code-block:: python from langchain_community.chat_models import ChatCohere from langchain_core.messages import HumanMessage chat = ChatCohere(max_tokens=256, temperature=0.75) messages = [HumanMessage(content="knock knock")] chat.invoke(messages) """model_config=ConfigDict(populate_by_name=True,arbitrary_types_allowed=True,)@propertydef_llm_type(self)->str:"""Return type of chat model."""return"cohere-chat"@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling Cohere API."""return{"temperature":self.temperature,}@propertydef_identifying_params(self)->Dict[str,Any]:"""Get the identifying parameters."""return{**{"model":self.model},**self._default_params}def_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:request=get_cohere_chat_request(messages,**self._default_params,**kwargs)ifhasattr(self.client,"chat_stream"):# detect and support sdk v5stream=self.client.chat_stream(**request)else:stream=self.client.chat(**request,stream=True)fordatainstream:ifdata.event_type=="text-generation":delta=data.textchunk=ChatGenerationChunk(message=AIMessageChunk(content=delta))ifrun_manager:run_manager.on_llm_new_token(delta,chunk=chunk)yieldchunkasyncdef_astream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[ChatGenerationChunk]:request=get_cohere_chat_request(messages,**self._default_params,**kwargs)ifhasattr(self.async_client,"chat_stream"):# detect and support sdk v5stream=awaitself.async_client.chat_stream(**request)else:stream=awaitself.async_client.chat(**request,stream=True)asyncfordatainstream:ifdata.event_type=="text-generation":delta=data.textchunk=ChatGenerationChunk(message=AIMessageChunk(content=delta))ifrun_manager:awaitrun_manager.on_llm_new_token(delta,chunk=chunk)yieldchunkdef_get_generation_info(self,response:Any)->Dict[str,Any]:"""Get the generation info from cohere API response."""return{"documents":response.documents,"citations":response.citations,"search_results":response.search_results,"search_queries":response.search_queries,"token_count":response.token_count,}def_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:ifself.streaming:stream_iter=self._stream(messages,stop=stop,run_manager=run_manager,**kwargs)returngenerate_from_stream(stream_iter)request=get_cohere_chat_request(messages,**self._default_params,**kwargs)response=self.client.chat(**request)message=AIMessage(content=response.text)generation_info=Noneifhasattr(response,"documents"):generation_info=self._get_generation_info(response)returnChatResult(generations=[ChatGeneration(message=message,generation_info=generation_info)])asyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:ifself.streaming:stream_iter=self._astream(messages,stop=stop,run_manager=run_manager,**kwargs)returnawaitagenerate_from_stream(stream_iter)request=get_cohere_chat_request(messages,**self._default_params,**kwargs)response=self.client.chat(**request)message=AIMessage(content=response.text)generation_info=Noneifhasattr(response,"documents"):generation_info=self._get_generation_info(response)returnChatResult(generations=[ChatGeneration(message=message,generation_info=generation_info)])
[docs]defget_num_tokens(self,text:str)->int:"""Calculate number of tokens."""returnlen(self.client.tokenize(text=text).tokens)