Source code for langchain_google_vertexai.chat_models
"""Wrapper around Google VertexAI chat-based models."""from__future__importannotations# noqaimportastfromfunctoolsimportcached_propertyimportjsonimportloggingimportrefromdataclassesimportdataclass,fieldfromoperatorimportitemgetterimportuuidfromtypingimport(Any,AsyncIterator,Callable,Dict,Iterator,List,Optional,Sequence,Type,Union,cast,Literal,Tuple,TypedDict,overload,)importproto# type: ignore[import-untyped]fromgoogle.cloud.aiplatformimporttelemetryfromlangchain_core.callbacksimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_modelsimportLanguageModelInputfromlangchain_core.language_models.chat_modelsimport(BaseChatModel,LangSmithParams,generate_from_stream,agenerate_from_stream,)fromlangchain_core.messagesimport(AIMessage,AIMessageChunk,BaseMessage,FunctionMessage,HumanMessage,SystemMessage,ToolCall,ToolMessage,)fromlangchain_core.messages.aiimportUsageMetadatafromlangchain_core.messages.toolimport(tool_call_chunk,tool_callascreate_tool_call,invalid_tool_call,)fromlangchain_core.output_parsers.baseimportOutputParserLikefromlangchain_core.output_parsersimportJsonOutputParser,PydanticOutputParserfromlangchain_core.output_parsers.openai_toolsimport(JsonOutputKeyToolsParser,PydanticToolsParser,)fromlangchain_core.output_parsers.openai_toolsimportparse_tool_callsfromlangchain_core.outputsimportChatGeneration,ChatGenerationChunk,ChatResultfrompydanticimportBaseModel,Field,model_validatorfromlangchain_core.runnablesimportRunnable,RunnablePassthroughfromlangchain_core.utils.function_callingimport(convert_to_json_schema,convert_to_openai_tool,)fromlangchain_core.utils.pydanticimportis_basemodel_subclassfromvertexai.generative_modelsimport(# type: ignoreToolasVertexTool,)fromvertexai.generative_models._generative_modelsimport(# type: ignoreToolConfig,SafetySettingsType,GenerationConfigType,GenerationResponse,_convert_schema_dict_to_gapic,)fromvertexai.language_modelsimport(# type: ignoreChatMessage,ChatModel,ChatSession,CodeChatModel,CodeChatSession,InputOutputTextPair,)fromvertexai.preview.language_modelsimport(# type: ignoreChatModelasPreviewChatModel,)fromvertexai.preview.language_modelsimport(CodeChatModelasPreviewCodeChatModel,)fromgoogle.cloud.aiplatform_v1.typesimport(Contentasv1Content,FunctionCallingConfigasv1FunctionCallingConfig,GenerateContentRequestasv1GenerateContentRequest,GenerationConfigasv1GenerationConfig,Partasv1Part,SafetySettingasv1SafetySetting,Toolasv1Tool,ToolConfigasv1ToolConfig,)fromgoogle.cloud.aiplatform_v1beta1.typesimport(Blob,Candidate,Part,HarmCategory,Content,FileData,FunctionCall,FunctionResponse,GenerateContentRequest,GenerationConfig,SafetySetting,ToolasGapicTool,ToolConfigasGapicToolConfig,VideoMetadata,)fromlangchain_google_vertexai._baseimport_VertexAICommon,GoogleModelFamilyfromlangchain_google_vertexai._image_utilsimportImageBytesLoaderfromlangchain_google_vertexai._utilsimport(create_retry_decorator,get_generation_info,_format_model_name,is_gemini_model,replace_defs_in_schema,)fromlangchain_google_vertexai.functions_utilsimport(_format_tool_config,_ToolConfigDict,_tool_choice_to_tool_config,_ToolChoiceType,_ToolsType,_format_to_gapic_tool,_ToolType,)frompydanticimportConfigDictfrompydantic.v1importBaseModelasBaseModelV1fromtyping_extensionsimportSelf,is_typeddictlogger=logging.getLogger(__name__)_allowed_params=["temperature","top_k","top_p","response_mime_type","response_schema","max_output_tokens","presence_penalty","frequency_penalty","candidate_count","seed","response_logprobs","logprobs","labels",]_allowed_params_prediction_service=["request","timeout","metadata","labels"]@dataclassclass_ChatHistory:"""Represents a context and a history of messages."""history:List[ChatMessage]=field(default_factory=list)context:Optional[str]=Noneclass_GeminiGenerateContentKwargs(TypedDict):generation_config:Optional[GenerationConfigType]safety_settings:Optional[SafetySettingsType]tools:Optional[List[VertexTool]]tool_config:Optional[ToolConfig]def_parse_chat_history(history:List[BaseMessage])->_ChatHistory:"""Parse a sequence of messages into history. Args: history: The list of messages to re-create the history of the chat. Returns: A parsed chat history. Raises: ValueError: If a sequence of message has a SystemMessage not at the first place. """vertex_messages,context=[],Nonefori,messageinenumerate(history):content=cast(str,message.content)ifi==0andisinstance(message,SystemMessage):context=contentelifisinstance(message,AIMessage):vertex_message=ChatMessage(content=message.content,author="bot")vertex_messages.append(vertex_message)elifisinstance(message,HumanMessage):vertex_message=ChatMessage(content=message.content,author="user")vertex_messages.append(vertex_message)else:raiseValueError(f"Unexpected message with type {type(message)} at the position {i}.")chat_history=_ChatHistory(context=context,history=vertex_messages)returnchat_historydef_parse_chat_history_gemini(history:List[BaseMessage],imageBytesLoader:ImageBytesLoader,convert_system_message_to_human:Optional[bool]=False,perform_literal_eval_on_string_raw_content:Optional[bool]=False,)->tuple[Content|None,list[Content]]:def_convert_to_prompt(part:Union[str,Dict])->Optional[Part]:ifisinstance(part,str):returnPart(text=part)ifnotisinstance(part,Dict):raiseValueError(f"Message's content is expected to be a dict, got {type(part)}!")ifpart["type"]=="text":returnPart(text=part["text"])ifpart["type"]=="tool_use":ifpart.get("text"):returnPart(text=part["text"])else:returnNoneifpart["type"]=="image_url":path=part["image_url"]["url"]returnimageBytesLoader.load_gapic_part(path)# Handle media type like LangChain.js# https://github.com/langchain-ai/langchainjs/blob/e536593e2585f1dd7b0afc187de4d07cb40689ba/libs/langchain-google-common/src/utils/gemini.ts#L93-L106ifpart["type"]=="media":if"mime_type"notinpart:raiseValueError(f"Missing mime_type in media part: {part}")mime_type=part["mime_type"]proto_part=Part()if"data"inpart:proto_part.inline_data=Blob(data=part["data"],mime_type=mime_type)elif"file_uri"inpart:proto_part.file_data=FileData(file_uri=part["file_uri"],mime_type=mime_type)else:raiseValueError(f"Media part must have either data or file_uri: {part}")if"video_metadata"inpart:metadata=VideoMetadata(part["video_metadata"])proto_part.video_metadata=metadatareturnproto_partraiseValueError("Only text, image_url, and media types are supported!")def_convert_to_parts(message:BaseMessage)->List[Part]:raw_content=message.content# If a user sends a multimodal request with agents, then the full input# will be sent as a string due to the ChatPromptTemplate formatting.# Because of this, we need to first try to convert the string to its# native type (such as list or dict) so that results can be properly# appended to the prompt, otherwise they will all be parsed as Text# rather than `inline_data`.ifperform_literal_eval_on_string_raw_contentandisinstance(raw_content,str):try:raw_content=ast.literal_eval(raw_content)exceptSyntaxError:passexceptValueError:pass# A linting error is thrown here because it does not think this line is# reachable due to typing, but mypy is wrong so we ignore the lint# error.ifisinstance(raw_content,int):# type: ignoreraw_content=str(raw_content)# type: ignoreifisinstance(raw_content,str):raw_content=[raw_content]result=[]forraw_partinraw_content:part=_convert_to_prompt(raw_part)ifpart:result.append(part)returnresultvertex_messages:List[Content]=[]system_parts:List[Part]|None=Nonesystem_instruction=None# the last AI Message before a sequence of tool callsprev_ai_message:Optional[AIMessage]=Nonefori,messageinenumerate(history):ifisinstance(message,SystemMessage):prev_ai_message=Nonesystem_parts=_convert_to_parts(message)ifconvert_system_message_to_human:logger.warning("gemini models released from April 2024 support""SystemMessages natively. For best performances,""when working with these models,""set convert_system_message_to_human to False")continueifsystem_instructionisnotNone:system_instruction.parts.extend(system_parts)else:system_instruction=Content(role="system",parts=system_parts)system_parts=Noneelifisinstance(message,HumanMessage):prev_ai_message=Nonerole="user"parts=_convert_to_parts(message)ifsystem_partsisnotNone:parts=system_parts+partssystem_parts=Noneifvertex_messagesandvertex_messages[-1].role=="user":prev_parts=list(vertex_messages[-1].parts)vertex_messages[-1]=Content(role=role,parts=prev_parts+parts)else:vertex_messages.append(Content(role=role,parts=parts))elifisinstance(message,AIMessage):prev_ai_message=messagerole="model"parts=[]ifmessage.content:parts=_convert_to_parts(message)fortcinmessage.tool_calls:function_call=FunctionCall({"name":tc["name"],"args":tc["args"]})parts.append(Part(function_call=function_call))iflen(vertex_messages):prev_content=vertex_messages[-1]prev_content_is_model=prev_contentandprev_content.role=="model"ifprev_content_is_model:prev_parts=list(prev_content.parts)prev_parts.extend(parts)vertex_messages[-1]=Content(role=role,parts=prev_parts)continuevertex_messages.append(Content(role=role,parts=parts))elifisinstance(message,FunctionMessage):prev_ai_message=Nonerole="function"part=Part(function_response=FunctionResponse(name=message.name,response={"content":message.content}))parts=[part]iflen(vertex_messages):prev_content=vertex_messages[-1]prev_content_is_function=(prev_contentandprev_content.role=="function")ifprev_content_is_function:prev_parts=list(prev_content.parts)prev_parts.extend(parts)# replacing last messagevertex_messages[-1]=Content(role=role,parts=prev_parts)continuevertex_messages.append(Content(role=role,parts=parts))elifisinstance(message,ToolMessage):role="function"# message.name can be null for ToolMessagename=message.nameifnameisNone:ifprev_ai_message:tool_call_id=message.tool_call_idtool_call:ToolCall|None=next((tfortinprev_ai_message.tool_callsift["id"]==tool_call_id),None,)iftool_callisNone:raiseValueError(("Message name is empty and can't find"+f"corresponding tool call for id: '${tool_call_id}'"))name=tool_call["name"]def_parse_content(raw_content:str|Dict[Any,Any])->Dict[Any,Any]:ifisinstance(raw_content,dict):returnraw_contentifisinstance(raw_content,str):try:content=json.loads(raw_content)# json.loads("2") returns 2 since it's a valid jsonifisinstance(content,dict):returncontentexceptjson.JSONDecodeError:passreturn{"content":raw_content}ifisinstance(message.content,list):parsed_content=[_parse_content(c)forcinmessage.content]iflen(parsed_content)>1:merged_content:Dict[Any,Any]={}forcontent_pieceinparsed_content:forkey,valueincontent_piece.items():ifkeynotinmerged_content:merged_content[key]=[]merged_content[key].append(value)logger.warning("Expected content to be a str, got a list with > 1 element.""Merging values together")content={k:"".join(v)fork,vinmerged_content.items()}else:content=parsed_content[0]else:content=_parse_content(message.content)part=Part(function_response=FunctionResponse(name=name,response=content,))parts=[part]prev_content=vertex_messages[-1]prev_content_is_function=prev_contentandprev_content.role=="function"ifprev_content_is_function:prev_parts=list(prev_content.parts)prev_parts.extend(parts)# replacing last messagevertex_messages[-1]=Content(role=role,parts=prev_parts)continuevertex_messages.append(Content(role=role,parts=parts))else:raiseValueError(f"Unexpected message with type {type(message)} at the position {i}.")returnsystem_instruction,vertex_messagesdef_parse_examples(examples:List[BaseMessage])->List[InputOutputTextPair]:iflen(examples)%2!=0:raiseValueError(f"Expect examples to have an even amount of messages, got {len(examples)}.")example_pairs=[]input_text=Nonefori,exampleinenumerate(examples):ifi%2==0:ifnotisinstance(example,HumanMessage):raiseValueError(f"Expected the first message in a part to be from human, got "f"{type(example)} for the {i}th message.")input_text=example.contentifi%2==1:ifnotisinstance(example,AIMessage):raiseValueError(f"Expected the second message in a part to be from AI, got "f"{type(example)} for the {i}th message.")pair=InputOutputTextPair(input_text=input_text,output_text=example.content)example_pairs.append(pair)returnexample_pairsdef_get_question(messages:List[BaseMessage])->HumanMessage:"""Get the human message at the end of a list of input messages to a chat model."""ifnotmessages:raiseValueError("You should provide at least one message to start the chat!")question=messages[-1]ifnotisinstance(question,HumanMessage):raiseValueError(f"Last message in the list should be from human, got {question.type}.")returnquestion@overloaddef_parse_response_candidate(response_candidate:"Candidate",streaming:Literal[False]=False)->AIMessage:...@overloaddef_parse_response_candidate(response_candidate:"Candidate",streaming:Literal[True])->AIMessageChunk:...def_parse_response_candidate(response_candidate:"Candidate",streaming:bool=False)->AIMessage:content:Union[None,str,List[str]]=Noneadditional_kwargs={}tool_calls=[]invalid_tool_calls=[]tool_call_chunks=[]forpartinresponse_candidate.content.parts:try:text:Optional[str]=part.textexceptAttributeError:text=Noneiftext:ifnotcontent:content=textelifisinstance(content,str):content=[content,text]elifisinstance(content,list):content.append(text)else:raiseException("Unexpected content type")ifpart.function_call:# For backward compatibility we store a function call in additional_kwargs,# but in general the full set of function calls is stored in tool_calls.function_call={"name":part.function_call.name}# dump to match other function calling llm for nowfunction_call_args_dict=proto.Message.to_dict(part.function_call)["args"]function_call["arguments"]=json.dumps({k:function_call_args_dict[k]forkinfunction_call_args_dict})additional_kwargs["function_call"]=function_callifstreaming:index=function_call.get("index")tool_call_chunks.append(tool_call_chunk(name=function_call.get("name"),args=function_call.get("arguments"),id=function_call.get("id",str(uuid.uuid4())),index=int(index)ifindexelseNone,))else:try:tool_calls_dicts=parse_tool_calls([{"function":function_call}],return_id=False,)tool_calls.extend([create_tool_call(name=tool_call["name"],args=tool_call["args"],id=tool_call.get("id",str(uuid.uuid4())),)fortool_callintool_calls_dicts])exceptExceptionase:invalid_tool_calls.append(invalid_tool_call(name=function_call.get("name"),args=function_call.get("arguments"),id=function_call.get("id",str(uuid.uuid4())),error=str(e),))ifcontentisNone:content=""ifstreaming:returnAIMessageChunk(content=cast(Union[str,List[Union[str,Dict[Any,Any]]]],content),additional_kwargs=additional_kwargs,tool_call_chunks=tool_call_chunks,)returnAIMessage(content=cast(Union[str,List[Union[str,Dict[Any,Any]]]],content),tool_calls=tool_calls,additional_kwargs=additional_kwargs,invalid_tool_calls=invalid_tool_calls,)def_completion_with_retry(generation_method:Callable,*,max_retries:int,run_manager:Optional[CallbackManagerForLLMRun]=None,wait_exponential_kwargs:Optional[dict[str,float]]=None,**kwargs:Any,)->Any:"""Use tenacity to retry the completion call."""retry_decorator=create_retry_decorator(max_retries=max_retries,run_manager=run_manager,wait_exponential_kwargs=wait_exponential_kwargs,)@retry_decoratordef_completion_with_retry_inner(generation_method:Callable,**kwargs:Any)->Any:returngeneration_method(**kwargs)params=({k:vfork,vinkwargs.items()ifkin_allowed_params_prediction_service}ifkwargs.get("is_gemini")elsekwargs)return_completion_with_retry_inner(generation_method,**params,)asyncdef_acompletion_with_retry(generation_method:Callable,*,max_retries:int,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,wait_exponential_kwargs:Optional[dict[str,float]]=None,**kwargs:Any,)->Any:"""Use tenacity to retry the completion call."""retry_decorator=create_retry_decorator(max_retries=max_retries,run_manager=run_manager,wait_exponential_kwargs=wait_exponential_kwargs,)@retry_decoratorasyncdef_completion_with_retry_inner(generation_method:Callable,**kwargs:Any)->Any:returnawaitgeneration_method(**kwargs)params=({k:vfork,vinkwargs.items()ifkin_allowed_params_prediction_service}ifkwargs.get("is_gemini")elsekwargs)returnawait_completion_with_retry_inner(generation_method,**params,)
[docs]classChatVertexAI(_VertexAICommon,BaseChatModel):"""Google Cloud Vertex AI chat model integration. Setup: You must either: - Have credentials configured for your environment (gcloud, workload identity, etc...) - Store the path to a service account JSON file as the GOOGLE_APPLICATION_CREDENTIALS environment variable This codebase uses the google.auth library which first looks for the application credentials variable mentioned above, and then looks for system-level auth. For more information, see: https://cloud.google.com/docs/authentication/application-default-credentials#GAC and https://googleapis.dev/python/google-auth/latest/reference/google.auth.html#module-google.auth. Key init args — completion params: model: str Name of ChatVertexAI model to use. e.g. "gemini-1.5-flash-001", "gemini-1.5-pro-001", etc. temperature: Optional[float] Sampling temperature. seed: Optional[int] Sampling integer to use. max_tokens: Optional[int] Max number of tokens to generate. stop: Optional[List[str]] Default stop sequences. safety_settings: Optional[Dict[vertexai.generative_models.HarmCategory, vertexai.generative_models.HarmBlockThreshold]] The default safety settings to use for all generations. Key init args — client params: max_retries: int Max number of retries. wait_exponential_kwargs: Optional[dict[str, float]] Optional dictionary with parameters for wait_exponential: - multiplier: Initial wait time multiplier (default: 1.0) - min: Minimum wait time in seconds (default: 4.0) - max: Maximum wait time in seconds (default: 10.0) - exp_base: Exponent base to use (default: 2.0) credentials: Optional[google.auth.credentials.Credentials] The default custom credentials to use when making API calls. If not provided, credentials will be ascertained from the environment. project: Optional[str] The default GCP project to use when making Vertex API calls. location: str = "us-central1" The default location to use when making API calls. request_parallelism: int = 5 The amount of parallelism allowed for requests issued to VertexAI models. Default is 5. base_url: Optional[str] Base URL for API requests. See full list of supported init args and their descriptions in the params section. Instantiate: .. code-block:: python from langchain_google_vertexai import ChatVertexAI llm = ChatVertexAI( model="gemini-1.5-flash-001", temperature=0, max_tokens=None, max_retries=6, stop=None, # other params... ) Invoke: .. code-block:: python messages = [ ("system", "You are a helpful translator. Translate the user sentence to French."), ("human", "I love programming."), ] llm.invoke(messages) .. code-block:: python AIMessage(content="J'adore programmer. \n", response_metadata={'is_blocked': False, 'safety_ratings': [{'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}], 'citation_metadata': None, 'usage_metadata': {'prompt_token_count': 17, 'candidates_token_count': 7, 'total_token_count': 24}}, id='run-925ce305-2268-44c4-875f-dde9128520ad-0') Stream: .. code-block:: python for chunk in llm.stream(messages): print(chunk) .. code-block:: python AIMessageChunk(content='J', response_metadata={'is_blocked': False, 'safety_ratings': [], 'citation_metadata': None}, id='run-9df01d73-84d9-42db-9d6b-b1466a019e89') AIMessageChunk(content="'adore programmer. \n", response_metadata={'is_blocked': False, 'safety_ratings': [{'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}], 'citation_metadata': None}, id='run-9df01d73-84d9-42db-9d6b-b1466a019e89') AIMessageChunk(content='', response_metadata={'is_blocked': False, 'safety_ratings': [], 'citation_metadata': None, 'usage_metadata': {'prompt_token_count': 17, 'candidates_token_count': 7, 'total_token_count': 24}}, id='run-9df01d73-84d9-42db-9d6b-b1466a019e89') .. code-block:: python stream = llm.stream(messages) full = next(stream) for chunk in stream: full += chunk full .. code-block:: python AIMessageChunk(content="J'adore programmer. \n", response_metadata={'is_blocked': False, 'safety_ratings': [{'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}], 'citation_metadata': None, 'usage_metadata': {'prompt_token_count': 17, 'candidates_token_count': 7, 'total_token_count': 24}}, id='run-b7f7492c-4cb5-42d0-8fc3-dce9b293b0fb') Async: .. code-block:: python await llm.ainvoke(messages) # stream: # async for chunk in (await llm.astream(messages)) # batch: # await llm.abatch([messages]) .. code-block:: python AIMessage(content="J'adore programmer. \n", response_metadata={'is_blocked': False, 'safety_ratings': [{'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}], 'citation_metadata': None, 'usage_metadata': {'prompt_token_count': 17, 'candidates_token_count': 7, 'total_token_count': 24}}, id='run-925ce305-2268-44c4-875f-dde9128520ad-0') Tool calling: .. code-block:: python from pydantic import BaseModel, Field class GetWeather(BaseModel): '''Get the current weather in a given location''' location: str = Field(..., description="The city and state, e.g. San Francisco, CA") class GetPopulation(BaseModel): '''Get the current population in a given location''' location: str = Field(..., description="The city and state, e.g. San Francisco, CA") llm_with_tools = llm.bind_tools([GetWeather, GetPopulation]) ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?") ai_msg.tool_calls .. code-block:: python [{'name': 'GetWeather', 'args': {'location': 'Los Angeles, CA'}, 'id': '2a2401fa-40db-470d-83ce-4e52de910d9e'}, {'name': 'GetWeather', 'args': {'location': 'New York City, NY'}, 'id': '96761deb-ab7f-4ef9-b4b4-6d44562fc46e'}, {'name': 'GetPopulation', 'args': {'location': 'Los Angeles, CA'}, 'id': '9147d532-abee-43a2-adb5-12f164300484'}, {'name': 'GetPopulation', 'args': {'location': 'New York City, NY'}, 'id': 'c43374ea-bde5-49ca-8487-5b83ebeea1e6'}] See ``ChatVertexAI.bind_tools()`` method for more. Use Search with Gemini 2: .. code-block:: python from google.cloud.aiplatform_v1beta1.types import Tool as VertexTool llm = ChatVertexAI(model="gemini-2.0-flash-exp") resp = llm.invoke( "When is the next total solar eclipse in US?", tools=[VertexTool(google_search={})], ) Structured output: .. code-block:: python from typing import Optional from pydantic import BaseModel, Field class Joke(BaseModel): '''Joke to tell user.''' setup: str = Field(description="The setup of the joke") punchline: str = Field(description="The punchline to the joke") rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10") structured_llm = llm.with_structured_output(Joke) structured_llm.invoke("Tell me a joke about cats") .. code-block:: python Joke(setup='What do you call a cat that loves to bowl?', punchline='An alley cat!', rating=None) See ``ChatVertexAI.with_structured_output()`` for more. Image input: .. code-block:: python import base64 import httpx from langchain_core.messages import HumanMessage image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8") message = HumanMessage( content=[ {"type": "text", "text": "describe the weather in this image"}, { "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}, }, ], ) ai_msg = llm.invoke([message]) ai_msg.content .. code-block:: python 'The weather in this image appears to be sunny and pleasant. The sky is a bright blue with scattered white clouds, suggesting a clear and mild day. The lush green grass indicates recent rainfall or sufficient moisture. The absence of strong shadows suggests that the sun is high in the sky, possibly late afternoon. Overall, the image conveys a sense of tranquility and warmth, characteristic of a beautiful summer day. \n' You can also point to GCS files which is faster / more efficient because bytes are transferred back and forth. .. code-block:: python llm.invoke( [ HumanMessage( [ "What's in the image?", { "type": "media", "file_uri": "gs://cloud-samples-data/generative-ai/image/scones.jpg", "mime_type": "image/jpeg", }, ] ) ] ).content .. code-block:: python 'The image is of five blueberry scones arranged on a piece of baking paper. \n\nHere is a list of what is in the picture:\n* **Five blueberry scones:** They are scattered across the parchment paper, dusted with powdered sugar. \n* **Two cups of coffee:** Two white cups with saucers. One appears full, the other partially drunk.\n* **A bowl of blueberries:** A brown bowl is filled with fresh blueberries, placed near the scones.\n* **A spoon:** A silver spoon with the words "Let\'s Jam" rests on the paper.\n* **Pink peonies:** Several pink peonies lie beside the scones, adding a touch of color.\n* **Baking paper:** The scones, cups, bowl, and spoon are arranged on a piece of white baking paper, splattered with purple. The paper is crinkled and sits on a dark surface. \n\nThe image has a rustic and delicious feel, suggesting a cozy and enjoyable breakfast or brunch setting. \n' # codespell:ignore brunch Video input: **NOTE**: Currently only supported for ``gemini-...-vision`` models. .. code-block:: python llm = ChatVertexAI(model="gemini-1.0-pro-vision") llm.invoke( [ HumanMessage( [ "What's in the video?", { "type": "media", "file_uri": "gs://cloud-samples-data/video/animals.mp4", "mime_type": "video/mp4", }, ] ) ] ).content .. code-block:: python 'The video is about a new feature in Google Photos called "Zoomable Selfies". The feature allows users to take selfies with animals at the zoo. The video shows several examples of people taking selfies with animals, including a tiger, an elephant, and a sea otter. The video also shows how the feature works. Users simply need to open the Google Photos app and select the "Zoomable Selfies" option. Then, they need to choose an animal from the list of available animals. The app will then guide the user through the process of taking the selfie.' Audio input: .. code-block:: python from langchain_core.messages import HumanMessage llm = ChatVertexAI(model="gemini-1.5-flash-001") llm.invoke( [ HumanMessage( [ "What's this audio about?", { "type": "media", "file_uri": "gs://cloud-samples-data/generative-ai/audio/pixel.mp3", "mime_type": "audio/mpeg", }, ] ) ] ).content .. code-block:: python "This audio is an interview with two product managers from Google who work on Pixel feature drops. They discuss how feature drops are important for showcasing how Google devices are constantly improving and getting better. They also discuss some of the highlights of the January feature drop and the new features coming in the March drop for Pixel phones and Pixel watches. The interview concludes with discussion of how user feedback is extremely important to them in deciding which features to include in the feature drops. " Token usage: .. code-block:: python ai_msg = llm.invoke(messages) ai_msg.usage_metadata .. code-block:: python {'input_tokens': 17, 'output_tokens': 7, 'total_tokens': 24} Logprobs: .. code-block:: python llm = ChatVertexAI(model="gemini-1.5-flash-001", logprobs=True) ai_msg = llm.invoke(messages) ai_msg.response_metadata["logprobs_result"] .. code-block:: python [ {'token': 'J', 'logprob': -1.549651415189146e-06, 'top_logprobs': []}, {'token': "'", 'logprob': -1.549651415189146e-06, 'top_logprobs': []}, {'token': 'adore', 'logprob': 0.0, 'top_logprobs': []}, {'token': ' programmer', 'logprob': -1.1922384146600962e-07, 'top_logprobs': []}, {'token': '.', 'logprob': -4.827636439586058e-05, 'top_logprobs': []}, {'token': ' ', 'logprob': -0.018011733889579773, 'top_logprobs': []}, {'token': '\n', 'logprob': -0.0008687592926435173, 'top_logprobs': []} ] Response metadata .. code-block:: python ai_msg = llm.invoke(messages) ai_msg.response_metadata .. code-block:: python {'is_blocked': False, 'safety_ratings': [{'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}], 'usage_metadata': {'prompt_token_count': 17, 'candidates_token_count': 7, 'total_token_count': 24}} Safety settings .. code-block:: python from langchain_google_vertexai import HarmBlockThreshold, HarmCategory llm = ChatVertexAI( model="gemini-1.5-pro", safety_settings={ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, }, ) llm.invoke(messages).response_metadata .. code-block:: python {'is_blocked': False, 'safety_ratings': [{'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}, {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability_label': 'NEGLIGIBLE', 'probability_score': 0.1, 'blocked': False, 'severity': 'HARM_SEVERITY_NEGLIGIBLE', 'severity_score': 0.1}], 'usage_metadata': {'prompt_token_count': 17, 'candidates_token_count': 7, 'total_token_count': 24}} """# noqa: E501model_name:str=Field(default="chat-bison-default",alias="model")"Underlying model name."examples:Optional[List[BaseMessage]]=Noneconvert_system_message_to_human:bool=False"""[Deprecated] Since new Gemini models support setting a System Message, setting this parameter to True is discouraged. """response_mime_type:Optional[str]=None"""Optional. Output response mimetype of the generated candidate text. Only supported in Gemini 1.5 and later models. Supported mimetype: * "text/plain": (default) Text output. * "application/json": JSON response in the candidates. * "text/x.enum": Enum in plain text. The model also needs to be prompted to output the appropriate response type, otherwise the behavior is undefined. This is a preview feature. """response_schema:Optional[Dict[str,Any]]=None""" Optional. Enforce an schema to the output. The format of the dictionary should follow Open API schema. """cached_content:Optional[str]=None""" Optional. Use the model in cache mode. Only supported in Gemini 1.5 and later models. Must be a string containing the cache name (A sequence of numbers) """logprobs:Union[bool,int]=False"""Whether to return logprobs as part of AIMessage.response_metadata. If False, don't return logprobs. If True, return logprobs for top candidate. If int, return logprobs for top ``logprobs`` candidates. **NOTE**: As of 10.28.24 this is only supported for gemini-1.5-flash models. .. versionadded: 2.0.6 """labels:Optional[Dict[str,str]]=None""" Optional tag llm calls with metadata to help in tracebility and biling. """perform_literal_eval_on_string_raw_content:bool=True"""Whether to perform literal eval on string raw content. """wait_exponential_kwargs:Optional[dict[str,float]]=None"""Optional dictionary with parameters for wait_exponential: - multiplier: Initial wait time multiplier (default: 1.0) - min: Minimum wait time in seconds (default: 4.0) - max: Maximum wait time in seconds (default: 10.0) - exp_base: Exponent base to use (default: 2.0) """def__init__(self,*,model_name:Optional[str]=None,**kwargs:Any)->None:"""Needed for mypy typing to recognize model_name as a valid arg."""ifmodel_name:kwargs["model_name"]=model_namesuper().__init__(**kwargs)model_config=ConfigDict(populate_by_name=True,arbitrary_types_allowed=True,)@classmethoddefis_lc_serializable(self)->bool:returnTrue@classmethoddefget_lc_namespace(cls)->List[str]:"""Get the namespace of the langchain object."""return["langchain","chat_models","vertexai"]@model_validator(mode="after")defvalidate_labels(self)->Self:ifself.labels:forkey,valueinself.labels.items():ifnotre.match(r"^[a-z][a-z0-9-_]{0,62}$",key):raiseValueError(f"Invalid label key: {key}")ifvalueandlen(value)>63:raiseValueError(f"Label value too long: {value}")returnself@cached_propertydef_image_bytes_loader_client(self):returnImageBytesLoader(project=self.project)@model_validator(mode="after")defvalidate_environment(self)->Self:"""Validate that the python package exists in environment."""safety_settings=self.safety_settingstuned_model_name=self.tuned_model_nameself.model_family=GoogleModelFamily(self.model_name)ifself.model_name=="chat-bison-default":logger.warning("Model_name will become a required arg for VertexAIEmbeddings ""starting from Sep-01-2024. Currently the default is set to ""chat-bison")self.model_name="chat-bison"ifself.full_model_nameisnotNone:passelifself.tuned_model_nameisnotNone:self.full_model_name=_format_model_name(self.tuned_model_name,location=self.location,project=cast(str,self.project),)else:self.full_model_name=_format_model_name(self.model_name,location=self.location,project=cast(str,self.project),)ifsafety_settingsandnotis_gemini_model(self.model_family):raiseValueError("Safety settings are only supported for Gemini models")iftuned_model_name:generative_model_name=self.tuned_model_nameelse:generative_model_name=self.model_nameifnotis_gemini_model(self.model_family):logger.warning("Non-Gemini models are deprecated. ""They will be removed starting Dec-01-2024. ")values={"project":self.project,"location":self.location,"credentials":self.credentials,"api_transport":self.api_transport,"api_endpoint":self.api_endpoint,"default_metadata":self.default_metadata,}self._init_vertexai(values)ifself.model_family==GoogleModelFamily.CODEY:model_cls=CodeChatModelmodel_cls_preview=PreviewCodeChatModelelse:model_cls=ChatModelmodel_cls_preview=PreviewChatModelself.client=model_cls.from_pretrained(generative_model_name)self.client_preview=model_cls_preview.from_pretrained(generative_model_name)returnself@propertydef_is_gemini_advanced(self)->bool:returnself.model_family==GoogleModelFamily.GEMINI_ADVANCEDdef_prepare_params(self,stop:Optional[List[str]]=None,stream:bool=False,**kwargs:Any,)->dict:params=super()._prepare_params(stop=stop,stream=stream,**kwargs)response_mime_type=kwargs.get("response_mime_type",self.response_mime_type)ifresponse_mime_typeisnotNone:params["response_mime_type"]=response_mime_typeresponse_schema=kwargs.get("response_schema",self.response_schema)ifresponse_schemaisnotNone:allowed_mime_types=("application/json","text/x.enum")ifresponse_mime_typenotinallowed_mime_types:error_message=("`response_schema` is only supported when "f"`response_mime_type` is set to one of {allowed_mime_types}")raiseValueError(error_message)gapic_response_schema=_convert_schema_dict_to_gapic(response_schema)params["response_schema"]=gapic_response_schemareturnparamsdef_get_ls_params(self,stop:Optional[List[str]]=None,**kwargs:Any)->LangSmithParams:"""Get standard params for tracing."""params=self._prepare_params(stop=stop,**kwargs)ls_params=LangSmithParams(ls_provider="google_vertexai",ls_model_name=self.model_name,ls_model_type="chat",ls_temperature=params.get("temperature",self.temperature),)ifls_max_tokens:=params.get("max_output_tokens",self.max_output_tokens):ls_params["ls_max_tokens"]=ls_max_tokensifls_stop:=stoporparams.get("stop",None)orself.stop:ls_params["ls_stop"]=ls_stopreturnls_paramsdef_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->ChatResult:"""Generate next turn in the conversation. Args: messages: The history of the conversation as a list of messages. Code chat does not support context. stop: The list of stop words (optional). run_manager: The CallbackManager for LLM run, it's not used at the moment. stream: Whether to use the streaming endpoint. Returns: The ChatResult that contains outputs generated by the model. Raises: ValueError: if the last message in the list is not from human. """ifstreamisTrueor(streamisNoneandself.streaming):stream_iter=self._stream(messages,stop=stop,run_manager=run_manager,**kwargs)returngenerate_from_stream(stream_iter)ifnotself._is_gemini_model:returnself._generate_non_gemini(messages,stop=stop,**kwargs)returnself._generate_gemini(messages=messages,stop=stop,run_manager=run_manager,is_gemini=True,**kwargs,)def_generation_config_gemini(self,stop:Optional[List[str]]=None,stream:bool=False,*,logprobs:int|bool=False,**kwargs:Any,)->Union[GenerationConfig,v1GenerationConfig]:"""Prepares GenerationConfig part of the request. https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1beta1#generationconfig """iflogprobsandisinstance(logprobs,bool):kwargs["response_logprobs"]=logprobseliflogprobsandisinstance(logprobs,int):kwargs["response_logprobs"]=Truekwargs["logprobs"]=logprobselse:passifself.endpoint_version=="v1":returnv1GenerationConfig(**self._prepare_params(stop=stop,stream=stream,**{k:vfork,vinkwargs.items()ifkin_allowed_params},))returnGenerationConfig(**self._prepare_params(stop=stop,stream=stream,**{k:vfork,vinkwargs.items()ifkin_allowed_params},))def_safety_settings_gemini(self,safety_settings:Optional[SafetySettingsType])->Optional[Sequence[SafetySetting]]:"""Prepares SafetySetting part of the request. https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1beta1#safetysetting """ifsafety_settingsisNone:ifself.safety_settings:returnself._safety_settings_gemini(self.safety_settings)returnNoneifisinstance(safety_settings,list):returnsafety_settingsifisinstance(safety_settings,dict):formatted_safety_settings=[]forcategory,thresholdinsafety_settings.items():ifisinstance(category,str):category=HarmCategory[category]# type: ignore[misc]ifisinstance(threshold,str):threshold=SafetySetting.HarmBlockThreshold[threshold]# type: ignore[misc]formatted_safety_settings.append(SafetySetting(category=HarmCategory(category),threshold=SafetySetting.HarmBlockThreshold(threshold),))returnformatted_safety_settingsraiseValueError("safety_settings should be either")def_prepare_request_gemini(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,stream:bool=False,tools:Optional[_ToolsType]=None,functions:Optional[_ToolsType]=None,tool_config:Optional[Union[_ToolConfigDict,ToolConfig]]=None,safety_settings:Optional[SafetySettingsType]=None,cached_content:Optional[str]=None,*,tool_choice:Optional[_ToolChoiceType]=None,logprobs:Optional[Union[int,bool]]=None,**kwargs,)->Union[v1GenerateContentRequest,GenerateContentRequest]:system_instruction,contents=_parse_chat_history_gemini(messages,self._image_bytes_loader_client,perform_literal_eval_on_string_raw_content=self.perform_literal_eval_on_string_raw_content,)formatted_tools=self._tools_gemini(tools=tools,functions=functions)iftool_config:tool_config=self._tool_config_gemini(tool_config=tool_config)eliftool_choice:all_names=[f.namefortoolin(formatted_toolsor[])forfintool.function_declarations]tool_config=_tool_choice_to_tool_config(tool_choice,all_names)else:passsafety_settings=self._safety_settings_gemini(safety_settings)logprobs=logprobsiflogprobsisnotNoneelseself.logprobslogprobs=logprobsifisinstance(logprobs,(int,bool))elseFalsegeneration_config=self._generation_config_gemini(stream=stream,stop=stop,logprobs=logprobs,**kwargs)def_content_to_v1(contents:list[Content])->list[v1Content]:v1_contens=[]forcontentincontents:v1_parts=[]forpartincontent.parts:raw_part=proto.Message.to_dict(part)_=raw_part.pop("thought")v1_parts.append(v1Part(**raw_part))v1_contens.append(v1Content(role=content.role,parts=v1_parts))returnv1_contensv1_system_instruction,v1_tools,v1_tool_config,v1_safety_settings=(None,None,None,None,)ifself.endpoint_version=="v1":v1_system_instruction=(_content_to_v1([system_instruction])[0]ifsystem_instructionelseNone)ifformatted_tools:v1_tools=[v1Tool(**proto.Message.to_dict(t))fortinformatted_tools]iftool_config:v1_tool_config=v1ToolConfig(function_calling_config=v1FunctionCallingConfig(**proto.Message.to_dict(tool_config.function_calling_config)))ifsafety_settings:v1_safety_settings=[v1SafetySetting(category=s.category,method=s.method,threshold=s.threshold)forsinsafety_settings]if(self.cached_contentisnotNone)or(cached_contentisnotNone):selected_cached_content=self.cached_contentorcached_contentfull_cache_name=self._request_from_cached_content(cached_content=selected_cached_content,# type: ignoresystem_instruction=system_instruction,tools=formatted_tools,tool_config=tool_config,)ifself.endpoint_version=="v1":returnGenerateContentRequest(contents=_content_to_v1(contents),model=self.full_model_name,safety_settings=v1_safety_settings,generation_config=generation_config,cached_content=full_cache_name,)returnGenerateContentRequest(contents=contents,model=self.full_model_name,safety_settings=safety_settings,generation_config=generation_config,cached_content=full_cache_name,)ifself.endpoint_version=="v1":returnv1GenerateContentRequest(contents=_content_to_v1(contents),system_instruction=v1_system_instruction,tools=v1_tools,tool_config=v1_tool_config,safety_settings=v1_safety_settings,generation_config=generation_config,model=self.full_model_name,labels=self.labels,)returnGenerateContentRequest(contents=contents,system_instruction=system_instruction,tools=formatted_tools,tool_config=tool_config,safety_settings=safety_settings,generation_config=generation_config,model=self.full_model_name,labels=self.labels,)def_request_from_cached_content(self,cached_content:str,system_instruction:Optional[Content],tools:Optional[Sequence[GapicTool]],tool_config:Optional[Union[_ToolConfigDict,ToolConfig]],)->str:not_allowed_parameters=[("system_instructions",system_instruction),("tools",tools),("tool_config",tool_config),]forparam_name,parameterinnot_allowed_parameters:ifparameter:message=(f"Using cached content. Parameter `{param_name}` will be ignored. ")logger.warning(message)return(f"projects/{self.project}/locations/{self.location}/"f"cachedContents/{cached_content}")def_generate_gemini(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:request=self._prepare_request_gemini(messages=messages,stop=stop,**kwargs)response=_completion_with_retry(self.prediction_client.generate_content,max_retries=self.max_retries,run_manager=run_manager,wait_exponential_kwargs=self.wait_exponential_kwargs,request=request,metadata=self.default_metadata,**kwargs,)returnself._gemini_response_to_chat_result(response)asyncdef_agenerate_gemini(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:response=await_acompletion_with_retry(self.async_prediction_client.generate_content,max_retries=self.max_retries,run_manager=run_manager,wait_exponential_kwargs=self.wait_exponential_kwargs,request=self._prepare_request_gemini(messages=messages,stop=stop,**kwargs),is_gemini=True,metadata=self.default_metadata,**kwargs,)returnself._gemini_response_to_chat_result(response)
[docs]defget_num_tokens(self,text:str)->int:"""Get the number of tokens present in the text."""ifself._is_gemini_model:# https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1beta1#counttokensrequest_,contents=_parse_chat_history_gemini([HumanMessage(content=text)],self._image_bytes_loader_client,perform_literal_eval_on_string_raw_content=self.perform_literal_eval_on_string_raw_content,)response=self.prediction_client.count_tokens(# type: ignore[union-attr]{"endpoint":self.full_model_name,"model":self.full_model_name,"contents":contents,})returnresponse.total_tokenselse:returnself.client_preview.start_chat().count_tokens(text)
def_tools_gemini(self,tools:Optional[_ToolsType]=None,functions:Optional[_ToolsType]=None,)->Optional[List[GapicTool]]:iftoolsandfunctions:logger.warning("Binding tools and functions together is not supported.","Only tools will be used",)iftools:return[_format_to_gapic_tool(tools)]iffunctions:return[_format_to_gapic_tool(functions)]returnNonedef_tool_config_gemini(self,tool_config:Optional[Union[_ToolConfigDict,ToolConfig]]=None)->Optional[GapicToolConfig]:iftool_configandnotisinstance(tool_config,ToolConfig):return_format_tool_config(cast(_ToolConfigDict,tool_config))returnNonedef_generate_non_gemini(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,**kwargs:Any,)->ChatResult:kwargs.pop("safety_settings",None)params=self._prepare_params(stop=stop,stream=False,**kwargs)question=_get_question(messages)history=_parse_chat_history(messages[:-1])examples=kwargs.get("examples")orself.examplesmsg_params={}if"candidate_count"inparams:msg_params["candidate_count"]=params.pop("candidate_count")ifexamples:params["examples"]=_parse_examples(examples)withtelemetry.tool_context_manager(self._user_agent):chat=self._start_chat(history,**params)response=_completion_with_retry(chat.send_message,max_retries=self.max_retries,message=question.content,**msg_params,)usage_metadata=response.raw_prediction_response.metadatalc_usage=_get_usage_metadata_non_gemini(usage_metadata)generations=[ChatGeneration(message=AIMessage(content=candidate.text,usage_metadata=lc_usage),generation_info=get_generation_info(candidate,self._is_gemini_model,usage_metadata=usage_metadata,),)forcandidateinresponse.candidates]returnChatResult(generations=generations)asyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,stream:Optional[bool]=None,**kwargs:Any,)->ChatResult:"""Asynchronously generate next turn in the conversation. Args: messages: The history of the conversation as a list of messages. Code chat does not support context. stop: The list of stop words (optional). run_manager: The CallbackManager for LLM run, it's not used at the moment. Returns: The ChatResult that contains outputs generated by the model. Raises: ValueError: if the last message in the list is not from human. """should_stream=streamisTrueor(streamisNoneandself.streaming)ifnotself._is_gemini_model:ifshould_stream:logger.warning("ChatVertexAI does not currently support async streaming.")returnawaitself._agenerate_non_gemini(messages,stop=stop,**kwargs)ifshould_stream:stream_iter=self._astream(messages,stop=stop,run_manager=run_manager,**kwargs)returnawaitagenerate_from_stream(stream_iter)returnawaitself._agenerate_gemini(messages=messages,stop=stop,run_manager=run_manager,**kwargs,)asyncdef_agenerate_non_gemini(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,**kwargs:Any,)->ChatResult:kwargs.pop("safety_settings",None)params=self._prepare_params(stop=stop,stream=False,**kwargs)question=_get_question(messages)history=_parse_chat_history(messages[:-1])examples=kwargs.get("examples")orself.examplesmsg_params={}if"candidate_count"inparams:msg_params["candidate_count"]=params.pop("candidate_count")ifexamples:params["examples"]=_parse_examples(examples)withtelemetry.tool_context_manager(self._user_agent):chat=self._start_chat(history,**params)response=await_acompletion_with_retry(chat.send_message_async,message=question.content,max_retries=self.max_retries,**msg_params,)usage_metadata=response.raw_prediction_response.metadatalc_usage=_get_usage_metadata_non_gemini(usage_metadata)generations=[ChatGeneration(message=AIMessage(content=candidate.text,usage_metadata=lc_usage),generation_info=get_generation_info(candidate,self._is_gemini_model,usage_metadata=usage_metadata,),)forcandidateinresponse.candidates]returnChatResult(generations=generations)def_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:ifnotself._is_gemini_model:yield fromself._stream_non_gemini(messages,stop=stop,run_manager=run_manager,**kwargs)returnyield fromself._stream_gemini(messages=messages,stop=stop,run_manager=run_manager,**kwargs)returndef_stream_gemini(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:request=self._prepare_request_gemini(messages=messages,stop=stop,**kwargs)response_iter=_completion_with_retry(self.prediction_client.stream_generate_content,max_retries=self.max_retries,run_manager=run_manager,wait_exponential_kwargs=self.wait_exponential_kwargs,request=request,is_gemini=True,metadata=self.default_metadata,**kwargs,)total_lc_usage=Noneforresponse_chunkinresponse_iter:chunk,total_lc_usage=self._gemini_chunk_to_generation_chunk(response_chunk,prev_total_usage=total_lc_usage)ifrun_managerandisinstance(chunk.message.content,str):run_manager.on_llm_new_token(chunk.message.content)yieldchunkdef_stream_non_gemini(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:params=self._prepare_params(stop=stop,stream=True,**kwargs)question=_get_question(messages)history=_parse_chat_history(messages[:-1])examples=kwargs.get("examples",None)ifexamples:params["examples"]=_parse_examples(examples)withtelemetry.tool_context_manager(self._user_agent):chat=self._start_chat(history,**params)responses=chat.send_message_streaming(question.content,**params)forresponseinresponses:ifrun_manager:run_manager.on_llm_new_token(response.text)yieldChatGenerationChunk(message=AIMessageChunk(content=response.text),generation_info=get_generation_info(response,self._is_gemini_model,usage_metadata=response.raw_prediction_response.metadata,),)asyncdef_astream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->AsyncIterator[ChatGenerationChunk]:# TODO: Update to properly support async streaming from gemini.ifnotself._is_gemini_model:asyncforchunkinsuper()._astream(messages,stop=stop,run_manager=run_manager,**kwargs):yieldchunkreturnrequest=self._prepare_request_gemini(messages=messages,stop=stop,**kwargs)response_iter=_acompletion_with_retry(self.async_prediction_client.stream_generate_content,max_retries=self.max_retries,run_manager=run_manager,wait_exponential_kwargs=self.wait_exponential_kwargs,request=request,is_gemini=True,metadata=self.default_metadata,**kwargs,)total_lc_usage=Noneasyncforresponse_chunkinawaitresponse_iter:chunk,total_lc_usage=self._gemini_chunk_to_generation_chunk(response_chunk,prev_total_usage=total_lc_usage)ifrun_managerandisinstance(chunk.message.content,str):awaitrun_manager.on_llm_new_token(chunk.message.content)yieldchunk
[docs]defwith_structured_output(self,schema:Union[Dict,Type[BaseModel],Type],*,include_raw:bool=False,method:Optional[Literal["json_mode"]]=None,**kwargs:Any,)->Runnable[LanguageModelInput,Union[Dict,BaseModel]]:"""Model wrapper that returns outputs formatted to match the given schema. .. versionchanged:: 1.1.0 Return type corrected in version 1.1.0. Previously if a dict schema was provided then the output had the form ``[{"args": {}, "name": "schema_name"}]`` where the output was a list with a single dict and the "args" of the one dict corresponded to the schema. As of `1.1.0` this has been fixed so that the schema (the value corresponding to the old "args" key) is returned directly. Args: schema: The output schema as a dict or a Pydantic class. If a Pydantic class then the model output will be an object of that class. If a dict then the model output will be a dict. With a Pydantic class the returned attributes will be validated, whereas with a dict they will not be. If `method` is "function_calling" and `schema` is a dict, then the dict must match the OpenAI function-calling spec. include_raw: If False then only the parsed structured output is returned. If an error occurs during model output parsing it will be raised. If True then both the raw model response (a BaseMessage) and the parsed model response will be returned. If an error occurs during output parsing it will be caught and returned as well. The final output is always a dict with keys "raw", "parsed", and "parsing_error". method: If set to 'json_schema' it will use controlled genetration to generate the response rather than function calling. Does not work with schemas with references or Pydantic models with self-references. Returns: A Runnable that takes any ChatModel input. If include_raw is True then a dict with keys — raw: BaseMessage, parsed: Optional[_DictOrPydantic], parsing_error: Optional[BaseException]. If include_raw is False then just _DictOrPydantic is returned, where _DictOrPydantic depends on the schema. If schema is a Pydantic class then _DictOrPydantic is the Pydantic class. If schema is a dict then _DictOrPydantic is a dict. Example: Pydantic schema, exclude raw: .. code-block:: python from pydantic import BaseModel from langchain_google_vertexai import ChatVertexAI class AnswerWithJustification(BaseModel): '''An answer to the user question along with justification for the answer.''' answer: str justification: str llm = ChatVertexAI(model_name="gemini-pro", temperature=0) structured_llm = llm.with_structured_output(AnswerWithJustification) structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") # -> AnswerWithJustification( # answer='They weigh the same.', justification='A pound is a pound.' # ) Example: Pydantic schema, include raw: .. code-block:: python from pydantic import BaseModel from langchain_google_vertexai import ChatVertexAI class AnswerWithJustification(BaseModel): '''An answer to the user question along with justification for the answer.''' answer: str justification: str llm = ChatVertexAI(model_name="gemini-pro", temperature=0) structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True) structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") # -> { # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}), # 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'), # 'parsing_error': None # } Example: Dict schema, exclude raw: .. code-block:: python from pydantic import BaseModel from langchain_core.utils.function_calling import convert_to_openai_function from langchain_google_vertexai import ChatVertexAI class AnswerWithJustification(BaseModel): '''An answer to the user question along with justification for the answer.''' answer: str justification: str dict_schema = convert_to_openai_function(AnswerWithJustification) llm = ChatVertexAI(model_name="gemini-pro", temperature=0) structured_llm = llm.with_structured_output(dict_schema) structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") # -> { # 'answer': 'They weigh the same', # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.' # } """# noqa: E501_=kwargs.pop("strict",None)ifkwargs:raiseValueError(f"Received unsupported arguments {kwargs}")parser:OutputParserLikeifmethod=="json_mode":schema_is_typeddict=is_typeddict(schema)ifisinstance(schema,type)andnotschema_is_typeddict:# TODO: This gets the json schema of a pydantic model. It fails for# nested models because the generated schema contains $refs that the# gemini api doesn't support. We can implement a postprocessing function# that takes care of this if necessary.ifissubclass(schema,BaseModelV1):schema_json=schema.schema()else:schema_json=schema.model_json_schema()schema_json=replace_defs_in_schema(schema_json)parser=PydanticOutputParser(pydantic_object=schema)else:ifschema_is_typeddict:schema_json=convert_to_json_schema(schema)else:schema_json=cast(dict,schema)parser=JsonOutputParser()llm=self.bind(response_mime_type="application/json",response_schema=schema_json,ls_structured_output_format={"kwargs":{"method":method},"schema":convert_to_json_schema(schema),},)else:tool_name=_get_tool_name(schema)ifisinstance(schema,type)andis_basemodel_subclass(schema):parser=PydanticToolsParser(tools=[schema],first_tool_only=True)else:parser=JsonOutputKeyToolsParser(key_name=tool_name,first_tool_only=True)tool_choice=tool_nameifself._is_gemini_advancedelseNonetry:llm=self.bind_tools([schema],tool_choice=tool_choice,ls_structured_output_format={"kwargs":{"method":"function_calling"},"schema":convert_to_openai_tool(schema),},)exceptException:llm=self.bind_tools([schema],tool_choice=tool_choice)ifinclude_raw:parser_with_fallback=RunnablePassthrough.assign(parsed=itemgetter("raw")|parser,parsing_error=lambda_:None).with_fallbacks([RunnablePassthrough.assign(parsed=lambda_:None)],exception_key="parsing_error",)return{"raw":llm}|parser_with_fallbackelse:returnllm|parser
[docs]defbind_tools(self,tools:_ToolsType,tool_config:Optional[_ToolConfigDict]=None,*,tool_choice:Optional[Union[_ToolChoiceType,bool]]=None,**kwargs:Any,)->Runnable[LanguageModelInput,BaseMessage]:"""Bind tool-like objects to this chat model. Assumes model is compatible with Vertex tool-calling API. Args: tools: A list of tool definitions to bind to this chat model. Can be a pydantic model, callable, or BaseTool. Pydantic models, callables, and BaseTools will be automatically converted to their schema dictionary representation. **kwargs: Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. """iftool_choiceandtool_config:raiseValueError("Must specify at most one of tool_choice and tool_config, received "f"both:\n\n{tool_choice=}\n\n{tool_config=}")try:formatted_tools=[convert_to_openai_tool(tool)fortoolintools]# type: ignore[arg-type]exceptException:formatted_tools=[_format_to_gapic_tool(tools)]iftool_choice:kwargs["tool_choice"]=tool_choiceeliftool_config:kwargs["tool_config"]=tool_configelse:passreturnself.bind(tools=formatted_tools,**kwargs)
def_start_chat(self,history:_ChatHistory,**kwargs:Any)->Union[ChatSession,CodeChatSession]:ifself.model_family==GoogleModelFamily.CODEY:returnself.client.start_chat(context=history.context,message_history=history.history,**kwargs)else:returnself.client.start_chat(message_history=history.history,**kwargs)def_gemini_response_to_chat_result(self,response:GenerationResponse)->ChatResult:generations=[]usage=proto.Message.to_dict(response.usage_metadata)lc_usage=_get_usage_metadata_gemini(usage)logprobs=self.logprobsifisinstance(self.logprobs,(int,bool))elseFalseforcandidateinresponse.candidates:info=get_generation_info(candidate,is_gemini=True,usage_metadata=usage,logprobs=logprobs)message=_parse_response_candidate(candidate)message.response_metadata["model_name"]=self.model_nameifisinstance(message,AIMessage):message.usage_metadata=lc_usagegenerations.append(ChatGeneration(message=message,generation_info=info))ifnotresponse.candidates:message=AIMessage(content="")message.response_metadata["model_name"]=self.model_nameifusage:generation_info={"usage_metadata":usage}message.usage_metadata=lc_usageelse:generation_info={}generations.append(ChatGeneration(message=message,generation_info=generation_info))returnChatResult(generations=generations)def_gemini_chunk_to_generation_chunk(self,response_chunk:GenerationResponse,prev_total_usage:Optional[UsageMetadata]=None,)->Tuple[ChatGenerationChunk,Optional[UsageMetadata]]:# return an empty completion message if there's no candidatesusage_metadata=proto.Message.to_dict(response_chunk.usage_metadata)# Gather langchain (standard) usage metadata# Note: some models (e.g., gemini-1.5-pro with image inputs) return# cumulative sums of token counts.total_lc_usage=_get_usage_metadata_gemini(usage_metadata)iftotal_lc_usageandprev_total_usage:lc_usage:Optional[UsageMetadata]=UsageMetadata(input_tokens=total_lc_usage["input_tokens"]-prev_total_usage["input_tokens"],output_tokens=total_lc_usage["output_tokens"]-prev_total_usage["output_tokens"],total_tokens=total_lc_usage["total_tokens"]-prev_total_usage["total_tokens"],)else:lc_usage=total_lc_usageifnotresponse_chunk.candidates:message=AIMessageChunk(content="")iflc_usage:message.usage_metadata=lc_usagegeneration_info={}else:top_candidate=response_chunk.candidates[0]message=_parse_response_candidate(top_candidate,streaming=True)iflc_usage:message.usage_metadata=lc_usagegeneration_info=get_generation_info(top_candidate,is_gemini=True,usage_metadata={},)# add model name if final chunkifgeneration_info.get("finish_reason"):message.response_metadata["model_name"]=self.model_name# is_blocked is part of "safety_ratings" list# but if it's True/False then chunks can't be margedgeneration_info.pop("is_blocked",None)returnChatGenerationChunk(message=message,generation_info=generation_info,),total_lc_usage
def_get_usage_metadata_gemini(raw_metadata:dict)->Optional[UsageMetadata]:"""Get UsageMetadata from raw response metadata."""input_tokens=raw_metadata.get("prompt_token_count",0)output_tokens=raw_metadata.get("candidates_token_count",0)total_tokens=raw_metadata.get("total_token_count",0)ifall(count==0forcountin[input_tokens,output_tokens,total_tokens]):returnNoneelse:returnUsageMetadata(input_tokens=input_tokens,output_tokens=output_tokens,total_tokens=total_tokens,)def_get_usage_metadata_non_gemini(raw_metadata:dict)->Optional[UsageMetadata]:"""Get UsageMetadata from raw response metadata."""token_usage=raw_metadata.get("tokenMetadata",{})input_tokens=token_usage.get("inputTokenCount",{}).get("totalTokens",0)output_tokens=token_usage.get("outputTokenCount",{}).get("totalTokens",0)ifinput_tokens==0andoutput_tokens==0:returnNoneelse:returnUsageMetadata(input_tokens=input_tokens,output_tokens=output_tokens,total_tokens=input_tokens+output_tokens,)def_get_tool_name(tool:_ToolType)->str:vertexai_tool=_format_to_gapic_tool([tool])return[f.nameforfinvertexai_tool.function_declarations][0]