import re
import string
from typing import Any, Dict, List, Optional
# TODO: remove ignore once the google package is published with types
from google.ai.generativelanguage_v1beta.types import (
BatchEmbedContentsRequest,
EmbedContentRequest,
)
from langchain_core.embeddings import Embeddings
from langchain_core.utils import secret_from_env
from pydantic import BaseModel, Field, SecretStr, model_validator
from typing_extensions import Self
from langchain_google_genai._common import (
GoogleGenerativeAIError,
get_client_info,
)
from langchain_google_genai._genai_extension import build_generative_service
_MAX_TOKENS_PER_BATCH = 20000
_DEFAULT_BATCH_SIZE = 100
[docs]
class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
"""`Google Generative AI Embeddings`.
To use, you must have either:
1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
constructor.
Example:
.. code-block:: python
from langchain_google_genai import GoogleGenerativeAIEmbeddings
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
embeddings.embed_query("What's our Q1 revenue?")
"""
client: Any = None #: :meta private:
model: str = Field(
...,
description="The name of the embedding model to use. "
"Example: models/embedding-001",
)
task_type: Optional[str] = Field(
default=None,
description="The task type. Valid options include: "
"task_type_unspecified, retrieval_query, retrieval_document, "
"semantic_similarity, classification, and clustering",
)
google_api_key: Optional[SecretStr] = Field(
default_factory=secret_from_env("GOOGLE_API_KEY", default=None),
description=(
"The Google API key to use. If not provided, "
"the GOOGLE_API_KEY environment variable will be used."
),
)
credentials: Any = Field(
default=None,
exclude=True,
description="The default custom credentials "
"(google.auth.credentials.Credentials) to use when making API calls. If not "
"provided, credentials will be ascertained from the GOOGLE_API_KEY envvar",
)
client_options: Optional[Dict] = Field(
default=None,
description=(
"A dictionary of client options to pass to the Google API client, "
"such as `api_endpoint`."
),
)
transport: Optional[str] = Field(
default=None,
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
)
request_options: Optional[Dict] = Field(
default=None,
description="A dictionary of request options to pass to the Google API client."
"Example: `{'timeout': 10}`",
)
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validates params and passes them to google-generativeai package."""
if isinstance(self.google_api_key, SecretStr):
google_api_key: Optional[str] = self.google_api_key.get_secret_value()
else:
google_api_key = self.google_api_key
client_info = get_client_info("GoogleGenerativeAIEmbeddings")
self.client = build_generative_service(
credentials=self.credentials,
api_key=google_api_key,
client_info=client_info,
client_options=self.client_options,
)
return self
@staticmethod
def _split_by_punctuation(text: str) -> List[str]:
"""Splits a string by punctuation and whitespace characters."""
split_by = string.punctuation + "\t\n "
pattern = f"([{split_by}])"
# Using re.split to split the text based on the pattern
return [segment for segment in re.split(pattern, text) if segment]
@staticmethod
def _prepare_batches(texts: List[str], batch_size: int) -> List[List[str]]:
"""Splits texts in batches based on current maximum batch size
and maximum tokens per request.
"""
text_index = 0
texts_len = len(texts)
batch_token_len = 0
batches: List[List[str]] = []
current_batch: List[str] = []
if texts_len == 0:
return []
while text_index < texts_len:
current_text = texts[text_index]
# Number of tokens per a text is conservatively estimated
# as 2 times number of words, punctuation and whitespace characters.
# Using `count_tokens` API will make batching too expensive.
# Utilizing a tokenizer, would add a dependency that would not
# necessarily be reused by the application using this class.
current_text_token_cnt = (
len(GoogleGenerativeAIEmbeddings._split_by_punctuation(current_text))
* 2
)
end_of_batch = False
if current_text_token_cnt > _MAX_TOKENS_PER_BATCH:
# Current text is too big even for a single batch.
# Such request will fail, but we still make a batch
# so that the app can get the error from the API.
if len(current_batch) > 0:
# Adding current batch if not empty.
batches.append(current_batch)
current_batch = [current_text]
text_index += 1
end_of_batch = True
elif (
batch_token_len + current_text_token_cnt > _MAX_TOKENS_PER_BATCH
or len(current_batch) == batch_size
):
end_of_batch = True
else:
if text_index == texts_len - 1:
# Last element - even though the batch may be not big,
# we still need to make it.
end_of_batch = True
batch_token_len += current_text_token_cnt
current_batch.append(current_text)
text_index += 1
if end_of_batch:
batches.append(current_batch)
current_batch = []
batch_token_len = 0
return batches
def _prepare_request(
self,
text: str,
task_type: Optional[str] = None,
title: Optional[str] = None,
output_dimensionality: Optional[int] = None,
) -> EmbedContentRequest:
task_type = self.task_type or task_type or "RETRIEVAL_DOCUMENT"
# https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
request = EmbedContentRequest(
content={"parts": [{"text": text}]},
model=self.model,
task_type=task_type.upper(),
title=title,
output_dimensionality=output_dimensionality,
)
return request
[docs]
def embed_documents(
self,
texts: List[str],
*,
batch_size: int = _DEFAULT_BATCH_SIZE,
task_type: Optional[str] = None,
titles: Optional[List[str]] = None,
output_dimensionality: Optional[int] = None,
) -> List[List[float]]:
"""Embed a list of strings. Google Generative AI currently
sets a max batch size of 100 strings.
Args:
texts: List[str] The list of strings to embed.
batch_size: [int] The batch size of embeddings to send to the model
task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType)
titles: An optional list of titles for texts provided.
Only applicable when TaskType is RETRIEVAL_DOCUMENT.
output_dimensionality: Optional reduced dimension for the output embedding.
https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
Returns:
List of embeddings, one for each text.
"""
embeddings: List[List[float]] = []
batch_start_index = 0
for batch in GoogleGenerativeAIEmbeddings._prepare_batches(texts, batch_size):
if titles:
titles_batch = titles[
batch_start_index : batch_start_index + len(batch)
]
batch_start_index += len(batch)
else:
titles_batch = [None] * len(batch) # type: ignore[list-item]
requests = [
self._prepare_request(
text=text,
task_type=task_type,
title=title,
output_dimensionality=output_dimensionality,
)
for text, title in zip(batch, titles_batch)
]
try:
result = self.client.batch_embed_contents(
BatchEmbedContentsRequest(requests=requests, model=self.model)
)
except Exception as e:
raise GoogleGenerativeAIError(f"Error embedding content: {e}") from e
embeddings.extend([list(e.values) for e in result.embeddings])
return embeddings
[docs]
def embed_query(
self,
text: str,
task_type: Optional[str] = None,
title: Optional[str] = None,
output_dimensionality: Optional[int] = None,
) -> List[float]:
"""Embed a text.
Args:
text: The text to embed.
task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType)
title: An optional title for the text.
Only applicable when TaskType is RETRIEVAL_DOCUMENT.
output_dimensionality: Optional reduced dimension for the output embedding.
https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
Returns:
Embedding for the text.
"""
task_type = self.task_type or "RETRIEVAL_QUERY"
return self.embed_documents(
[text],
task_type=task_type,
titles=[title] if title else None,
output_dimensionality=output_dimensionality,
)[0]