from__future__importannotationsimportloggingfromtypingimportAny,Callable,Dict,List,Optionalfromlangchain_core._api.deprecationimportdeprecatedfromlangchain_core.callbacksimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_models.llmsimportLLMfromlangchain_core.load.serializableimportSerializablefromlangchain_core.pydantic_v1importField,SecretStrfromlangchain_core.utilsimportconvert_to_secret_str,get_from_dict_or_env,pre_initfromtenacityimport(before_sleep_log,retry,retry_if_exception_type,stop_after_attempt,wait_exponential,)fromlangchain_community.llms.utilsimportenforce_stop_tokenslogger=logging.getLogger(__name__)def_create_retry_decorator(max_retries:int)->Callable[[Any],Any]:importcohere# support v4 and v5retry_conditions=(retry_if_exception_type(cohere.error.CohereError)ifhasattr(cohere,"error")elseretry_if_exception_type(Exception))min_seconds=4max_seconds=10# Wait 2^x * 1 second between each retry starting with# 4 seconds, then up to 10 seconds, then 10 seconds afterwardsreturnretry(reraise=True,stop=stop_after_attempt(max_retries),wait=wait_exponential(multiplier=1,min=min_seconds,max=max_seconds),retry=retry_conditions,before_sleep=before_sleep_log(logger,logging.WARNING),)
[docs]defcompletion_with_retry(llm:Cohere,**kwargs:Any)->Any:"""Use tenacity to retry the completion call."""retry_decorator=_create_retry_decorator(llm.max_retries)@retry_decoratordef_completion_with_retry(**kwargs:Any)->Any:returnllm.client.generate(**kwargs)return_completion_with_retry(**kwargs)
[docs]defacompletion_with_retry(llm:Cohere,**kwargs:Any)->Any:"""Use tenacity to retry the completion call."""retry_decorator=_create_retry_decorator(llm.max_retries)@retry_decoratorasyncdef_completion_with_retry(**kwargs:Any)->Any:returnawaitllm.async_client.generate(**kwargs)return_completion_with_retry(**kwargs)
[docs]@deprecated(since="0.0.30",removal="1.0",alternative_import="langchain_cohere.BaseCohere")classBaseCohere(Serializable):"""Base class for Cohere models."""client:Any#: :meta private:async_client:Any#: :meta private:model:Optional[str]=Field(default=None)"""Model name to use."""temperature:float=0.75"""A non-negative float that tunes the degree of randomness in generation."""cohere_api_key:Optional[SecretStr]=None"""Cohere API key. If not provided, will be read from the environment variable."""stop:Optional[List[str]]=Nonestreaming:bool=Field(default=False)"""Whether to stream the results."""user_agent:str="langchain""""Identifier for the application making the request."""@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate that api key and python package exists in environment."""try:importcohereexceptImportError:raiseImportError("Could not import cohere python package. ""Please install it with `pip install cohere`.")else:values["cohere_api_key"]=convert_to_secret_str(get_from_dict_or_env(values,"cohere_api_key","COHERE_API_KEY"))client_name=values["user_agent"]values["client"]=cohere.Client(api_key=values["cohere_api_key"].get_secret_value(),client_name=client_name,)values["async_client"]=cohere.AsyncClient(api_key=values["cohere_api_key"].get_secret_value(),client_name=client_name,)returnvalues
[docs]@deprecated(since="0.1.14",removal="1.0",alternative_import="langchain_cohere.Cohere")classCohere(LLM,BaseCohere):"""Cohere large language models. To use, you should have the ``cohere`` python package installed, and the environment variable ``COHERE_API_KEY`` set with your API key, or pass it as a named parameter to the constructor. Example: .. code-block:: python from langchain_community.llms import Cohere cohere = Cohere(model="gptd-instruct-tft", cohere_api_key="my-api-key") """max_tokens:int=256"""Denotes the number of tokens to predict per generation."""k:int=0"""Number of most likely tokens to consider at each step."""p:int=1"""Total probability mass of tokens to consider at each step."""frequency_penalty:float=0.0"""Penalizes repeated tokens according to frequency. Between 0 and 1."""presence_penalty:float=0.0"""Penalizes repeated tokens. Between 0 and 1."""truncate:Optional[str]=None"""Specify how the client handles inputs longer than the maximum token length: Truncate from START, END or NONE"""max_retries:int=10"""Maximum number of retries to make when generating."""classConfig:extra="forbid"@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling Cohere API."""return{"max_tokens":self.max_tokens,"temperature":self.temperature,"k":self.k,"p":self.p,"frequency_penalty":self.frequency_penalty,"presence_penalty":self.presence_penalty,"truncate":self.truncate,}@propertydeflc_secrets(self)->Dict[str,str]:return{"cohere_api_key":"COHERE_API_KEY"}@propertydef_identifying_params(self)->Dict[str,Any]:"""Get the identifying parameters."""return{**{"model":self.model},**self._default_params}@propertydef_llm_type(self)->str:"""Return type of llm."""return"cohere"def_invocation_params(self,stop:Optional[List[str]],**kwargs:Any)->dict:params=self._default_paramsifself.stopisnotNoneandstopisnotNone:raiseValueError("`stop` found in both the input and default params.")elifself.stopisnotNone:params["stop_sequences"]=self.stopelse:params["stop_sequences"]=stopreturn{**params,**kwargs}def_process_response(self,response:Any,stop:Optional[List[str]])->str:text=response.generations[0].text# If stop tokens are provided, Cohere's endpoint returns them.# In order to make this consistent with other endpoints, we strip them.ifstop:text=enforce_stop_tokens(text,stop)returntextdef_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call out to Cohere's generate endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = cohere("Tell me a joke.") """params=self._invocation_params(stop,**kwargs)response=completion_with_retry(self,model=self.model,prompt=prompt,**params)_stop=params.get("stop_sequences")returnself._process_response(response,_stop)asyncdef_acall(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Async call out to Cohere's generate endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = await cohere("Tell me a joke.") """params=self._invocation_params(stop,**kwargs)response=awaitacompletion_with_retry(self,model=self.model,prompt=prompt,**params)_stop=params.get("stop_sequences")returnself._process_response(response,_stop)