importreimportstringfromtypingimportAny,Dict,List,Optional# TODO: remove ignore once the google package is published with typesfromgoogle.ai.generativelanguage_v1beta.typesimport(BatchEmbedContentsRequest,EmbedContentRequest,EmbedContentResponse,)fromlangchain_core.embeddingsimportEmbeddingsfromlangchain_core.utilsimportsecret_from_envfrompydanticimportBaseModel,Field,SecretStr,model_validatorfromtyping_extensionsimportSelffromlangchain_google_genai._commonimport(GoogleGenerativeAIError,get_client_info,)fromlangchain_google_genai._genai_extensionimportbuild_generative_service_MAX_TOKENS_PER_BATCH=20000_DEFAULT_BATCH_SIZE=100
[docs]classGoogleGenerativeAIEmbeddings(BaseModel,Embeddings):"""`Google Generative AI Embeddings`. To use, you must have either: 1. The ``GOOGLE_API_KEY`` environment variable set with your API key, or 2. Pass your API key using the google_api_key kwarg to the GoogleGenerativeAIEmbeddings constructor. Example: .. code-block:: python from langchain_google_genai import GoogleGenerativeAIEmbeddings embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") embeddings.embed_query("What's our Q1 revenue?") """client:Any=None#: :meta private:model:str=Field(...,description="The name of the embedding model to use. ""Example: models/embedding-001",)task_type:Optional[str]=Field(default=None,description="The task type. Valid options include: ""task_type_unspecified, retrieval_query, retrieval_document, ""semantic_similarity, classification, and clustering",)google_api_key:Optional[SecretStr]=Field(default_factory=secret_from_env("GOOGLE_API_KEY",default=None),description=("The Google API key to use. If not provided, ""the GOOGLE_API_KEY environment variable will be used."),)credentials:Any=Field(default=None,exclude=True,description="The default custom credentials ""(google.auth.credentials.Credentials) to use when making API calls. If not ""provided, credentials will be ascertained from the GOOGLE_API_KEY envvar",)client_options:Optional[Dict]=Field(default=None,description=("A dictionary of client options to pass to the Google API client, ""such as `api_endpoint`."),)transport:Optional[str]=Field(default=None,description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",)request_options:Optional[Dict]=Field(default=None,description="A dictionary of request options to pass to the Google API client.""Example: `{'timeout': 10}`",)@model_validator(mode="after")defvalidate_environment(self)->Self:"""Validates params and passes them to google-generativeai package."""ifisinstance(self.google_api_key,SecretStr):google_api_key:Optional[str]=self.google_api_key.get_secret_value()else:google_api_key=self.google_api_keyclient_info=get_client_info("GoogleGenerativeAIEmbeddings")self.client=build_generative_service(credentials=self.credentials,api_key=google_api_key,client_info=client_info,client_options=self.client_options,)returnself@staticmethoddef_split_by_punctuation(text:str)->List[str]:"""Splits a string by punctuation and whitespace characters."""split_by=string.punctuation+"\t\n "pattern=f"([{split_by}])"# Using re.split to split the text based on the patternreturn[segmentforsegmentinre.split(pattern,text)ifsegment]@staticmethoddef_prepare_batches(texts:List[str],batch_size:int)->List[List[str]]:"""Splits texts in batches based on current maximum batch size and maximum tokens per request. """text_index=0texts_len=len(texts)batch_token_len=0batches:List[List[str]]=[]current_batch:List[str]=[]iftexts_len==0:return[]whiletext_index<texts_len:current_text=texts[text_index]# Number of tokens per a text is conservatively estimated# as 2 times number of words, punctuation and whitespace characters.# Using `count_tokens` API will make batching too expensive.# Utilizing a tokenizer, would add a dependency that would not# necessarily be reused by the application using this class.current_text_token_cnt=(len(GoogleGenerativeAIEmbeddings._split_by_punctuation(current_text))*2)end_of_batch=Falseifcurrent_text_token_cnt>_MAX_TOKENS_PER_BATCH:# Current text is too big even for a single batch.# Such request will fail, but we still make a batch# so that the app can get the error from the API.iflen(current_batch)>0:# Adding current batch if not empty.batches.append(current_batch)current_batch=[current_text]text_index+=1end_of_batch=Trueelif(batch_token_len+current_text_token_cnt>_MAX_TOKENS_PER_BATCHorlen(current_batch)==batch_size):end_of_batch=Trueelse:iftext_index==texts_len-1:# Last element - even though the batch may be not big,# we still need to make it.end_of_batch=Truebatch_token_len+=current_text_token_cntcurrent_batch.append(current_text)text_index+=1ifend_of_batch:batches.append(current_batch)current_batch=[]batch_token_len=0returnbatchesdef_prepare_request(self,text:str,task_type:Optional[str]=None,title:Optional[str]=None,output_dimensionality:Optional[int]=None,)->EmbedContentRequest:task_type=self.task_typeortask_typeor"RETRIEVAL_DOCUMENT"# https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequestrequest=EmbedContentRequest(content={"parts":[{"text":text}]},model=self.model,task_type=task_type.upper(),title=title,output_dimensionality=output_dimensionality,)returnrequest
[docs]defembed_documents(self,texts:List[str],*,batch_size:int=_DEFAULT_BATCH_SIZE,task_type:Optional[str]=None,titles:Optional[List[str]]=None,output_dimensionality:Optional[int]=None,)->List[List[float]]:"""Embed a list of strings. Google Generative AI currently sets a max batch size of 100 strings. Args: texts: List[str] The list of strings to embed. batch_size: [int] The batch size of embeddings to send to the model task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType) titles: An optional list of titles for texts provided. Only applicable when TaskType is RETRIEVAL_DOCUMENT. output_dimensionality: Optional reduced dimension for the output embedding. https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest Returns: List of embeddings, one for each text. """embeddings:List[List[float]]=[]batch_start_index=0forbatchinGoogleGenerativeAIEmbeddings._prepare_batches(texts,batch_size):iftitles:titles_batch=titles[batch_start_index:batch_start_index+len(batch)]batch_start_index+=len(batch)else:titles_batch=[None]*len(batch)# type: ignore[list-item]requests=[self._prepare_request(text=text,task_type=task_type,title=title,output_dimensionality=output_dimensionality,)fortext,titleinzip(batch,titles_batch)]try:result=self.client.batch_embed_contents(BatchEmbedContentsRequest(requests=requests,model=self.model))exceptExceptionase:raiseGoogleGenerativeAIError(f"Error embedding content: {e}")fromeembeddings.extend([list(e.values)foreinresult.embeddings])returnembeddings
[docs]defembed_query(self,text:str,task_type:Optional[str]=None,title:Optional[str]=None,output_dimensionality:Optional[int]=None,)->List[float]:"""Embed a text, using the non-batch endpoint: https://ai.google.dev/api/rest/v1/models/embedContent#EmbedContentRequest Args: text: The text to embed. task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType) title: An optional title for the text. Only applicable when TaskType is RETRIEVAL_DOCUMENT. output_dimensionality: Optional reduced dimension for the output embedding. Returns: Embedding for the text. """task_type=self.task_typeor"RETRIEVAL_QUERY"try:request:EmbedContentRequest=self._prepare_request(text=text,task_type=task_type,title=title,output_dimensionality=output_dimensionality,)result:EmbedContentResponse=self.client.embed_content(request)exceptExceptionase:raiseGoogleGenerativeAIError(f"Error embedding content: {e}")fromereturnlist(result.embedding.values)