[docs]classVertexAIModelGarden(_BaseVertexAIModelGarden,BaseLLM):"""Large language models served from Vertex AI Model Garden."""classConfig:"""Configuration for this pydantic object."""allow_population_by_field_name=True# Needed so that mypy doesn't flag missing aliased init args.def__init__(self,**kwargs:Any)->None:super().__init__(**kwargs)def_generate(self,prompts:List[str],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->LLMResult:"""Run the LLM on the given prompt and input."""instances=self._prepare_request(prompts,**kwargs)ifself.single_example_per_requestandlen(instances)>1:results=[]forinstanceininstances:response=self.client.predict(endpoint=self.endpoint_path,instances=[instance])results.append(self._parse_prediction(response.predictions[0]))returnLLMResult(generations=[[Generation(text=result)]forresultinresults])response=self.client.predict(endpoint=self.endpoint_path,instances=instances)returnself._parse_response(response)asyncdef_agenerate(self,prompts:List[str],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->LLMResult:"""Run the LLM on the given prompt and input."""instances=self._prepare_request(prompts,**kwargs)ifself.single_example_per_requestandlen(instances)>1:responses=[]forinstanceininstances:responses.append(self.async_client.predict(endpoint=self.endpoint_path,instances=[instance]))responses=awaitasyncio.gather(*responses)returnLLMResult(generations=[[Generation(text=self._parse_prediction(response.predictions[0]))]forresponseinresponses])response=awaitself.async_client.predict(endpoint=self.endpoint_path,instances=instances)returnself._parse_response(response)
[docs]classChatAnthropicVertex(_VertexAICommon,BaseChatModel):async_client:Any=None#: :meta private:model_name:Optional[str]=Field(default=None,alias="model")# type: ignore[assignment]"Underlying model name."max_output_tokens:int=Field(default=1024,alias="max_tokens")access_token:Optional[str]=Nonestream_usage:bool=True# Whether to include usage metadata in streaming outputcredentials:Optional[Credentials]=NoneclassConfig:"""Configuration for this pydantic object."""allow_population_by_field_name=True# Needed so that mypy doesn't flag missing aliased init args.def__init__(self,**kwargs:Any)->None:super().__init__(**kwargs)@root_validator(pre=False,skip_on_failure=True)defvalidate_environment(cls,values:Dict)->Dict:fromanthropicimport(# type: ignoreAnthropicVertex,AsyncAnthropicVertex,)values["client"]=AnthropicVertex(project_id=values["project"],region=values["location"],max_retries=values["max_retries"],access_token=values["access_token"],credentials=values["credentials"],)values["async_client"]=AsyncAnthropicVertex(project_id=values["project"],region=values["location"],max_retries=values["max_retries"],access_token=values["access_token"],credentials=values["credentials"],)returnvalues@propertydef_default_params(self):return{"model":self.model_name,"max_tokens":self.max_output_tokens,"temperature":self.temperature,"top_k":self.top_k,"top_p":self.top_p,}def_format_params(self,*,messages:List[BaseMessage],stop:Optional[List[str]]=None,**kwargs:Any,)->Dict[str,Any]:system_message,formatted_messages=_format_messages_anthropic(messages)params=self._default_paramsparams.update(kwargs)ifkwargs.get("model_name"):params["model"]=params["model_name"]ifkwargs.get("model"):params["model"]=kwargs["model"]params.pop("model_name",None)params.update({"system":system_message,"messages":formatted_messages,"stop_sequences":stop,})return{k:vfork,vinparams.items()ifvisnotNone}def_format_output(self,data:Any,**kwargs:Any)->ChatResult:data_dict=data.model_dump()content=[cforcindata_dict["content"]ifc["type"]!="tool_use"]content=content[0]["text"]iflen(content)==1elsecontentllm_output={k:vfork,vindata_dict.items()ifknotin("content","role","type")}tool_calls=_extract_tool_calls(data_dict["content"])iftool_calls:msg=AIMessage(content=content,tool_calls=tool_calls)else:msg=AIMessage(content=content)# Collect token usagemsg.usage_metadata={"input_tokens":data.usage.input_tokens,"output_tokens":data.usage.output_tokens,"total_tokens":data.usage.input_tokens+data.usage.output_tokens,}returnChatResult(generations=[ChatGeneration(message=msg)],llm_output=llm_output,)def_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:params=self._format_params(messages=messages,stop=stop,**kwargs)ifself.streaming:stream_iter=self._stream(messages,stop=stop,run_manager=run_manager,**kwargs)returngenerate_from_stream(stream_iter)data=self.client.messages.create(**params)returnself._format_output(data,**kwargs)asyncdef_agenerate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:params=self._format_params(messages=messages,stop=stop,**kwargs)ifself.streaming:stream_iter=self._astream(messages,stop=stop,run_manager=run_manager,**kwargs)returnawaitagenerate_from_stream(stream_iter)data=awaitself.async_client.messages.create(**params)returnself._format_output(data,**kwargs)@propertydef_llm_type(self)->str:"""Return type of chat model."""return"anthropic-chat-vertexai"def_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,*,stream_usage:Optional[bool]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:ifstream_usageisNone:stream_usage=self.stream_usageparams=self._format_params(messages=messages,stop=stop,**kwargs)stream=self.client.messages.create(**params,stream=True)coerce_content_to_string=not_tools_in_params(params)foreventinstream:msg=_make_message_chunk_from_anthropic_event(event,stream_usage=stream_usage,coerce_content_to_string=coerce_content_to_string,)ifmsgisnotNone:chunk=ChatGenerationChunk(message=msg)ifrun_managerandisinstance(msg.content,str):run_manager.on_llm_new_token(msg.content,chunk=chunk)yieldchunkasyncdef_astream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,*,stream_usage:Optional[bool]=None,**kwargs:Any,)->AsyncIterator[ChatGenerationChunk]:ifstream_usageisNone:stream_usage=self.stream_usageparams=self._format_params(messages=messages,stop=stop,**kwargs)stream=awaitself.async_client.messages.create(**params,stream=True)coerce_content_to_string=not_tools_in_params(params)asyncforeventinstream:msg=_make_message_chunk_from_anthropic_event(event,stream_usage=stream_usage,coerce_content_to_string=coerce_content_to_string,)ifmsgisnotNone:chunk=ChatGenerationChunk(message=msg)ifrun_managerandisinstance(msg.content,str):awaitrun_manager.on_llm_new_token(msg.content,chunk=chunk)yieldchunk
[docs]defbind_tools(self,tools:Sequence[Union[Dict[str,Any],Type[BaseModel],Callable,BaseTool]],*,tool_choice:Optional[Union[Dict[str,str],Literal["any","auto"],str]]=None,**kwargs:Any,)->Runnable[LanguageModelInput,BaseMessage]:"""Bind tool-like objects to this chat model"""formatted_tools=[convert_to_anthropic_tool(tool)fortoolintools]ifnottool_choice:passelifisinstance(tool_choice,dict):kwargs["tool_choice"]=tool_choiceelifisinstance(tool_choice,str)andtool_choicein("any","auto"):kwargs["tool_choice"]={"type":tool_choice}elifisinstance(tool_choice,str):kwargs["tool_choice"]={"type":"tool","name":tool_choice}else:raiseValueError(f"Unrecognized 'tool_choice' type {tool_choice=}. Expected dict, "f"str, or None.")returnself.bind(tools=formatted_tools,**kwargs)
[docs]defwith_structured_output(self,schema:Union[Dict,Type[BaseModel]],*,include_raw:bool=False,**kwargs:Any,)->Runnable[LanguageModelInput,Union[Dict,BaseModel]]:"""Model wrapper that returns outputs formatted to match the given schema."""tool_name=convert_to_anthropic_tool(schema)["name"]llm=self.bind_tools([schema],tool_choice=tool_name)ifisinstance(schema,type)andissubclass(schema,BaseModel):output_parser=ToolsOutputParser(first_tool_only=True,pydantic_schemas=[schema])else:output_parser=ToolsOutputParser(first_tool_only=True,args_only=True)ifinclude_raw:parser_assign=RunnablePassthrough.assign(parsed=itemgetter("raw")|output_parser,parsing_error=lambda_:None)parser_none=RunnablePassthrough.assign(parsed=lambda_:None)parser_with_fallback=parser_assign.with_fallbacks([parser_none],exception_key="parsing_error")returnRunnableMap(raw=llm)|parser_with_fallbackelse:returnllm|output_parser