from__future__importannotationsfromconcurrent.futuresimportExecutor,ThreadPoolExecutorfromtypingimportTYPE_CHECKING,Any,ClassVar,Dict,Iterator,List,Optional,Unionfromlangchain_core._api.deprecationimportdeprecatedfromlangchain_core.callbacks.managerimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_models.llmsimportBaseLLMfromlangchain_core.outputsimportGeneration,GenerationChunk,LLMResultfromlangchain_core.utilsimportpre_initfrompydanticimportBaseModel,ConfigDict,Fieldfromlangchain_community.utilities.vertexaiimport(create_retry_decorator,get_client_info,init_vertexai,raise_vertex_import_error,)ifTYPE_CHECKING:fromgoogle.cloud.aiplatform.gapicimport(PredictionServiceAsyncClient,PredictionServiceClient,)fromgoogle.cloud.aiplatform.modelsimportPredictionfromgoogle.protobuf.struct_pb2importValuefromvertexai.language_models._language_modelsimport(TextGenerationResponse,_LanguageModel,)fromvertexai.preview.generative_modelsimportImage# This is for backwards compatibility# We can remove after `langchain` stops importing it_response_to_generation=Nonestream_completion_with_retry=None
[docs]defis_codey_model(model_name:str)->bool:"""Return True if the model name is a Codey model."""return"code"inmodel_name
[docs]defis_gemini_model(model_name:str)->bool:"""Return True if the model name is a Gemini model."""returnmodel_nameisnotNoneand"gemini"inmodel_name
[docs]defcompletion_with_retry(# type: ignore[no-redef]llm:VertexAI,prompt:List[Union[str,"Image"]],stream:bool=False,is_gemini:bool=False,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Any:"""Use tenacity to retry the completion call."""retry_decorator=create_retry_decorator(llm,run_manager=run_manager)@retry_decoratordef_completion_with_retry(prompt:List[Union[str,"Image"]],is_gemini:bool=False,**kwargs:Any)->Any:ifis_gemini:returnllm.client.generate_content(prompt,stream=stream,generation_config=kwargs)else:ifstream:returnllm.client.predict_streaming(prompt[0],**kwargs)returnllm.client.predict(prompt[0],**kwargs)return_completion_with_retry(prompt,is_gemini,**kwargs)
[docs]asyncdefacompletion_with_retry(llm:VertexAI,prompt:str,is_gemini:bool=False,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->Any:"""Use tenacity to retry the completion call."""retry_decorator=create_retry_decorator(llm,run_manager=run_manager)@retry_decoratorasyncdef_acompletion_with_retry(prompt:str,is_gemini:bool=False,**kwargs:Any)->Any:ifis_gemini:returnawaitllm.client.generate_content_async(prompt,generation_config=kwargs)returnawaitllm.client.predict_async(prompt,**kwargs)returnawait_acompletion_with_retry(prompt,is_gemini,**kwargs)
class_VertexAIBase(BaseModel):model_config=ConfigDict(protected_namespaces=())project:Optional[str]=None"The default GCP project to use when making Vertex API calls."location:str="us-central1""The default location to use when making API calls."request_parallelism:int=5"The amount of parallelism allowed for requests issued to VertexAI models. ""Default is 5."max_retries:int=6"""The maximum number of retries to make when generating."""task_executor:ClassVar[Optional[Executor]]=Field(default=None,exclude=True)stop:Optional[List[str]]=None"Optional list of stop words to use when generating."model_name:Optional[str]=None"Underlying model name."@classmethoddef_get_task_executor(cls,request_parallelism:int=5)->Executor:ifcls.task_executorisNone:cls.task_executor=ThreadPoolExecutor(max_workers=request_parallelism)returncls.task_executorclass_VertexAICommon(_VertexAIBase):# type: ignore[override]client:"_LanguageModel"=None#: :meta private:client_preview:"_LanguageModel"=None#: :meta private:model_name:str"Underlying model name."temperature:float=0.0"Sampling temperature, it controls the degree of randomness in token selection."max_output_tokens:int=128"Token limit determines the maximum amount of text output from one prompt."top_p:float=0.95"Tokens are selected from most probable to least until the sum of their ""probabilities equals the top-p value. Top-p is ignored for Codey models."top_k:int=40"How the model selects tokens for output, the next token is selected from ""among the top-k most probable tokens. Top-k is ignored for Codey models."credentials:Any=Field(default=None,exclude=True)"The default custom credentials (google.auth.credentials.Credentials) to use ""when making API calls. If not provided, credentials will be ascertained from ""the environment."n:int=1"""How many completions to generate for each prompt."""streaming:bool=False"""Whether to stream the results or not."""@propertydef_llm_type(self)->str:return"vertexai"@propertydefis_codey_model(self)->bool:returnis_codey_model(self.model_name)@propertydef_is_gemini_model(self)->bool:returnis_gemini_model(self.model_name)@propertydef_identifying_params(self)->Dict[str,Any]:"""Gets the identifying parameters."""return{**{"model_name":self.model_name},**self._default_params}@propertydef_default_params(self)->Dict[str,Any]:params={"temperature":self.temperature,"max_output_tokens":self.max_output_tokens,"candidate_count":self.n,}ifnotself.is_codey_model:params.update({"top_k":self.top_k,"top_p":self.top_p,})returnparams@classmethoddef_try_init_vertexai(cls,values:Dict)->None:allowed_params=["project","location","credentials"]params={k:vfork,vinvalues.items()ifkinallowed_params}init_vertexai(**params)returnNonedef_prepare_params(self,stop:Optional[List[str]]=None,stream:bool=False,**kwargs:Any,)->dict:stop_sequences=stoporself.stopparams_mapping={"n":"candidate_count"}params={params_mapping.get(k,k):vfork,vinkwargs.items()}params={**self._default_params,"stop_sequences":stop_sequences,**params}ifstreamorself.streaming:params.pop("candidate_count")returnparams
[docs]@deprecated(since="0.0.12",removal="1.0",alternative_import="langchain_google_vertexai.VertexAI",)classVertexAI(_VertexAICommon,BaseLLM):# type: ignore[override]"""Google Vertex AI large language models."""model_name:str="text-bison""The name of the Vertex AI large language model."tuned_model_name:Optional[str]=None"The name of a tuned model. If provided, model_name is ignored."@classmethoddefis_lc_serializable(self)->bool:returnTrue@classmethoddefget_lc_namespace(cls)->List[str]:"""Get the namespace of the langchain object."""return["langchain","llms","vertexai"]
[docs]@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate that the python package exists in environment."""tuned_model_name=values.get("tuned_model_name")model_name=values["model_name"]is_gemini=is_gemini_model(values["model_name"])cls._try_init_vertexai(values)try:fromvertexai.language_modelsimport(CodeGenerationModel,TextGenerationModel,)fromvertexai.preview.language_modelsimport(CodeGenerationModelasPreviewCodeGenerationModel,)fromvertexai.preview.language_modelsimport(TextGenerationModelasPreviewTextGenerationModel,)ifis_gemini:fromvertexai.preview.generative_modelsimport(GenerativeModel,)ifis_codey_model(model_name):model_cls=CodeGenerationModelpreview_model_cls=PreviewCodeGenerationModelelifis_gemini:model_cls=GenerativeModelpreview_model_cls=GenerativeModelelse:model_cls=TextGenerationModelpreview_model_cls=PreviewTextGenerationModeliftuned_model_name:values["client"]=model_cls.get_tuned_model(tuned_model_name)values["client_preview"]=preview_model_cls.get_tuned_model(tuned_model_name)else:ifis_gemini:values["client"]=model_cls(model_name=model_name)values["client_preview"]=preview_model_cls(model_name=model_name)else:values["client"]=model_cls.from_pretrained(model_name)values["client_preview"]=preview_model_cls.from_pretrained(model_name)exceptImportError:raise_vertex_import_error()ifvalues["streaming"]andvalues["n"]>1:raiseValueError("Only one candidate can be generated with streaming!")returnvalues
[docs]defget_num_tokens(self,text:str)->int:"""Get the number of tokens present in the text. Useful for checking if an input will fit in a model's context window. Args: text: The string input to tokenize. Returns: The integer number of tokens in the text. """try:result=self.client_preview.count_tokens([text])exceptAttributeError:raise_vertex_import_error()returnresult.total_tokens
def_response_to_generation(self,response:TextGenerationResponse)->GenerationChunk:"""Converts a stream response to a generation chunk."""try:generation_info={"is_blocked":response.is_blocked,"safety_attributes":response.safety_attributes,}exceptException:generation_info=NonereturnGenerationChunk(text=response.text,generation_info=generation_info)def_generate(self,prompts:List[str],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->LLMResult:should_stream=streamifstreamisnotNoneelseself.streamingparams=self._prepare_params(stop=stop,stream=should_stream,**kwargs)generations:List[List[Generation]]=[]forpromptinprompts:ifshould_stream:generation=GenerationChunk(text="")forchunkinself._stream(prompt,stop=stop,run_manager=run_manager,**kwargs):generation+=chunkgenerations.append([generation])else:res=completion_with_retry(# type: ignore[misc]self,[prompt],stream=should_stream,is_gemini=self._is_gemini_model,run_manager=run_manager,**params,)generations.append([self._response_to_generation(r)forrinres.candidates])returnLLMResult(generations=generations)asyncdef_agenerate(self,prompts:List[str],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->LLMResult:params=self._prepare_params(stop=stop,**kwargs)generations=[]forpromptinprompts:res=awaitacompletion_with_retry(self,prompt,is_gemini=self._is_gemini_model,run_manager=run_manager,**params,)generations.append([self._response_to_generation(r)forrinres.candidates])returnLLMResult(generations=generations)# type: ignore[arg-type]def_stream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[GenerationChunk]:params=self._prepare_params(stop=stop,stream=True,**kwargs)forstream_respincompletion_with_retry(# type: ignore[misc]self,[prompt],stream=True,is_gemini=self._is_gemini_model,run_manager=run_manager,**params,):chunk=self._response_to_generation(stream_resp)ifrun_manager:run_manager.on_llm_new_token(chunk.text,chunk=chunk,verbose=self.verbose,)yieldchunk
[docs]@deprecated(since="0.0.12",removal="1.0",alternative_import="langchain_google_vertexai.VertexAIModelGarden",)classVertexAIModelGarden(_VertexAIBase,BaseLLM):"""Vertex AI Model Garden large language models."""client:"PredictionServiceClient"=(None#: :meta private: # type: ignore[assignment])async_client:"PredictionServiceAsyncClient"=(None#: :meta private: # type: ignore[assignment])endpoint_id:str"A name of an endpoint where the model has been deployed."allowed_model_args:Optional[List[str]]=None"Allowed optional args to be passed to the model."prompt_arg:str="prompt"result_arg:Optional[str]="generated_text""Set result_arg to None if output of the model is expected to be a string.""Otherwise, if it's a dict, provided an argument that contains the result."
[docs]@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate that the python package exists in environment."""try:fromgoogle.api_core.client_optionsimportClientOptionsfromgoogle.cloud.aiplatform.gapicimport(PredictionServiceAsyncClient,PredictionServiceClient,)exceptImportError:raise_vertex_import_error()ifnotvalues["project"]:raiseValueError("A GCP project should be provided to run inference on Model Garden!")client_options=ClientOptions(api_endpoint=f"{values['location']}-aiplatform.googleapis.com")client_info=get_client_info(module="vertex-ai-model-garden")values["client"]=PredictionServiceClient(client_options=client_options,client_info=client_info)values["async_client"]=PredictionServiceAsyncClient(client_options=client_options,client_info=client_info)returnvalues
@propertydefendpoint_path(self)->str:returnself.client.endpoint_path(project=self.project,# type: ignore[arg-type]location=self.location,endpoint=self.endpoint_id,)@propertydef_llm_type(self)->str:return"vertexai_model_garden"def_prepare_request(self,prompts:List[str],**kwargs:Any)->List["Value"]:try:fromgoogle.protobufimportjson_formatfromgoogle.protobuf.struct_pb2importValueexceptImportError:raiseImportError("protobuf package not found, please install it with"" `pip install protobuf`")instances=[]forpromptinprompts:ifself.allowed_model_args:instance={k:vfork,vinkwargs.items()ifkinself.allowed_model_args}else:instance={}instance[self.prompt_arg]=promptinstances.append(instance)predict_instances=[json_format.ParseDict(instance_dict,Value())forinstance_dictininstances]returnpredict_instancesdef_generate(self,prompts:List[str],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->LLMResult:"""Run the LLM on the given prompt and input."""instances=self._prepare_request(prompts,**kwargs)response=self.client.predict(endpoint=self.endpoint_path,instances=instances)returnself._parse_response(response)def_parse_response(self,predictions:"Prediction")->LLMResult:generations:List[List[Generation]]=[]forresultinpredictions.predictions:generations.append([Generation(text=self._parse_prediction(prediction))forpredictioninresult])returnLLMResult(generations=generations)def_parse_prediction(self,prediction:Any)->str:ifisinstance(prediction,str):returnpredictionifself.result_arg:try:returnprediction[self.result_arg]exceptKeyError:ifisinstance(prediction,str):error_desc=("Provided non-None `result_arg` (result_arg="f"{self.result_arg}). But got prediction of type "f"{type(prediction)} instead of dict. Most probably, you""need to set `result_arg=None` during VertexAIModelGarden ""initialization.")raiseValueError(error_desc)else:raiseValueError(f"{self.result_arg} key not found in prediction!")returnpredictionasyncdef_agenerate(self,prompts:List[str],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->LLMResult:"""Run the LLM on the given prompt and input."""instances=self._prepare_request(prompts,**kwargs)response=awaitself.async_client.predict(endpoint=self.endpoint_path,instances=instances)returnself._parse_response(response)