Source code for langchain_ai21.ai21_base
from typing import Any, Optional
from ai21 import AI21Client
from langchain_core.utils import from_env, secret_from_env
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from typing_extensions import Self
_DEFAULT_TIMEOUT_SEC = 300
[docs]
class AI21Base(BaseModel):
"""Base class for AI21 models."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
client: Any = Field(default=None, exclude=True) #: :meta private:
api_key: SecretStr = Field(
default_factory=secret_from_env("AI21_API_KEY", default="")
)
"""API key for AI21 API."""
api_host: str = Field(
default_factory=from_env("AI21_API_URL", default="https://api.ai21.com")
)
"""Host URL"""
timeout_sec: float = Field(
default_factory=lambda: float(
from_env("AI21_TIMEOUT_SEC", default=str(_DEFAULT_TIMEOUT_SEC))()
)
)
"""Timeout in seconds.
If not set, it will default to the value of the environment
variable `AI21_TIMEOUT_SEC` or 300 seconds.
"""
num_retries: Optional[int] = None
"""Maximum number of retries for API requests before giving up."""
@model_validator(mode="after")
def post_init(self) -> Self:
api_key = self.api_key
api_host = self.api_host
timeout_sec = self.timeout_sec
if (self.client or None) is None:
self.client = AI21Client(
api_key=api_key.get_secret_value(),
api_host=api_host,
timeout_sec=None if timeout_sec is None else float(timeout_sec),
via="langchain",
)
return self