[docs]classVertexAI(_VertexAICommon,BaseLLM):"""Google Vertex AI large language models."""model_name:str=Field(default="gemini-2.0-flash-001",alias="model")"The name of the Vertex AI large language model."tuned_model_name:Optional[str]=None"""The name of a tuned model. If tuned_model_name is passed model_name will be used to determine the model family """response_mime_type:Optional[str]=None"""Optional. Output response mimetype of the generated candidate text. Only supported in Gemini 1.5 and later models. Supported mimetype: * "text/plain": (default) Text output. * "application/json": JSON response in the candidates. * "text/x.enum": Enum in plain text. The model also needs to be prompted to output the appropriate response type, otherwise the behavior is undefined. This is a preview feature. """response_schema:Optional[Dict[str,Any]]=None""" Optional. Enforce an schema to the output. The format of the dictionary should follow Open API schema. """def__init__(self,*,model_name:Optional[str]=None,**kwargs:Any)->None:"""Needed for mypy typing to recognize model_name as a valid arg."""ifmodel_name:kwargs["model_name"]=model_namesuper().__init__(**kwargs)model_config=ConfigDict(populate_by_name=True,)@classmethoddefis_lc_serializable(self)->bool:returnTrue@classmethoddefget_lc_namespace(cls)->List[str]:"""Get the namespace of the langchain object."""return["langchain","llms","vertexai"]@model_validator(mode="after")defvalidate_environment(self)->Self:"""Validate that the python package exists in environment."""if"medlm"inself.model_nameandself.model_family==GoogleModelFamily.PALM:err=("MedLM on Palm is not supported any more! Please, use Gemini or ""switch to langchain-google-vertexai==2.0.13")raiseValueError(err)self.client=ChatVertexAI(model_name=self.model_name,tuned_model_name=self.tuned_model_name,project=self.project,location=self.location,credentials=self.credentials,api_transport=self.api_transport,api_endpoint=self.api_endpoint,default_metadata=self.default_metadata,temperature=self.temperature,max_output_tokens=self.max_output_tokens,top_p=self.top_p,top_k=self.top_k,safety_settings=self.safety_settings,n=self.n,seed=self.seed,response_schema=self.response_schema,response_mime_type=self.response_mime_type,)returnselfdef_get_ls_params(self,stop:Optional[List[str]]=None,**kwargs:Any)->LangSmithParams:"""Get standard params for tracing."""params=self._prepare_params(stop=stop,**kwargs)ls_params=super()._get_ls_params(stop=stop,**params)ls_params["ls_provider"]="google_vertexai"ifls_max_tokens:=params.get("max_output_tokens",self.max_output_tokens):ls_params["ls_max_tokens"]=ls_max_tokensifls_stop:=stoporself.stop:ls_params["ls_stop"]=ls_stopreturnls_paramsdef_generate(self,prompts:List[str],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->LLMResult:generations:List[List[Generation]]=[]forpromptinprompts:chat_result=self.client._generate([HumanMessage(content=prompt)],stop=stop,stream=stream,run_manager=run_manager,**kwargs,)generations.append([Generation(text=g.message.content,generation_info={**g.generation_info},)forginchat_result.generations])returnLLMResult(generations=generations)asyncdef_agenerate(self,prompts:List[str],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->LLMResult:generations:List[List[Generation]]=[]forpromptinprompts:chat_result=awaitself.client._agenerate([HumanMessage(content=prompt)],stop=stop,run_manager=run_manager,**kwargs,)generations.append([Generation(text=g.message.content,generation_info={**g.generation_info,},)forginchat_result.generations])returnLLMResult(generations=generations)@staticmethoddef_lc_usage_to_metadata(lc_usage:Dict[str,Any])->Dict[str,Any]:mapping={"input_tokens":"prompt_token_count","output_tokens":"candidates_token_count","total_tokens":"total_token_count",}return{mapping[k]:vfork,vinlc_usage.items()ifvandkinmapping}def_stream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[GenerationChunk]:forstream_chunkinself.client._stream([HumanMessage(content=prompt)],stop=stop,run_manager=run_manager,**kwargs,):ifstream_chunk.message.usage_metadata:lc_usage=stream_chunk.message.usage_metadatausage_metadata={**lc_usage,**self._lc_usage_to_metadata(lc_usage=lc_usage),}else:usage_metadata={}chunk=GenerationChunk(text=stream_chunk.message.content,generation_info={**stream_chunk.generation_info,**{"usage_metadata":usage_metadata},},)yieldchunkifrun_manager:run_manager.on_llm_new_token(chunk.text,chunk=chunk,verbose=self.verbose,)asyncdef_astream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[GenerationChunk]:asyncforstream_chunkinself.client._astream([HumanMessage(content=prompt)],stop=stop,run_manager=run_manager,**kwargs,):chunk=GenerationChunk(text=stream_chunk.message.content)yieldchunkifrun_manager:awaitrun_manager.on_llm_new_token(chunk.text,chunk=chunk,verbose=self.verbose)
[docs]defget_num_tokens(self,text:str)->int:"""Get the number of tokens present in the text. Useful for checking if an input will fit in a model's context window. Args: text: The string input to tokenize. Returns: The integer number of tokens in the text. """returnself.client.get_num_tokens(text)