from__future__importannotationsimportrefromconcurrent.futuresimportExecutorfromtypingimportAny,Callable,ClassVar,Dict,List,Optional,Sequence,Tupleimportvertexai# type: ignore[import-untyped]fromgoogle.api_core.client_optionsimportClientOptionsfromgoogle.cloud.aiplatformimportinitializerfromgoogle.cloud.aiplatform.constantsimportbaseasconstantsfromgoogle.cloud.aiplatform.gapicimport(PredictionServiceAsyncClient,PredictionServiceClient,)fromgoogle.cloud.aiplatform.modelsimportPredictionfromgoogle.cloud.aiplatform_v1beta1.services.prediction_serviceimport(PredictionServiceAsyncClientasv1beta1PredictionServiceAsyncClient,)fromgoogle.cloud.aiplatform_v1beta1.services.prediction_serviceimport(PredictionServiceClientasv1beta1PredictionServiceClient,)fromgoogle.protobufimportjson_formatfromgoogle.protobuf.struct_pb2importValuefromlangchain_core.outputsimportGeneration,LLMResultfromlangchain_core.pydantic_v1importBaseModel,Field,root_validatorfromvertexai.generative_models._generative_modelsimport(# type: ignoreSafetySettingsType,)fromvertexai.language_modelsimport(# type: ignore[import-untyped]TextGenerationModel,)fromvertexai.preview.language_modelsimport(# type: ignoreChatModelasPreviewChatModel,)fromvertexai.preview.language_modelsimport(CodeChatModelasPreviewCodeChatModel,)fromlangchain_google_vertexai._utilsimport(GoogleModelFamily,get_client_info,get_user_agent,is_gemini_model,)_PALM_DEFAULT_MAX_OUTPUT_TOKENS=TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS_PALM_DEFAULT_TEMPERATURE=0.0_PALM_DEFAULT_TOP_P=0.95_PALM_DEFAULT_TOP_K=40_DEFAULT_LOCATION="us-central1"class_VertexAIBase(BaseModel):client:Any=None#: :meta private:async_client:Any=None#: :meta private:project:Optional[str]=None"The default GCP project to use when making Vertex API calls."location:str=Field(default=_DEFAULT_LOCATION)"The default location to use when making API calls."request_parallelism:int=5"The amount of parallelism allowed for requests issued to VertexAI models. ""Default is 5."max_retries:int=6"""The maximum number of retries to make when generating."""task_executor:ClassVar[Optional[Executor]]=Field(default=None,exclude=True)stop:Optional[List[str]]=Field(default=None,alias="stop_sequences")"Optional list of stop words to use when generating."model_name:Optional[str]=Field(default=None,alias="model")"Underlying model name."full_model_name:Optional[str]=None#: :meta private:"The full name of the model's endpoint."client_options:Optional["ClientOptions"]=Field(default=None,exclude=True)#: :meta private:api_endpoint:Optional[str]=Field(default=None,alias="base_url")"Desired API endpoint, e.g., us-central1-aiplatform.googleapis.com"api_transport:Optional[str]=None"""The desired API transport method, can be either 'grpc' or 'rest'. Uses the default parameter in vertexai.init if defined. """default_metadata:Sequence[Tuple[str,str]]=Field(default_factory=list)#: :meta private:additional_headers:Optional[Dict[str,str]]=Field(default=None)"A key-value dictionary representing additional headers for the model call"client_cert_source:Optional[Callable[[],Tuple[bytes,bytes]]]=None"A callback which returns client certificate bytes and private key bytes both ""in PEM format."credentials:Any=Field(default=None,exclude=True)"The default custom credentials (google.auth.credentials.Credentials) to use ""when making API calls. If not provided, credentials will be ascertained from ""the environment."classConfig:"""Configuration for this pydantic object."""allow_population_by_field_name=Truearbitrary_types_allowed=True@root_validator(pre=True)defvalidate_params_base(cls,values:dict)->dict:if"model"invaluesand"model_name"notinvalues:values["model_name"]=values.pop("model")ifvalues.get("project")isNone:values["project"]=initializer.global_config.projectifvalues.get("api_transport")isNone:values["api_transport"]=initializer.global_config._api_transportifvalues.get("api_endpoint"):api_endpoint=values["api_endpoint"]else:location=values.get("location",cls.__fields__["location"].default)api_endpoint=f"{location}-{constants.PREDICTION_API_BASE_PATH}"client_options=ClientOptions(api_endpoint=api_endpoint)ifvalues.get("client_cert_source"):client_options.client_cert_source=values["client_cert_source"]values["client_options"]=client_optionsadditional_headers=values.get("additional_headers",{})values["default_metadata"]=tuple(additional_headers.items())returnvalues@propertydefprediction_client(self)->v1beta1PredictionServiceClient:"""Returns PredictionServiceClient."""ifself.clientisNone:self.client=v1beta1PredictionServiceClient(credentials=self.credentials,client_options=self.client_options,client_info=get_client_info(module=self._user_agent),transport=self.api_transport,)returnself.client@propertydefasync_prediction_client(self)->v1beta1PredictionServiceAsyncClient:"""Returns PredictionServiceClient."""ifself.async_clientisNone:async_client_kwargs:dict[str,Any]=dict(client_options=self.client_options,client_info=get_client_info(module=self._user_agent),credentials=self.credentials,)ifself.api_transportisnotNone:async_client_kwargs["transport"]=self.api_transportself.async_client=v1beta1PredictionServiceAsyncClient(**async_client_kwargs)returnself.async_client@propertydef_user_agent(self)->str:"""Gets the User Agent."""_,user_agent=get_user_agent(f"{type(self).__name__}_{self.model_name}")returnuser_agent@propertydef_library_version(self)->str:"""Gets the library version for headers."""library_version,_=get_user_agent(f"{type(self).__name__}_{self.model_name}")returnlibrary_versionclass_VertexAICommon(_VertexAIBase):client_preview:Any=None#: :meta private:model_name:str=Field(default=None,alias="model")"Underlying model name."temperature:Optional[float]=None"Sampling temperature, it controls the degree of randomness in token selection."max_output_tokens:Optional[int]=Field(default=None,alias="max_tokens")"Token limit determines the maximum amount of text output from one prompt."top_p:Optional[float]=None"Tokens are selected from most probable to least until the sum of their ""probabilities equals the top-p value. Top-p is ignored for Codey models."top_k:Optional[int]=None"How the model selects tokens for output, the next token is selected from ""among the top-k most probable tokens. Top-k is ignored for Codey models."n:int=1"""How many completions to generate for each prompt."""streaming:bool=False"""Whether to stream the results or not."""model_family:Optional[GoogleModelFamily]=None#: :meta private:safety_settings:Optional["SafetySettingsType"]=None"""The default safety settings to use for all generations. For example: from langchain_google_vertexai import HarmBlockThreshold, HarmCategory safety_settings = { HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, } """# noqa: E501tuned_model_name:Optional[str]=None"""The name of a tuned model. If tuned_model_name is passed model_name will be used to determine the model family """@propertydef_is_gemini_model(self)->bool:returnis_gemini_model(self.model_family)# type: ignore[arg-type]@propertydef_llm_type(self)->str:return"vertexai"@propertydef_identifying_params(self)->Dict[str,Any]:"""Gets the identifying parameters."""return{**{"model_name":self.model_name},**self._default_params}@propertydef_default_params(self)->Dict[str,Any]:ifself.model_family==GoogleModelFamily.GEMINI:default_params:Dict[str,Any]={}elifself.model_family==GoogleModelFamily.GEMINI_ADVANCED:default_params={}else:default_params={"temperature":_PALM_DEFAULT_TEMPERATURE,"max_output_tokens":_PALM_DEFAULT_MAX_OUTPUT_TOKENS,"top_p":_PALM_DEFAULT_TOP_P,"top_k":_PALM_DEFAULT_TOP_K,}params={"temperature":self.temperature,"max_output_tokens":self.max_output_tokens,"candidate_count":self.n,}ifnotself.model_family==GoogleModelFamily.CODEY:params.update({"top_k":self.top_k,"top_p":self.top_p,})updated_params={}forparam_name,param_valueinparams.items():default_value=default_params.get(param_name)ifparam_valueisnotNoneordefault_valueisnotNone:updated_params[param_name]=(param_valueifparam_valueisnotNoneelsedefault_value)returnupdated_params@classmethoddef_init_vertexai(cls,values:Dict)->None:vertexai.init(project=values.get("project"),location=values.get("location"),credentials=values.get("credentials"),api_transport=values.get("api_transport"),api_endpoint=values.get("api_endpoint"),request_metadata=values.get("default_metadata"),)returnNonedef_prepare_params(self,stop:Optional[List[str]]=None,stream:bool=False,**kwargs:Any,)->dict:stop_sequences=stoporself.stopparams_mapping={"n":"candidate_count"}params={params_mapping.get(k,k):vfork,vinkwargs.items()}params={**self._default_params,"stop_sequences":stop_sequences,**params}ifstreamorself.streaming:params.pop("candidate_count")returnparamsdefget_num_tokens(self,text:str)->int:"""Get the number of tokens present in the text. Useful for checking if an input will fit in a model's context window. Args: text: The string input to tokenize. Returns: The integer number of tokens in the text. """is_palm_chat_model=isinstance(self.client_preview,PreviewChatModel)orisinstance(self.client_preview,PreviewCodeChatModel)ifis_palm_chat_model:result=self.client_preview.start_chat().count_tokens(text)else:result=self.client_preview.count_tokens([text])returnresult.total_tokensclass_BaseVertexAIModelGarden(_VertexAIBase):"""Large language models served from Vertex AI Model Garden."""async_client:Any=None#: :meta private:endpoint_id:str"A name of an endpoint where the model has been deployed."allowed_model_args:Optional[List[str]]=None"Allowed optional args to be passed to the model."prompt_arg:str="prompt"result_arg:Optional[str]="generated_text""Set result_arg to None if output of the model is expected to be a string.""Otherwise, if it's a dict, provided an argument that contains the result."single_example_per_request:bool=True"LLM endpoint currently serves only the first example in the request"@root_validator(pre=False,skip_on_failure=True)defvalidate_environment(cls,values:Dict)->Dict:"""Validate that the python package exists in environment."""ifnotvalues["project"]:raiseValueError("A GCP project should be provided to run inference on Model Garden!")client_options=ClientOptions(api_endpoint=f"{values['location']}-aiplatform.googleapis.com")client_info=get_client_info(module="vertex-ai-model-garden")values["client"]=PredictionServiceClient(client_options=client_options,client_info=client_info)values["async_client"]=PredictionServiceAsyncClient(client_options=client_options,client_info=client_info)returnvalues@propertydefendpoint_path(self)->str:returnself.client.endpoint_path(project=self.project,location=self.location,endpoint=self.endpoint_id)@propertydef_llm_type(self)->str:return"vertexai_model_garden"def_prepare_request(self,prompts:List[str],**kwargs:Any)->List["Value"]:instances=[]forpromptinprompts:ifself.allowed_model_args:instance={k:vfork,vinkwargs.items()ifkinself.allowed_model_args}else:instance={}instance[self.prompt_arg]=promptinstances.append(instance)predict_instances=[json_format.ParseDict(instance_dict,Value())forinstance_dictininstances]returnpredict_instancesdef_parse_response(self,predictions:"Prediction")->LLMResult:generations:List[List[Generation]]=[]forresultinpredictions.predictions:ifisinstance(result,str):generations.append([Generation(text=self._parse_prediction(result))])else:generations.append([Generation(text=self._parse_prediction(prediction))forpredictioninresult])returnLLMResult(generations=generations)def_parse_prediction(self,prediction:Any)->str:def_clean_response(response:str)->str:ifresponse.startswith("Prompt:\n"):result=re.search(r"(?s:.*)\nOutput:\n((?s:.*))",response)ifresult:returnresult[1]returnresponseifisinstance(prediction,str):return_clean_response(prediction)ifself.result_arg:try:return_clean_response(prediction[self.result_arg])exceptKeyError:ifisinstance(prediction,str):error_desc=("Provided non-None `result_arg` (result_arg="f"{self.result_arg}). But got prediction of type "f"{type(prediction)} instead of dict. Most probably, you""need to set `result_arg=None` during VertexAIModelGarden ""initialization.")raiseValueError(error_desc)else:raiseValueError(f"{self.result_arg} key not found in prediction!")returnprediction