Source code for langchain_google_vertexai.model_garden_maas._base
importcopyfromenumimportEnum,autofromtypingimport(Any,AsyncContextManager,AsyncIterator,Callable,Dict,List,Optional,Union,)importhttpxfromgoogleimportauthfromgoogle.auth.credentialsimportCredentialsfromgoogle.auth.transportimportrequestsasauth_requestsfromhttpx_sseimport(EventSource,aconnect_sse,connect_sse,)fromlangchain_core.callbacksimport(AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun,)fromlangchain_core.language_models.llmsimportcreate_base_retry_decoratorfrompydanticimportConfigDict,model_validatorfromtyping_extensionsimportSelffromlangchain_google_vertexai._baseimport_VertexAIBase_MISTRAL_MODELS:List[str]=["mistral-nemo@2407","mistral-large@2407","mistral-large-2411@001","mistral-small-2503@001","codestral-2501@001",]_LLAMA_MODELS:List[str]=["meta/llama3-405b-instruct-maas","meta/llama3-70b-instruct-maas","meta/llama3-8b-instruct-maas","meta/llama-3.2-90b-vision-instruct-maas","meta/llama-3.3-70b-instruct-maas",]def_get_token(credentials:Optional[Credentials]=None)->str:"""Returns a valid token for GCP auth."""credentials=(auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])[0]ifnotcredentialselsecredentials)request=auth_requests.Request()credentials.refresh(request)ifnotcredentials.token:raiseValueError("Couldn't retrieve a token!")returncredentials.tokendef_raise_on_error(response:httpx.Response)->None:"""Raise an error if the response is an error."""ifhttpx.codes.is_error(response.status_code):error_message=response.read().decode("utf-8")raisehttpx.HTTPStatusError(f"Error response {response.status_code} "f"while fetching {response.url}: {error_message}",request=response.request,response=response,)asyncdef_araise_on_error(response:httpx.Response)->None:"""Raise an error if the response is an error."""ifhttpx.codes.is_error(response.status_code):error_message=(awaitresponse.aread()).decode("utf-8")raisehttpx.HTTPStatusError(f"Error response {response.status_code} "f"while fetching {response.url}: {error_message}",request=response.request,response=response,)asyncdef_aiter_sse(event_source_mgr:AsyncContextManager[EventSource],)->AsyncIterator[Dict]:"""Iterate over the server-sent events."""asyncwithevent_source_mgrasevent_source:await_araise_on_error(event_source.response)asyncforeventinevent_source.aiter_sse():ifevent.data=="[DONE]":returnyieldevent.json()classVertexMaaSModelFamily(str,Enum):LLAMA=auto()# https://cloud.google.com/blog/products/ai-machine-learning/llama-3-1-on-vertex-aiMISTRAL=auto()# https://cloud.google.com/blog/products/ai-machine-learning/codestral-and-mistral-large-v2-on-vertex-ai@classmethoddef_missing_(cls,value:Any)->"VertexMaaSModelFamily":model_name=value.lower()ifmodel_namein_LLAMA_MODELS:returnVertexMaaSModelFamily.LLAMAifmodel_namein_MISTRAL_MODELS:returnVertexMaaSModelFamily.MISTRALraiseValueError(f"Model {model_name} is not supported yet!")class_BaseVertexMaasModelGarden(_VertexAIBase):append_tools_to_system_message:bool=False"Whether to append tools to the system message or not."model_family:Optional[VertexMaaSModelFamily]=Nonetimeout:int=120model_config=ConfigDict(populate_by_name=True,arbitrary_types_allowed=True,)def__init__(self,**kwargs):super().__init__(**kwargs)token=_get_token(credentials=self.credentials)endpoint=self.get_url()headers={"Content-Type":"application/json","Accept":"application/json","Authorization":f"Bearer {token}","x-goog-api-client":self._library_version,"user_agent":self._user_agent,}self.client=httpx.Client(base_url=endpoint,headers=headers,timeout=self.timeout,)self.async_client=httpx.AsyncClient(base_url=endpoint,headers=headers,timeout=self.timeout,)@model_validator(mode="after")defvalidate_environment_model_garden(self)->Self:"""Validate that the python package exists in environment."""family=VertexMaaSModelFamily(self.model_name)self.model_family=familyiffamily==VertexMaaSModelFamily.MISTRAL:model=self.model_name.split("@")[0]ifself.model_nameelseNoneself.full_model_name=self.model_nameself.model_name=modelreturnselfdef_enrich_params(self,params:Dict[str,Any])->Dict[str,Any]:"""Fix params to be compliant with Vertex AI."""copy_params=copy.deepcopy(params)_=copy_params.pop("safe_prompt",None)copy_params["model"]=self.model_namereturncopy_paramsdef_get_url_part(self,stream:bool=False)->str:ifself.model_family==VertexMaaSModelFamily.MISTRAL:ifstream:return(f"publishers/mistralai/models/{self.full_model_name}"":streamRawPredict")returnf"publishers/mistralai/models/{self.full_model_name}:rawPredict"return"endpoints/openapi/chat/completions"defget_url(self)->str:ifself.model_family==VertexMaaSModelFamily.LLAMA:version="v1beta1"else:version="v1"return(f"https://{self.location}-aiplatform.googleapis.com/{version}/projects/"f"{self.project}/locations/{self.location}")def_create_retry_decorator(llm:_BaseVertexMaasModelGarden,run_manager:Optional[Union[AsyncCallbackManagerForLLMRun,CallbackManagerForLLMRun]]=None,)->Callable[[Any],Any]:"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""errors=[httpx.RequestError,httpx.StreamError]returncreate_base_retry_decorator(error_types=errors,max_retries=llm.max_retries,run_manager=run_manager)
[docs]asyncdefacompletion_with_retry(llm:_BaseVertexMaasModelGarden,run_manager:Optional[AsyncCallbackManagerForLLMRun]=None,**kwargs:Any,)->Any:"""Use tenacity to retry the async completion call."""retry_decorator=_create_retry_decorator(llm,run_manager=run_manager)@retry_decoratorasyncdef_completion_with_retry(**kwargs:Any)->Any:if"stream"notinkwargs:kwargs["stream"]=Falsestream=kwargs["stream"]ifstream:# Llama and Mistral expect different "Content-Type" for streamingheaders={"Accept":"text/event-stream"}ifheaders_content_type:=kwargs.pop("headers_content_type",None):headers["Content-Type"]=headers_content_typeevent_source=aconnect_sse(llm.async_client,"POST",llm._get_url_part(stream=True),json=kwargs,headers=headers,)return_aiter_sse(event_source)else:response=awaitllm.async_client.post(url=llm._get_url_part(),json=kwargs)await_araise_on_error(response)returnresponse.json()kwargs=llm._enrich_params(kwargs)returnawait_completion_with_retry(**kwargs)
defcompletion_with_retry(llm:_BaseVertexMaasModelGarden,**kwargs):if"stream"notinkwargs:kwargs["stream"]=Falsestream=kwargs["stream"]kwargs=llm._enrich_params(kwargs)ifstream:# Llama and Mistral expect different "Content-Type" for streamingheaders={"Accept":"text/event-stream"}ifheaders_content_type:=kwargs.pop("headers_content_type",None):headers["Content-Type"]=headers_content_typedefiter_sse():withconnect_sse(llm.client,"POST",llm._get_url_part(stream=True),json=kwargs,headers=headers,)asevent_source:_raise_on_error(event_source.response)foreventinevent_source.iter_sse():ifevent.data=="[DONE]":returnyieldevent.json()returniter_sse()response=llm.client.post(url=llm._get_url_part(),json=kwargs)_raise_on_error(response)returnresponse.json()