import copy
from enum import Enum, auto
from typing import (
Any,
AsyncContextManager,
AsyncIterator,
Callable,
Dict,
List,
Optional,
Union,
)
import httpx
from google import auth
from google.auth.credentials import Credentials
from google.auth.transport import requests as auth_requests
from httpx_sse import (
EventSource,
aconnect_sse,
connect_sse,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from pydantic import ConfigDict, model_validator
from typing_extensions import Self
from langchain_google_vertexai._base import _VertexAIBase
_MISTRAL_MODELS: List[str] = [
"mistral-nemo@2407",
"mistral-large@2407",
]
_LLAMA_MODELS: List[str] = [
"meta/llama3-405b-instruct-maas",
"meta/llama3-70b-instruct-maas",
"meta/llama3-8b-instruct-maas",
"meta/llama-3.2-90b-vision-instruct-maas",
]
def _get_token(credentials: Optional[Credentials] = None) -> str:
"""Returns a valid token for GCP auth."""
credentials = auth.default()[0] if not credentials else credentials
request = auth_requests.Request()
credentials.refresh(request)
if not credentials.token:
raise ValueError("Couldn't retrieve a token!")
return credentials.token
def _raise_on_error(response: httpx.Response) -> None:
"""Raise an error if the response is an error."""
if httpx.codes.is_error(response.status_code):
error_message = response.read().decode("utf-8")
raise httpx.HTTPStatusError(
f"Error response {response.status_code} "
f"while fetching {response.url}: {error_message}",
request=response.request,
response=response,
)
async def _araise_on_error(response: httpx.Response) -> None:
"""Raise an error if the response is an error."""
if httpx.codes.is_error(response.status_code):
error_message = (await response.aread()).decode("utf-8")
raise httpx.HTTPStatusError(
f"Error response {response.status_code} "
f"while fetching {response.url}: {error_message}",
request=response.request,
response=response,
)
async def _aiter_sse(
event_source_mgr: AsyncContextManager[EventSource],
) -> AsyncIterator[Dict]:
"""Iterate over the server-sent events."""
async with event_source_mgr as event_source:
await _araise_on_error(event_source.response)
async for event in event_source.aiter_sse():
if event.data == "[DONE]":
return
yield event.json()
class VertexMaaSModelFamily(str, Enum):
LLAMA = auto()
# https://cloud.google.com/blog/products/ai-machine-learning/llama-3-1-on-vertex-ai
MISTRAL = auto()
# https://cloud.google.com/blog/products/ai-machine-learning/codestral-and-mistral-large-v2-on-vertex-ai
@classmethod
def _missing_(cls, value: Any) -> "VertexMaaSModelFamily":
model_name = value.lower()
if model_name in _LLAMA_MODELS:
return VertexMaaSModelFamily.LLAMA
if model_name in _MISTRAL_MODELS:
return VertexMaaSModelFamily.MISTRAL
raise ValueError(f"Model {model_name} is not supported yet!")
class _BaseVertexMaasModelGarden(_VertexAIBase):
append_tools_to_system_message: bool = False
"Whether to append tools to the system message or not."
model_family: Optional[VertexMaaSModelFamily] = None
timeout: int = 120
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
token = _get_token(credentials=self.credentials)
endpoint = self.get_url()
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {token}",
"x-goog-api-client": self._library_version,
"user_agent": self._user_agent,
}
self.client = httpx.Client(
base_url=endpoint,
headers=headers,
timeout=self.timeout,
)
self.async_client = httpx.AsyncClient(
base_url=endpoint,
headers=headers,
timeout=self.timeout,
)
@model_validator(mode="after")
def validate_environment_model_garden(self) -> Self:
"""Validate that the python package exists in environment."""
family = VertexMaaSModelFamily(self.model_name)
self.model_family = family
if family == VertexMaaSModelFamily.MISTRAL:
model = self.model_name.split("@")[0] if self.model_name else None
self.full_model_name = self.model_name
self.model_name = model
return self
def _enrich_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Fix params to be compliant with Vertex AI."""
copy_params = copy.deepcopy(params)
_ = copy_params.pop("safe_prompt", None)
copy_params["model"] = self.model_name
return copy_params
def _get_url_part(self, stream: bool = False) -> str:
if self.model_family == VertexMaaSModelFamily.MISTRAL:
if stream:
return (
f"publishers/mistralai/models/{self.full_model_name}"
":streamRawPredict"
)
return f"publishers/mistralai/models/{self.full_model_name}:rawPredict"
return "endpoints/openapi/chat/completions"
def get_url(self) -> str:
if self.model_family == VertexMaaSModelFamily.LLAMA:
version = "v1beta1"
else:
version = "v1"
return (
f"https://{self.location}-aiplatform.googleapis.com/{version}/projects/"
f"{self.project}/locations/{self.location}"
)
def _create_retry_decorator(
llm: _BaseVertexMaasModelGarden,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""
errors = [httpx.RequestError, httpx.StreamError]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
async def acompletion_with_retry(
llm: _BaseVertexMaasModelGarden,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
if "stream" not in kwargs:
kwargs["stream"] = False
stream = kwargs["stream"]
if stream:
# Llama and Mistral expect different "Content-Type" for streaming
headers = {"Accept": "text/event-stream"}
if headers_content_type := kwargs.pop("headers_content_type", None):
headers["Content-Type"] = headers_content_type
event_source = aconnect_sse(
llm.async_client,
"POST",
llm._get_url_part(stream=True),
json=kwargs,
headers=headers,
)
return _aiter_sse(event_source)
else:
response = await llm.async_client.post(url=llm._get_url_part(), json=kwargs)
await _araise_on_error(response)
return response.json()
kwargs = llm._enrich_params(kwargs)
return await _completion_with_retry(**kwargs)
def completion_with_retry(llm: _BaseVertexMaasModelGarden, **kwargs):
if "stream" not in kwargs:
kwargs["stream"] = False
stream = kwargs["stream"]
kwargs = llm._enrich_params(kwargs)
if stream:
# Llama and Mistral expect different "Content-Type" for streaming
headers = {"Accept": "text/event-stream"}
if headers_content_type := kwargs.pop("headers_content_type", None):
headers["Content-Type"] = headers_content_type
def iter_sse():
with connect_sse(
llm.client,
"POST",
llm._get_url_part(stream=True),
json=kwargs,
headers=headers,
) as event_source:
_raise_on_error(event_source.response)
for event in event_source.iter_sse():
if event.data == "[DONE]":
return
yield event.json()
return iter_sse()
response = llm.client.post(url=llm._get_url_part(), json=kwargs)
_raise_on_error(response)
return response.json()