[docs]defenforce_stop_tokens(text:str,stop:List[str])->str:"""Cut off the text as soon as any stop words occur."""returnre.split("|".join(stop),text,maxsplit=1)[0]
logger=logging.getLogger(__name__)
[docs]defcompletion_with_retry(llm:Cohere,**kwargs:Any)->Any:"""Use tenacity to retry the completion call."""retry_decorator=_create_retry_decorator(llm.max_retries)@retry_decoratordef_completion_with_retry(**kwargs:Any)->Any:returnllm.client.generate(**kwargs)return_completion_with_retry(**kwargs)
[docs]defacompletion_with_retry(llm:Cohere,**kwargs:Any)->Any:"""Use tenacity to retry the completion call."""retry_decorator=_create_retry_decorator(llm.max_retries)@retry_decoratorasyncdef_completion_with_retry(**kwargs:Any)->Any:returnawaitllm.async_client.generate(**kwargs)return_completion_with_retry(**kwargs)
[docs]classBaseCohere(Serializable):"""Base class for Cohere models."""client:Any=None#: :meta private:async_client:Any=None#: :meta private:model:Optional[str]=Field(default=None)"""Model name to use."""temperature:Optional[float]=None"""A non-negative float that tunes the degree of randomness in generation."""cohere_api_key:Optional[SecretStr]=None"""Cohere API key. If not provided, will be read from the environment variable."""stop:Optional[List[str]]=Nonestreaming:bool=Field(default=False)"""Whether to stream the results."""user_agent:str="langchain:partner""""Identifier for the application making the request."""timeout_seconds:Optional[float]=300"""Timeout in seconds for the Cohere API request."""base_url:Optional[str]=None"""Override the default Cohere API URL."""@root_validator(pre=False,skip_on_failure=True)defvalidate_environment(cls,values:Dict)->Dict:"""Validate that api key and python package exists in environment."""values["cohere_api_key"]=convert_to_secret_str(get_from_dict_or_env(values,"cohere_api_key","COHERE_API_KEY"))client_name=values["user_agent"]timeout_seconds=values.get("timeout_seconds")values["client"]=cohere.Client(api_key=values["cohere_api_key"].get_secret_value(),timeout=timeout_seconds,client_name=client_name,base_url=values["base_url"],)values["async_client"]=cohere.AsyncClient(api_key=values["cohere_api_key"].get_secret_value(),client_name=client_name,timeout=timeout_seconds,base_url=values["base_url"],)returnvalues
[docs]classCohere(LLM,BaseCohere):"""Cohere large language models. To use, you should have the ``cohere`` python package installed, and the environment variable ``COHERE_API_KEY`` set with your API key, or pass it as a named parameter to the constructor. Example: .. code-block:: python from langchain_cohere import Cohere cohere = Cohere(cohere_api_key="my-api-key") """max_tokens:Optional[int]=None"""Denotes the number of tokens to predict per generation."""k:Optional[int]=None"""Number of most likely tokens to consider at each step."""p:Optional[int]=None"""Total probability mass of tokens to consider at each step."""frequency_penalty:Optional[float]=None"""Penalizes repeated tokens according to frequency. Between 0 and 1."""presence_penalty:Optional[float]=None"""Penalizes repeated tokens. Between 0 and 1."""truncate:Optional[str]=None"""Specify how the client handles inputs longer than the maximum token length: Truncate from START, END or NONE"""max_retries:int=10"""Maximum number of retries to make when generating."""classConfig:"""Configuration for this pydantic object."""arbitrary_types_allowed=Trueextra=Extra.forbid@propertydef_default_params(self)->Dict[str,Any]:"""Configurable parameters for calling Cohere's generate API."""base_params={"model":self.model,"temperature":self.temperature,"max_tokens":self.max_tokens,"k":self.k,"p":self.p,"frequency_penalty":self.frequency_penalty,"presence_penalty":self.presence_penalty,"truncate":self.truncate,}return{k:vfork,vinbase_params.items()ifvisnotNone}@propertydeflc_secrets(self)->Dict[str,str]:return{"cohere_api_key":"COHERE_API_KEY"}@propertydef_identifying_params(self)->Dict[str,Any]:"""Get the identifying parameters."""returnself._default_params@propertydef_llm_type(self)->str:"""Return type of llm."""return"cohere"def_invocation_params(self,stop:Optional[List[str]],**kwargs:Any)->dict:params=self._default_paramsifself.stopisnotNoneandstopisnotNone:raiseValueError("`stop` found in both the input and default params.")elifself.stopisnotNone:params["stop_sequences"]=self.stopelse:params["stop_sequences"]=stopreturn{**params,**kwargs}def_process_response(self,response:Any,stop:Optional[List[str]])->str:text=response.generations[0].text# If stop tokens are provided, Cohere's endpoint returns them.# In order to make this consistent with other endpoints, we strip them.ifstop:text=enforce_stop_tokens(text,stop)returntextdef_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call out to Cohere's generate endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = cohere("Tell me a joke.") """params=self._invocation_params(stop,**kwargs)response=completion_with_retry(self,model=self.model,prompt=prompt,**params)_stop=params.get("stop_sequences")returnself._process_response(response,_stop)asyncdef_acall(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Async call out to Cohere's generate endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = await cohere("Tell me a joke.") """params=self._invocation_params(stop,**kwargs)response=awaitacompletion_with_retry(self,model=self.model,prompt=prompt,**params)_stop=params.get("stop_sequences")returnself._process_response(response,_stop)