Source code for langchain_community.embeddings.google_palm
from__future__importannotationsimportloggingfromtypingimportAny,Callable,Dict,List,Optionalfromlangchain_core.embeddingsimportEmbeddingsfromlangchain_core.pydantic_v1importBaseModelfromlangchain_core.utilsimportget_from_dict_or_env,pre_initfromtenacityimport(before_sleep_log,retry,retry_if_exception_type,stop_after_attempt,wait_exponential,)logger=logging.getLogger(__name__)def_create_retry_decorator()->Callable[[Any],Any]:"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""importgoogle.api_core.exceptionsmultiplier=2min_seconds=1max_seconds=60max_retries=10returnretry(reraise=True,stop=stop_after_attempt(max_retries),wait=wait_exponential(multiplier=multiplier,min=min_seconds,max=max_seconds),retry=(retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)|retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)|retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)),before_sleep=before_sleep_log(logger,logging.WARNING),)
[docs]defembed_with_retry(embeddings:GooglePalmEmbeddings,*args:Any,**kwargs:Any)->Any:"""Use tenacity to retry the completion call."""retry_decorator=_create_retry_decorator()@retry_decoratordef_embed_with_retry(*args:Any,**kwargs:Any)->Any:returnembeddings.client.generate_embeddings(*args,**kwargs)return_embed_with_retry(*args,**kwargs)
[docs]classGooglePalmEmbeddings(BaseModel,Embeddings):"""Google's PaLM Embeddings APIs."""client:Anygoogle_api_key:Optional[str]model_name:str="models/embedding-gecko-001""""Model name to use."""show_progress_bar:bool=False"""Whether to show a tqdm progress bar. Must have `tqdm` installed."""@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate api key, python package exists."""google_api_key=get_from_dict_or_env(values,"google_api_key","GOOGLE_API_KEY")try:importgoogle.generativeaiasgenaigenai.configure(api_key=google_api_key)exceptImportError:raiseImportError("Could not import google.generativeai python package.")values["client"]=genaireturnvalues
[docs]defembed_documents(self,texts:List[str])->List[List[float]]:ifself.show_progress_bar:try:fromtqdmimporttqdmiter_=tqdm(texts,desc="GooglePalmEmbeddings")exceptImportError:logger.warning("Unable to show progress bar because tqdm could not be imported. ""Please install with `pip install tqdm`.")iter_=textselse:iter_=textsreturn[self.embed_query(text)fortextiniter_]