[docs]classAI21PenaltyData(BaseModel):"""Parameters for AI21 penalty data."""scale:int=0applyToWhitespaces:bool=TrueapplyToPunctuations:bool=TrueapplyToNumbers:bool=TrueapplyToStopwords:bool=TrueapplyToEmojis:bool=True
[docs]classAI21(LLM):"""AI21 large language models. To use, you should have the environment variable ``AI21_API_KEY`` set with your API key or pass it as a named parameter to the constructor. Example: .. code-block:: python from langchain_community.llms import AI21 ai21 = AI21(ai21_api_key="my-api-key", model="j2-jumbo-instruct") """model:str="j2-jumbo-instruct""""Model name to use."""temperature:float=0.7"""What sampling temperature to use."""maxTokens:int=256"""The maximum number of tokens to generate in the completion."""minTokens:int=0"""The minimum number of tokens to generate in the completion."""topP:float=1.0"""Total probability mass of tokens to consider at each step."""presencePenalty:AI21PenaltyData=AI21PenaltyData()"""Penalizes repeated tokens."""countPenalty:AI21PenaltyData=AI21PenaltyData()"""Penalizes repeated tokens according to count."""frequencyPenalty:AI21PenaltyData=AI21PenaltyData()"""Penalizes repeated tokens according to frequency."""numResults:int=1"""How many completions to generate for each prompt."""logitBias:Optional[Dict[str,float]]=None"""Adjust the probability of specific tokens being generated."""ai21_api_key:Optional[SecretStr]=Nonestop:Optional[List[str]]=Nonebase_url:Optional[str]=None"""Base url to use, if None decides based on model name."""model_config=ConfigDict(extra="forbid",)
[docs]@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate that api key exists in environment."""ai21_api_key=convert_to_secret_str(get_from_dict_or_env(values,"ai21_api_key","AI21_API_KEY"))values["ai21_api_key"]=ai21_api_keyreturnvalues
@propertydef_default_params(self)->Dict[str,Any]:"""Get the default parameters for calling AI21 API."""return{"temperature":self.temperature,"maxTokens":self.maxTokens,"minTokens":self.minTokens,"topP":self.topP,"presencePenalty":self.presencePenalty.dict(),"countPenalty":self.countPenalty.dict(),"frequencyPenalty":self.frequencyPenalty.dict(),"numResults":self.numResults,"logitBias":self.logitBias,}@propertydef_identifying_params(self)->Dict[str,Any]:"""Get the identifying parameters."""return{**{"model":self.model},**self._default_params}@propertydef_llm_type(self)->str:"""Return type of llm."""return"ai21"def_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call out to AI21's complete endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = ai21("Tell me a joke.") """ifself.stopisnotNoneandstopisnotNone:raiseValueError("`stop` found in both the input and default params.")elifself.stopisnotNone:stop=self.stopelifstopisNone:stop=[]ifself.base_urlisnotNone:base_url=self.base_urlelse:ifself.modelin("j1-grande-instruct",):base_url="https://api.ai21.com/studio/v1/experimental"else:base_url="https://api.ai21.com/studio/v1"params={**self._default_params,**kwargs}self.ai21_api_key=cast(SecretStr,self.ai21_api_key)response=requests.post(url=f"{base_url}/{self.model}/complete",headers={"Authorization":f"Bearer {self.ai21_api_key.get_secret_value()}"},json={"prompt":prompt,"stopSequences":stop,**params},)ifresponse.status_code!=200:optional_detail=response.json().get("error")raiseValueError(f"AI21 /complete call failed with status code {response.status_code}."f" Details: {optional_detail}")response_json=response.json()returnresponse_json["completions"][0]["data"]["text"]