from__future__importannotationsfromtypingimportAny,AsyncIterator,Dict,Iterator,List,Optional,Unionfromgoogle.cloud.aiplatformimporttelemetryfromlangchain_core.callbacks.managerimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_models.llmsimportBaseLLM,LangSmithParamsfromlangchain_core.outputsimportGeneration,GenerationChunk,LLMResultfromlangchain_core.pydantic_v1importField,root_validatorfromvertexai.generative_modelsimport(# type: ignore[import-untyped]Candidate,GenerativeModel,Image,)fromvertexai.language_modelsimport(# type: ignore[import-untyped]CodeGenerationModel,TextGenerationModel,)fromvertexai.language_models._language_modelsimport(# type: ignore[import-untyped]TextGenerationResponse,)fromvertexai.preview.language_modelsimport(# type: ignore[import-untyped]CodeGenerationModelasPreviewCodeGenerationModel,)fromvertexai.preview.language_modelsimport(TextGenerationModelasPreviewTextGenerationModel,)fromlangchain_google_vertexai._baseimportGoogleModelFamily,_VertexAICommonfromlangchain_google_vertexai._utilsimport(create_retry_decorator,get_generation_info,is_gemini_model,)def_completion_with_retry(llm:VertexAI,prompt:List[Union[str,Image]],stream:bool=False,is_gemini:bool=False,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Any:"""Use tenacity to retry the completion call."""retry_decorator=create_retry_decorator(max_retries=llm.max_retries,run_manager=run_manager)@retry_decoratordef_completion_with_retry_inner(prompt:List[Union[str,Image]],is_gemini:bool=False,**kwargs:Any)->Any:ifis_gemini:returnllm.client.generate_content(prompt,stream=stream,safety_settings=kwargs.pop("safety_settings",None),generation_config=kwargs,)else:ifstream:returnllm.client.predict_streaming(prompt[0],**kwargs)returnllm.client.predict(prompt[0],**kwargs)withtelemetry.tool_context_manager(llm._user_agent):return_completion_with_retry_inner(prompt,is_gemini,**kwargs)asyncdef_acompletion_with_retry(llm:VertexAI,prompt:str,is_gemini:bool=False,stream:bool=False,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->Any:"""Use tenacity to retry the completion call."""retry_decorator=create_retry_decorator(max_retries=llm.max_retries,run_manager=run_manager)@retry_decoratorasyncdef_acompletion_with_retry_inner(prompt:str,is_gemini:bool=False,stream:bool=False,**kwargs:Any)->Any:ifis_gemini:returnawaitllm.client.generate_content_async(prompt,generation_config=kwargs,stream=stream,safety_settings=kwargs.pop("safety_settings",None),)ifstream:raiseValueError("Async streaming is supported only for Gemini family!")returnawaitllm.client.predict_async(prompt,**kwargs)withtelemetry.tool_context_manager(llm._user_agent):returnawait_acompletion_with_retry_inner(prompt,is_gemini,stream=stream,**kwargs)
[docs]classVertexAI(_VertexAICommon,BaseLLM):"""Google Vertex AI large language models."""model_name:str=Field(default="text-bison",alias="model")"The name of the Vertex AI large language model."tuned_model_name:Optional[str]=None"""The name of a tuned model. If tuned_model_name is passed model_name will be used to determine the model family """def__init__(self,*,model_name:Optional[str]=None,**kwargs:Any)->None:"""Needed for mypy typing to recognize model_name as a valid arg."""ifmodel_name:kwargs["model_name"]=model_namesuper().__init__(**kwargs)classConfig:"""Configuration for this pydantic object."""allow_population_by_field_name=True@classmethoddefis_lc_serializable(self)->bool:returnTrue@classmethoddefget_lc_namespace(cls)->List[str]:"""Get the namespace of the langchain object."""return["langchain","llms","vertexai"]@root_validator(pre=False,skip_on_failure=True)defvalidate_environment(cls,values:Dict)->Dict:"""Validate that the python package exists in environment."""tuned_model_name=values.get("tuned_model_name")safety_settings=values["safety_settings"]values["model_family"]=GoogleModelFamily(values["model_name"])is_gemini=is_gemini_model(values["model_family"])cls._init_vertexai(values)ifsafety_settingsand(notis_geminiortuned_model_name):raiseValueError("Safety settings are only supported for Gemini models")ifvalues["model_family"]==GoogleModelFamily.CODEY:model_cls=CodeGenerationModelpreview_model_cls=PreviewCodeGenerationModelelifis_gemini:model_cls=GenerativeModelpreview_model_cls=GenerativeModelelse:model_cls=TextGenerationModelpreview_model_cls=PreviewTextGenerationModeliftuned_model_name:generative_model_name=values["tuned_model_name"]else:generative_model_name=values["model_name"]ifis_gemini:values["client"]=model_cls(model_name=generative_model_name,safety_settings=safety_settings)values["client_preview"]=preview_model_cls(model_name=generative_model_name,safety_settings=safety_settings)else:iftuned_model_name:values["client"]=model_cls.get_tuned_model(generative_model_name)values["client_preview"]=preview_model_cls.get_tuned_model(generative_model_name)else:values["client"]=model_cls.from_pretrained(generative_model_name)values["client_preview"]=preview_model_cls.from_pretrained(generative_model_name)ifvalues["streaming"]andvalues["n"]>1:raiseValueError("Only one candidate can be generated with streaming!")returnvaluesdef_get_ls_params(self,stop:Optional[List[str]]=None,**kwargs:Any)->LangSmithParams:"""Get standard params for tracing."""params=self._prepare_params(stop=stop,**kwargs)ls_params=super()._get_ls_params(stop=stop,**params)ls_params["ls_provider"]="google_vertexai"ifls_max_tokens:=params.get("max_output_tokens",self.max_output_tokens):ls_params["ls_max_tokens"]=ls_max_tokensifls_stop:=stoporself.stop:ls_params["ls_stop"]=ls_stopreturnls_paramsdef_candidate_to_generation(self,response:Union[Candidate,TextGenerationResponse],*,stream:bool=False,usage_metadata:Optional[Dict]=None,)->GenerationChunk:"""Converts a stream response to a generation chunk."""generation_info=get_generation_info(response,self._is_gemini_model,stream=stream,usage_metadata=usage_metadata,)try:text=response.textexceptAttributeError:text=""exceptValueError:text=""returnGenerationChunk(text=text,generation_info=generation_info,)def_generate(self,prompts:List[str],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->LLMResult:should_stream=streamifstreamisnotNoneelseself.streamingparams=self._prepare_params(stop=stop,stream=should_stream,**kwargs)generations:List[List[Generation]]=[]forpromptinprompts:ifshould_stream:generation=GenerationChunk(text="")forchunkinself._stream(prompt,stop=stop,run_manager=run_manager,**kwargs):generation+=chunkgenerations.append([generation])else:res=_completion_with_retry(self,[prompt],stream=should_stream,is_gemini=self._is_gemini_model,run_manager=run_manager,**params,)ifself._is_gemini_model:usage_metadata=res.to_dict().get("usage_metadata")else:usage_metadata=res.raw_prediction_response.metadatagenerations.append([self._candidate_to_generation(r,usage_metadata=usage_metadata)forrinres.candidates])returnLLMResult(generations=generations)asyncdef_agenerate(self,prompts:List[str],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->LLMResult:params=self._prepare_params(stop=stop,**kwargs)generations:List[List[Generation]]=[]forpromptinprompts:res=await_acompletion_with_retry(self,prompt,is_gemini=self._is_gemini_model,run_manager=run_manager,**params,)ifself._is_gemini_model:usage_metadata=res.to_dict().get("usage_metadata")else:usage_metadata=res.raw_prediction_response.metadatagenerations.append([self._candidate_to_generation(r,usage_metadata=usage_metadata)forrinres.candidates])returnLLMResult(generations=generations)def_stream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[GenerationChunk]:params=self._prepare_params(stop=stop,stream=True,**kwargs)forstream_respin_completion_with_retry(self,[prompt],stream=True,is_gemini=self._is_gemini_model,run_manager=run_manager,**params,):usage_metadata=Noneifself._is_gemini_model:usage_metadata=stream_resp.to_dict().get("usage_metadata")stream_resp=stream_resp.candidates[0]chunk=self._candidate_to_generation(stream_resp,stream=True,usage_metadata=usage_metadata)yieldchunkifrun_manager:run_manager.on_llm_new_token(chunk.text,chunk=chunk,verbose=self.verbose,)asyncdef_astream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[GenerationChunk]:params=self._prepare_params(stop=stop,stream=True,**kwargs)ifnotself._is_gemini_model:raiseValueError("Async streaming is supported only for Gemini family!")asyncforchunkinawait_acompletion_with_retry(self,prompt,stream=True,is_gemini=self._is_gemini_model,run_manager=run_manager,**params,):usage_metadata=chunk.to_dict().get("usage_metadata")chunk=self._candidate_to_generation(chunk.candidates[0],stream=True,usage_metadata=usage_metadata)yieldchunkifrun_manager:awaitrun_manager.on_llm_new_token(chunk.text,chunk=chunk,verbose=self.verbose)