Source code for langchain_community.llms.ai21

from typing import Any, Dict, List, Optional, cast

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from pydantic import BaseModel, ConfigDict, SecretStr


[docs] class AI21PenaltyData(BaseModel): """Parameters for AI21 penalty data.""" scale: int = 0 applyToWhitespaces: bool = True applyToPunctuations: bool = True applyToNumbers: bool = True applyToStopwords: bool = True applyToEmojis: bool = True
[docs] class AI21(LLM): """AI21 large language models. To use, you should have the environment variable ``AI21_API_KEY`` set with your API key or pass it as a named parameter to the constructor. Example: .. code-block:: python from langchain_community.llms import AI21 ai21 = AI21(ai21_api_key="my-api-key", model="j2-jumbo-instruct") """ model: str = "j2-jumbo-instruct" """Model name to use.""" temperature: float = 0.7 """What sampling temperature to use.""" maxTokens: int = 256 """The maximum number of tokens to generate in the completion.""" minTokens: int = 0 """The minimum number of tokens to generate in the completion.""" topP: float = 1.0 """Total probability mass of tokens to consider at each step.""" presencePenalty: AI21PenaltyData = AI21PenaltyData() """Penalizes repeated tokens.""" countPenalty: AI21PenaltyData = AI21PenaltyData() """Penalizes repeated tokens according to count.""" frequencyPenalty: AI21PenaltyData = AI21PenaltyData() """Penalizes repeated tokens according to frequency.""" numResults: int = 1 """How many completions to generate for each prompt.""" logitBias: Optional[Dict[str, float]] = None """Adjust the probability of specific tokens being generated.""" ai21_api_key: Optional[SecretStr] = None stop: Optional[List[str]] = None base_url: Optional[str] = None """Base url to use, if None decides based on model name.""" model_config = ConfigDict( extra="forbid", )
[docs] @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" ai21_api_key = convert_to_secret_str( get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY") ) values["ai21_api_key"] = ai21_api_key return values
@property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling AI21 API.""" return { "temperature": self.temperature, "maxTokens": self.maxTokens, "minTokens": self.minTokens, "topP": self.topP, "presencePenalty": self.presencePenalty.dict(), "countPenalty": self.countPenalty.dict(), "frequencyPenalty": self.frequencyPenalty.dict(), "numResults": self.numResults, "logitBias": self.logitBias, } @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" return {**{"model": self.model}, **self._default_params} @property def _llm_type(self) -> str: """Return type of llm.""" return "ai21" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """Call out to AI21's complete endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = ai21("Tell me a joke.") """ if self.stop is not None and stop is not None: raise ValueError("`stop` found in both the input and default params.") elif self.stop is not None: stop = self.stop elif stop is None: stop = [] if self.base_url is not None: base_url = self.base_url else: if self.model in ("j1-grande-instruct",): base_url = "https://api.ai21.com/studio/v1/experimental" else: base_url = "https://api.ai21.com/studio/v1" params = {**self._default_params, **kwargs} self.ai21_api_key = cast(SecretStr, self.ai21_api_key) response = requests.post( url=f"{base_url}/{self.model}/complete", headers={"Authorization": f"Bearer {self.ai21_api_key.get_secret_value()}"}, json={"prompt": prompt, "stopSequences": stop, **params}, ) if response.status_code != 200: optional_detail = response.json().get("error") raise ValueError( f"AI21 /complete call failed with status code {response.status_code}." f" Details: {optional_detail}" ) response_json = response.json() return response_json["completions"][0]["data"]["text"]