from __future__ import annotations
from enum import Enum, auto
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
import google.api_core
import google.generativeai as genai # type: ignore[import]
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LangSmithParams, LanguageModelInput
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.utils import secret_from_env
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
from langchain_google_genai._enums import (
HarmBlockThreshold,
HarmCategory,
)
[docs]
class GoogleModelFamily(str, Enum):
GEMINI = auto()
PALM = auto()
@classmethod
def _missing_(cls, value: Any) -> Optional["GoogleModelFamily"]:
if "gemini" in value.lower():
return GoogleModelFamily.GEMINI
elif "text-bison" in value.lower():
return GoogleModelFamily.PALM
return None
def _create_retry_decorator(
llm: BaseLLM,
*,
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Creates a retry decorator for Vertex / Palm LLMs."""
errors = [
google.api_core.exceptions.ResourceExhausted,
google.api_core.exceptions.ServiceUnavailable,
google.api_core.exceptions.Aborted,
google.api_core.exceptions.DeadlineExceeded,
google.api_core.exceptions.GoogleAPIError,
]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=max_retries, run_manager=run_manager
)
return decorator
def _completion_with_retry(
llm: GoogleGenerativeAI,
prompt: LanguageModelInput,
is_gemini: bool = False,
stream: bool = False,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(
llm, max_retries=llm.max_retries, run_manager=run_manager
)
@retry_decorator
def _completion_with_retry(
prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
) -> Any:
generation_config = kwargs.get("generation_config", {})
error_msg = (
"Your location is not supported by google-generativeai at the moment. "
"Try to use VertexAI LLM from langchain_google_vertexai"
)
try:
if is_gemini:
return llm.client.generate_content(
contents=prompt,
stream=stream,
generation_config=generation_config,
safety_settings=kwargs.pop("safety_settings", None),
request_options={"timeout": llm.timeout} if llm.timeout else None,
)
return llm.client.generate_text(prompt=prompt, **kwargs)
except google.api_core.exceptions.FailedPrecondition as exc:
if "location is not supported" in exc.message:
raise ValueError(error_msg)
return _completion_with_retry(
prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
)
def _strip_erroneous_leading_spaces(text: str) -> str:
"""Strip erroneous leading spaces from text.
The PaLM API will sometimes erroneously return a single leading space in all
lines > 1. This function strips that space.
"""
has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
if has_leading_space:
return text.replace("\n ", "\n")
else:
return text
class _BaseGoogleGenerativeAI(BaseModel):
"""Base class for Google Generative AI LLMs"""
model: str = Field(
...,
description="""The name of the model to use.
Supported examples:
- gemini-pro
- models/text-bison-001""",
)
"""Model name to use."""
google_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("GOOGLE_API_KEY", default=None)
)
credentials: Any = None
"The default custom credentials (google.auth.credentials.Credentials) to use "
"when making API calls. If not provided, credentials will be ascertained from "
"the GOOGLE_API_KEY envvar"
temperature: float = 0.7
"""Run inference with this temperature. Must by in the closed interval
[0.0, 1.0]."""
top_p: Optional[float] = None
"""Decode using nucleus sampling: consider the smallest set of tokens whose
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
top_k: Optional[int] = None
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
Must be positive."""
max_output_tokens: Optional[int] = Field(default=None, alias="max_tokens")
"""Maximum number of tokens to include in a candidate. Must be greater than zero.
If unset, will default to 64."""
n: int = 1
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""
max_retries: int = 6
"""The maximum number of retries to make when generating."""
timeout: Optional[float] = None
"""The maximum number of seconds to wait for a response."""
client_options: Optional[Dict] = Field(
default=None,
description=(
"A dictionary of client options to pass to the Google API client, "
"such as `api_endpoint`."
),
)
transport: Optional[str] = Field(
default=None,
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
)
additional_headers: Optional[Dict[str, str]] = Field(
default=None,
description=(
"A key-value dictionary representing additional headers for the model call"
),
)
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
"""The default safety settings to use for all generations.
For example:
from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
safety_settings = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
}
""" # noqa: E501
@property
def lc_secrets(self) -> Dict[str, str]:
return {"google_api_key": "GOOGLE_API_KEY"}
@property
def _model_family(self) -> str:
return GoogleModelFamily(self.model)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"max_output_tokens": self.max_output_tokens,
"candidate_count": self.n,
}
[docs]
class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
"""Google GenerativeAI models.
Example:
.. code-block:: python
from langchain_google_genai import GoogleGenerativeAI
llm = GoogleGenerativeAI(model="gemini-pro")
"""
client: Any = None #: :meta private:
model_config = ConfigDict(
populate_by_name=True,
)
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validates params and passes them to google-generativeai package."""
if self.credentials:
genai.configure(
credentials=self.credentials,
transport=self.transport,
client_options=self.client_options,
)
else:
if isinstance(self.google_api_key, SecretStr):
google_api_key: Optional[str] = self.google_api_key.get_secret_value()
else:
google_api_key = self.google_api_key
genai.configure(
api_key=google_api_key,
transport=self.transport,
client_options=self.client_options,
)
model_name = self.model
safety_settings = self.safety_settings
if safety_settings and (
not GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI
):
raise ValueError("Safety settings are only supported for Gemini models")
if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
self.client = genai.GenerativeModel(
model_name=model_name, safety_settings=safety_settings
)
else:
self.client = genai
if self.temperature is not None and not 0 <= self.temperature <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
if self.top_p is not None and not 0 <= self.top_p <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
if self.top_k is not None and self.top_k <= 0:
raise ValueError("top_k must be positive")
if self.max_output_tokens is not None and self.max_output_tokens <= 0:
raise ValueError("max_output_tokens must be greater than zero")
if self.timeout is not None and self.timeout <= 0:
raise ValueError("timeout must be greater than zero")
return self
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
ls_params = super()._get_ls_params(stop=stop, **kwargs)
ls_params["ls_provider"] = "google_genai"
if ls_max_tokens := kwargs.get("max_output_tokens", self.max_output_tokens):
ls_params["ls_max_tokens"] = ls_max_tokens
return ls_params
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations: List[List[Generation]] = []
generation_config = {
"stop_sequences": stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"max_output_tokens": self.max_output_tokens,
"candidate_count": self.n,
}
for prompt in prompts:
if self._model_family == GoogleModelFamily.GEMINI:
res = _completion_with_retry(
self,
prompt=prompt,
stream=False,
is_gemini=True,
run_manager=run_manager,
generation_config=generation_config,
safety_settings=kwargs.pop("safety_settings", None),
)
candidates = [
"".join([p.text for p in c.content.parts]) for c in res.candidates
]
generations.append([Generation(text=c) for c in candidates])
else:
res = _completion_with_retry(
self,
model=self.model,
prompt=prompt,
stream=False,
is_gemini=False,
run_manager=run_manager,
**generation_config,
)
prompt_generations = []
for candidate in res.candidates:
raw_text = candidate["output"]
stripped_text = _strip_erroneous_leading_spaces(raw_text)
prompt_generations.append(Generation(text=stripped_text))
generations.append(prompt_generations)
return LLMResult(generations=generations)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
generation_config = {
"stop_sequences": stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"max_output_tokens": self.max_output_tokens,
"candidate_count": self.n,
}
generation_config = generation_config | kwargs.get("generation_config", {})
for stream_resp in _completion_with_retry(
self,
prompt,
stream=True,
is_gemini=True,
run_manager=run_manager,
generation_config=generation_config,
safety_settings=kwargs.pop("safety_settings", None),
**kwargs,
):
chunk = GenerationChunk(text=stream_resp.text)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
stream_resp.text,
chunk=chunk,
verbose=self.verbose,
)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "google_palm"
[docs]
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
if self._model_family == GoogleModelFamily.GEMINI:
result = self.client.count_tokens(text)
token_count = result.total_tokens
else:
result = self.client.count_text_tokens(model=self.model, prompt=text)
token_count = result["token_count"]
return token_count