def_create_retry_decorator(llm:BaseLLM,*,max_retries:int=1,run_manager:Optional[Union[AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun]]=None,)->Callable[[Any],Any]:"""Creates a retry decorator for Vertex / Palm LLMs."""errors=[google.api_core.exceptions.ResourceExhausted,google.api_core.exceptions.ServiceUnavailable,google.api_core.exceptions.Aborted,google.api_core.exceptions.DeadlineExceeded,google.api_core.exceptions.GoogleAPIError,]decorator=create_base_retry_decorator(error_types=errors,max_retries=max_retries,run_manager=run_manager)returndecoratordef_completion_with_retry(llm:GoogleGenerativeAI,prompt:LanguageModelInput,is_gemini:bool=False,stream:bool=False,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Any:"""Use tenacity to retry the completion call."""retry_decorator=_create_retry_decorator(llm,max_retries=llm.max_retries,run_manager=run_manager)@retry_decoratordef_completion_with_retry(prompt:LanguageModelInput,is_gemini:bool,stream:bool,**kwargs:Any)->Any:generation_config=kwargs.get("generation_config",{})error_msg=("Your location is not supported by google-generativeai at the moment. ""Try to use VertexAI LLM from langchain_google_vertexai")try:ifis_gemini:returnllm.client.generate_content(contents=prompt,stream=stream,generation_config=generation_config,safety_settings=kwargs.pop("safety_settings",None),request_options={"timeout":llm.timeout}ifllm.timeoutelseNone,)returnllm.client.generate_text(prompt=prompt,**kwargs)exceptgoogle.api_core.exceptions.FailedPreconditionasexc:if"location is not supported"inexc.message:raiseValueError(error_msg)return_completion_with_retry(prompt=prompt,is_gemini=is_gemini,stream=stream,**kwargs)def_strip_erroneous_leading_spaces(text:str)->str:"""Strip erroneous leading spaces from text. The PaLM API will sometimes erroneously return a single leading space in all lines > 1. This function strips that space. """has_leading_space=all(notlineorline[0]==" "forlineintext.split("\n")[1:])ifhas_leading_space:returntext.replace("\n ","\n")else:returntextclass_BaseGoogleGenerativeAI(BaseModel):"""Base class for Google Generative AI LLMs"""model:str=Field(...,description="""The name of the model to use.Supported examples: - gemini-pro - models/text-bison-001""",)"""Model name to use."""google_api_key:Optional[SecretStr]=Field(alias="api_key",default_factory=secret_from_env("GOOGLE_API_KEY",default=None))credentials:Any=None"The default custom credentials (google.auth.credentials.Credentials) to use ""when making API calls. If not provided, credentials will be ascertained from ""the GOOGLE_API_KEY envvar"temperature:float=0.7"""Run inference with this temperature. Must by in the closed interval [0.0, 1.0]."""top_p:Optional[float]=None"""Decode using nucleus sampling: consider the smallest set of tokens whose probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""top_k:Optional[int]=None"""Decode using top-k sampling: consider the set of top_k most probable tokens. Must be positive."""max_output_tokens:Optional[int]=None"""Maximum number of tokens to include in a candidate. Must be greater than zero. If unset, will default to 64."""n:int=1"""Number of chat completions to generate for each prompt. Note that the API may not return the full n completions if duplicates are generated."""max_retries:int=6"""The maximum number of retries to make when generating."""timeout:Optional[float]=None"""The maximum number of seconds to wait for a response."""client_options:Optional[Dict]=Field(default=None,description=("A dictionary of client options to pass to the Google API client, ""such as `api_endpoint`."),)transport:Optional[str]=Field(default=None,description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",)additional_headers:Optional[Dict[str,str]]=Field(default=None,description=("A key-value dictionary representing additional headers for the model call"),)safety_settings:Optional[Dict[HarmCategory,HarmBlockThreshold]]=None"""The default safety settings to use for all generations. For example: from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory safety_settings = { HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, } """# noqa: E501@propertydeflc_secrets(self)->Dict[str,str]:return{"google_api_key":"GOOGLE_API_KEY"}@propertydef_model_family(self)->str:returnGoogleModelFamily(self.model)@propertydef_identifying_params(self)->Dict[str,Any]:"""Get the identifying parameters."""return{"model":self.model,"temperature":self.temperature,"top_p":self.top_p,"top_k":self.top_k,"max_output_tokens":self.max_output_tokens,"candidate_count":self.n,}
[docs]classGoogleGenerativeAI(_BaseGoogleGenerativeAI,BaseLLM):"""Google GenerativeAI models. Example: .. code-block:: python from langchain_google_genai import GoogleGenerativeAI llm = GoogleGenerativeAI(model="gemini-pro") """client:Any=None#: :meta private:@root_validator(pre=False,skip_on_failure=True)defvalidate_environment(cls,values:Dict)->Dict:"""Validates params and passes them to google-generativeai package."""ifvalues.get("credentials"):genai.configure(credentials=values.get("credentials"),transport=values.get("transport"),client_options=values.get("client_options"),)else:google_api_key=values.get("google_api_key")ifisinstance(google_api_key,SecretStr):google_api_key=google_api_key.get_secret_value()genai.configure(api_key=google_api_key,transport=values.get("transport"),client_options=values.get("client_options"),)model_name=values["model"]safety_settings=values["safety_settings"]ifsafety_settingsand(notGoogleModelFamily(model_name)==GoogleModelFamily.GEMINI):raiseValueError("Safety settings are only supported for Gemini models")ifGoogleModelFamily(model_name)==GoogleModelFamily.GEMINI:values["client"]=genai.GenerativeModel(model_name=model_name,safety_settings=safety_settings)else:values["client"]=genaiifvalues["temperature"]isnotNoneandnot0<=values["temperature"]<=1:raiseValueError("temperature must be in the range [0.0, 1.0]")ifvalues["top_p"]isnotNoneandnot0<=values["top_p"]<=1:raiseValueError("top_p must be in the range [0.0, 1.0]")ifvalues["top_k"]isnotNoneandvalues["top_k"]<=0:raiseValueError("top_k must be positive")ifvalues["max_output_tokens"]isnotNoneandvalues["max_output_tokens"]<=0:raiseValueError("max_output_tokens must be greater than zero")ifvalues["timeout"]isnotNoneandvalues["timeout"]<=0:raiseValueError("timeout must be greater than zero")returnvaluesdef_get_ls_params(self,stop:Optional[List[str]]=None,**kwargs:Any)->LangSmithParams:"""Get standard params for tracing."""ls_params=super()._get_ls_params(stop=stop,**kwargs)ls_params["ls_provider"]="google_genai"ifls_max_tokens:=kwargs.get("max_output_tokens",self.max_output_tokens):ls_params["ls_max_tokens"]=ls_max_tokensreturnls_paramsdef_generate(self,prompts:List[str],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->LLMResult:generations:List[List[Generation]]=[]generation_config={"stop_sequences":stop,"temperature":self.temperature,"top_p":self.top_p,"top_k":self.top_k,"max_output_tokens":self.max_output_tokens,"candidate_count":self.n,}forpromptinprompts:ifself._model_family==GoogleModelFamily.GEMINI:res=_completion_with_retry(self,prompt=prompt,stream=False,is_gemini=True,run_manager=run_manager,generation_config=generation_config,safety_settings=kwargs.pop("safety_settings",None),)candidates=["".join([p.textforpinc.content.parts])forcinres.candidates]generations.append([Generation(text=c)forcincandidates])else:res=_completion_with_retry(self,model=self.model,prompt=prompt,stream=False,is_gemini=False,run_manager=run_manager,**generation_config,)prompt_generations=[]forcandidateinres.candidates:raw_text=candidate["output"]stripped_text=_strip_erroneous_leading_spaces(raw_text)prompt_generations.append(Generation(text=stripped_text))generations.append(prompt_generations)returnLLMResult(generations=generations)def_stream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[GenerationChunk]:generation_config={"stop_sequences":stop,"temperature":self.temperature,"top_p":self.top_p,"top_k":self.top_k,"max_output_tokens":self.max_output_tokens,"candidate_count":self.n,}generation_config=generation_config|kwargs.get("generation_config",{})forstream_respin_completion_with_retry(self,prompt,stream=True,is_gemini=True,run_manager=run_manager,generation_config=generation_config,safety_settings=kwargs.pop("safety_settings",None),**kwargs,):chunk=GenerationChunk(text=stream_resp.text)yieldchunkifrun_manager:run_manager.on_llm_new_token(stream_resp.text,chunk=chunk,verbose=self.verbose,)@propertydef_llm_type(self)->str:"""Return type of llm."""return"google_palm"
[docs]defget_num_tokens(self,text:str)->int:"""Get the number of tokens present in the text. Useful for checking if an input will fit in a model's context window. Args: text: The string input to tokenize. Returns: The integer number of tokens in the text. """ifself._model_family==GoogleModelFamily.GEMINI:result=self.client.count_tokens(text)token_count=result.total_tokenselse:result=self.client.count_text_tokens(model=self.model,prompt=text)token_count=result["token_count"]returntoken_count