"""Chat Model Components Derived from ChatModel/NVIDIA"""
from __future__ import annotations
import base64
import enum
import io
import logging
import os
import sys
import urllib.parse
import warnings
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Type,
Union,
)
import requests
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.output_parsers import (
BaseOutputParser,
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
Generation,
)
from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr, root_validator
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
from langchain_nvidia_ai_endpoints._statics import Model
from langchain_nvidia_ai_endpoints._utils import convert_message_to_dict
_CallbackManager = Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
_DictOrPydanticOrEnumClass = Union[Dict[str, Any], Type[BaseModel], Type[enum.Enum]]
_DictOrPydanticOrEnum = Union[Dict, BaseModel, enum.Enum]
try:
import PIL.Image
has_pillow = True
except ImportError:
has_pillow = False
logger = logging.getLogger(__name__)
def _is_url(s: str) -> bool:
try:
result = urllib.parse.urlparse(s)
return all([result.scheme, result.netloc])
except Exception as e:
logger.debug(f"Unable to parse URL: {e}")
return False
def _resize_image(img_data: bytes, max_dim: int = 1024) -> str:
if not has_pillow:
print( # noqa: T201
"Pillow is required to resize images down to reasonable scale."
" Please install it using `pip install pillow`."
" For now, not resizing; may cause NVIDIA API to fail."
)
return base64.b64encode(img_data).decode("utf-8")
image = PIL.Image.open(io.BytesIO(img_data))
max_dim_size = max(image.size)
aspect_ratio = max_dim / max_dim_size
new_h = int(image.size[1] * aspect_ratio)
new_w = int(image.size[0] * aspect_ratio)
resized_image = image.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS)
output_buffer = io.BytesIO()
resized_image.save(output_buffer, format="JPEG")
output_buffer.seek(0)
resized_b64_string = base64.b64encode(output_buffer.read()).decode("utf-8")
return resized_b64_string
def _url_to_b64_string(image_source: str) -> str:
b64_template = "data:image/png;base64,{b64_string}"
try:
if _is_url(image_source):
response = requests.get(
image_source, headers={"User-Agent": "langchain-nvidia-ai-endpoints"}
)
response.raise_for_status()
encoded = base64.b64encode(response.content).decode("utf-8")
if sys.getsizeof(encoded) > 200000:
## (VK) Temporary fix. NVIDIA API has a limit of 250KB for the input.
encoded = _resize_image(response.content)
return b64_template.format(b64_string=encoded)
elif image_source.startswith("data:image"):
return image_source
elif os.path.exists(image_source):
with open(image_source, "rb") as f:
encoded = base64.b64encode(f.read()).decode("utf-8")
return b64_template.format(b64_string=encoded)
else:
raise ValueError(
"The provided string is not a valid URL, base64, or file path."
)
except Exception as e:
raise ValueError(f"Unable to process the provided image source: {e}")
def _nv_vlm_adjust_input(message_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
The NVIDIA VLM API input message.content:
{
"role": "user",
"content": [
...,
{
"type": "image_url",
"image_url": "{data}"
},
...
]
}
where OpenAI VLM API input message.content:
{
"role": "user",
"content": [
...,
{
"type": "image_url",
"image_url": {
"url": "{url | data}"
}
},
...
]
}
This function converts the OpenAI VLM API input message to
NVIDIA VLM API input message, in place.
In the process, it accepts a url or file and converts them to
data urls.
"""
if content := message_dict.get("content"):
if isinstance(content, list):
for part in content:
if isinstance(part, dict) and "image_url" in part:
if (
isinstance(part["image_url"], dict)
and "url" in part["image_url"]
):
part["image_url"] = _url_to_b64_string(part["image_url"]["url"])
return message_dict
[docs]class ChatNVIDIA(BaseChatModel):
"""NVIDIA chat model.
Example:
.. code-block:: python
from langchain_nvidia_ai_endpoints import ChatNVIDIA
model = ChatNVIDIA(model="meta/llama2-70b")
response = model.invoke("Hello")
"""
_client: _NVIDIAClient = PrivateAttr(_NVIDIAClient)
_default_model_name: str = "meta/llama3-8b-instruct"
_default_base_url: str = "https://integrate.api.nvidia.com/v1"
base_url: str = Field(
description="Base url for model listing an invocation",
)
model: Optional[str] = Field(description="Name of the model to invoke")
temperature: Optional[float] = Field(description="Sampling temperature in [0, 1]")
max_tokens: Optional[int] = Field(
1024, description="Maximum # of tokens to generate"
)
top_p: Optional[float] = Field(description="Top-p for distribution sampling")
seed: Optional[int] = Field(description="The seed for deterministic results")
stop: Optional[Sequence[str]] = Field(description="Stop words (cased)")
_base_url_var = "NVIDIA_BASE_URL"
@root_validator(pre=True)
def _validate_base_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["base_url"] = (
values.get(cls._base_url_var.lower())
or values.get("base_url")
or os.getenv(cls._base_url_var)
or cls._default_base_url
)
return values
def __init__(self, **kwargs: Any):
"""
Create a new NVIDIAChat chat model.
This class provides access to a NVIDIA NIM for chat. By default, it
connects to a hosted NIM, but can be configured to connect to a local NIM
using the `base_url` parameter. An API key is required to connect to the
hosted NIM.
Args:
model (str): The model to use for chat.
nvidia_api_key (str): The API key to use for connecting to the hosted NIM.
api_key (str): Alternative to nvidia_api_key.
base_url (str): The base URL of the NIM to connect to.
Format for base URL is http://host:port
temperature (float): Sampling temperature in [0, 1].
max_tokens (int): Maximum number of tokens to generate.
top_p (float): Top-p for distribution sampling.
seed (int): A seed for deterministic results.
stop (list[str]): A list of cased stop words.
API Key:
- The recommended way to provide the API key is through the `NVIDIA_API_KEY`
environment variable.
Base URL:
- Connect to a self-hosted model with NVIDIA NIM using the `base_url` arg to
link to the local host at localhost:8000:
llm = ChatNVIDIA(
base_url="http://localhost:8000/v1",
model="meta-llama3-8b-instruct"
)
"""
super().__init__(**kwargs)
self._client = _NVIDIAClient(
base_url=self.base_url,
model_name=self.model,
default_hosted_model_name=self._default_model_name,
api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)),
infer_path="{base_url}/chat/completions",
cls=self.__class__.__name__,
)
# todo: only store the model in one place
# the model may be updated to a newer name during initialization
self.model = self._client.model_name
@property
def available_models(self) -> List[Model]:
"""
Get a list of available models that work with ChatNVIDIA.
"""
return self._client.get_available_models(self.__class__.__name__)
[docs] @classmethod
def get_available_models(
cls,
**kwargs: Any,
) -> List[Model]:
"""
Get a list of available models that work with ChatNVIDIA.
"""
return cls(**kwargs).available_models
@property
def _llm_type(self) -> str:
"""Return type of NVIDIA AI Foundation Model Interface."""
return "chat-nvidia-ai-playground"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
inputs = [
_nv_vlm_adjust_input(message)
for message in [convert_message_to_dict(message) for message in messages]
]
payload = self._get_payload(inputs=inputs, stop=stop, stream=False, **kwargs)
response = self._client.get_req(payload=payload)
responses, _ = self._client.postprocess(response)
self._set_callback_out(responses, run_manager)
parsed_response = self._custom_postprocess(responses, streaming=False)
# for pre 0.2 compatibility w/ ChatMessage
# ChatMessage had a role property that was not present in AIMessage
parsed_response.update({"role": "assistant"})
generation = ChatGeneration(message=AIMessage(**parsed_response))
return ChatResult(generations=[generation], llm_output=responses)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[Sequence[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Allows streaming to model!"""
inputs = [
_nv_vlm_adjust_input(message)
for message in [convert_message_to_dict(message) for message in messages]
]
payload = self._get_payload(inputs=inputs, stop=stop, stream=True, **kwargs)
for response in self._client.get_req_stream(payload=payload):
self._set_callback_out(response, run_manager)
parsed_response = self._custom_postprocess(response, streaming=True)
# for pre 0.2 compatibility w/ ChatMessageChunk
# ChatMessageChunk had a role property that was not
# present in AIMessageChunk
# unfortunately, AIMessageChunk does not have extensible propery
# parsed_response.update({"role": "assistant"})
message = AIMessageChunk(**parsed_response)
chunk = ChatGenerationChunk(message=message)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
def _set_callback_out(
self,
result: dict,
run_manager: Optional[_CallbackManager],
) -> None:
result.update({"model_name": self.model})
if run_manager:
for cb in run_manager.handlers:
if hasattr(cb, "llm_output"):
cb.llm_output = result
def _custom_postprocess(
self, msg: dict, streaming: bool = False
) -> dict: # todo: remove
kw_left = msg.copy()
out_dict = {
"role": kw_left.pop("role", "assistant") or "assistant",
"name": kw_left.pop("name", None),
"id": kw_left.pop("id", None),
"content": kw_left.pop("content", "") or "",
"additional_kwargs": {},
"response_metadata": {},
}
# "tool_calls" is set for invoke and stream responses
if tool_calls := kw_left.pop("tool_calls", None):
assert isinstance(
tool_calls, list
), "invalid response from server: tool_calls must be a list"
# todo: break this into post-processing for invoke and stream
if not streaming:
out_dict["additional_kwargs"]["tool_calls"] = tool_calls
elif streaming:
out_dict["tool_call_chunks"] = []
for tool_call in tool_calls:
# todo: the nim api does not return the function index
# for tool calls in stream responses. this is
# an issue that needs to be resolved server-side.
# the only reason we can skip this for now
# is because the nim endpoint returns only full
# tool calls, no deltas.
# assert "index" in tool_call, (
# "invalid response from server: "
# "tool_call must have an 'index' key"
# )
assert "function" in tool_call, (
"invalid response from server: "
"tool_call must have a 'function' key"
)
out_dict["tool_call_chunks"].append(
{
"index": tool_call.get("index", None),
"id": tool_call.get("id", None),
"name": tool_call["function"].get("name", None),
"args": tool_call["function"].get("arguments", None),
}
)
# we only create the response_metadata from the last message in a stream.
# if we do it for all messages, we'll end up with things like
# "model_name" = "mode-xyz" * # messages.
if "finish_reason" in kw_left:
out_dict["response_metadata"] = kw_left
return out_dict
######################################################################################
## Core client-side interfaces
def _get_payload(
self, inputs: Sequence[Dict], **kwargs: Any
) -> dict: # todo: remove
"""Generates payload for the _NVIDIAClient API to send to service."""
messages: List[Dict[str, Any]] = []
for msg in inputs:
if isinstance(msg, str):
# (WFH) this shouldn't ever be reached but leaving this here bcs
# it's a Chesterton's fence I'm unwilling to touch
messages.append(dict(role="user", content=msg))
elif isinstance(msg, dict):
if msg.get("content", None) is None:
# content=None is valid for assistant messages (tool calling)
if not msg.get("role") == "assistant":
raise ValueError(f"Message {msg} has no content.")
messages.append(msg)
else:
raise ValueError(f"Unknown message received: {msg} of type {type(msg)}")
# special handling for "stop" because it always comes in kwargs.
# if user provided "stop" to invoke/stream, it will be non-None
# in kwargs.
# note: we cannot tell if the user specified stop=None to invoke/stream because
# the default value of stop is None.
# todo: remove self.stop
assert "stop" in kwargs, '"stop" param is expected in kwargs'
if kwargs["stop"] is None:
kwargs.pop("stop")
# setup default payload values
payload: Dict[str, Any] = {
"model": self.model,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"seed": self.seed,
"stop": self.stop,
}
# merge incoming kwargs with attr_kwargs giving preference to
# the incoming kwargs
payload.update(kwargs)
# remove keys with None values from payload
payload = {k: v for k, v in payload.items() if v is not None}
return {"messages": messages, **payload}
[docs] def bind_functions(
self,
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
function_call: Optional[str] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
raise NotImplementedError("Not implemented, use `bind_tools` instead.")
# we have an Enum extension to BaseChatModel.with_structured_output and
# as a result need to type ignore for the schema parameter and return type.
[docs] def with_structured_output( # type: ignore
self,
schema: _DictOrPydanticOrEnumClass,
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydanticOrEnum]:
"""
Bind a structured output schema to the model.
The schema can be -
0. a dictionary representing a JSON schema
1. a Pydantic object
2. an Enum
0. If a dictionary is provided, the model will return a dictionary. Example:
```
json_schema = {
"title": "joke",
"description": "Joke to tell user.",
"type": "object",
"properties": {
"setup": {
"type": "string",
"description": "The setup of the joke",
},
"punchline": {
"type": "string",
"description": "The punchline to the joke",
},
},
"required": ["setup", "punchline"],
}
structured_llm = llm.with_structured_output(json_schema)
structured_llm.invoke("Tell me a joke about NVIDIA")
# Output: {'setup': 'Why did NVIDIA go broke? The hardware ate all the software.',
# 'punchline': 'It took a big bite out of their main board.'}
```
1. If a Pydantic schema is provided, the model will return a Pydantic object.
Example:
```
from langchain_core.pydantic_v1 import BaseModel, Field
class Joke(BaseModel):
setup: str = Field(description="The setup of the joke")
punchline: str = Field(description="The punchline to the joke")
structured_llm = llm.with_structured_output(Joke)
structured_llm.invoke("Tell me a joke about NVIDIA")
# Output: Joke(setup='Why did NVIDIA go broke? The hardware ate all the software.',
# punchline='It took a big bite out of their main board.')
```
2. If an Enum is provided, all values must be strings, and the model will return
an Enum object. Example:
```
import enum
class Choices(enum.Enum):
A = "A"
B = "B"
C = "C"
structured_llm = llm.with_structured_output(Choices)
structured_llm.invoke("What is the first letter in this list? [X, Y, Z, C]")
# Output: <Choices.C: 'C'>
```
Note about streaming: Unlike other streaming responses, the streamed chunks
will be increasingly complete. They will not be deltas. The last chunk will
contain the complete response.
For instance with a dictionary schema, the chunks will be:
```
structured_llm = llm.with_structured_output(json_schema)
for chunk in structured_llm.stream("Tell me a joke about NVIDIA"):
print(chunk)
# Output:
# {}
# {'setup': ''}
# {'setup': 'Why'}
# {'setup': 'Why did'}
# {'setup': 'Why did N'}
# {'setup': 'Why did NVID'}
# ...
# {'setup': 'Why did NVIDIA go broke? The hardware ate all the software.', 'punchline': 'It took a big bite out of their main board'}
# {'setup': 'Why did NVIDIA go broke? The hardware ate all the software.', 'punchline': 'It took a big bite out of their main board.'}
```
For instnace with a Pydantic schema, the chunks will be:
```
structured_llm = llm.with_structured_output(Joke)
for chunk in structured_llm.stream("Tell me a joke about NVIDIA"):
print(chunk)
# Output:
# setup='Why did NVIDIA go broke? The hardware ate all the software.' punchline=''
# setup='Why did NVIDIA go broke? The hardware ate all the software.' punchline='It'
# setup='Why did NVIDIA go broke? The hardware ate all the software.' punchline='It took'
# ...
# setup='Why did NVIDIA go broke? The hardware ate all the software.' punchline='It took a big bite out of their main board'
# setup='Why did NVIDIA go broke? The hardware ate all the software.' punchline='It took a big bite out of their main board.'
```
For Pydantic schema and Enum, the output will be None if the response is
insufficient to construct the object or otherwise invalid. For instance,
```
llm = ChatNVIDIA(max_tokens=1)
structured_llm = llm.with_structured_output(Joke)
print(structured_llm.invoke("Tell me a joke about NVIDIA"))
# Output: None
```
For more, see https://python.langchain.com/v0.2/docs/how_to/structured_output/
""" # noqa: E501
if "method" in kwargs:
warnings.warn(
"The 'method' parameter is unnecessary and is ignored. "
"The appropriate method will be chosen automatically depending "
"on the type of schema provided."
)
if include_raw:
raise NotImplementedError(
"include_raw=True is not implemented, consider "
"https://python.langchain.com/v0.2/docs/how_to/"
"structured_output/#prompting-and-parsing-model"
"-outputs-directly or rely on the structured response "
"being None when the LLM produces an incomplete response."
)
# check if the model supports structured output, warn if it does not
known_good = False
# todo: we need to store model: Model in this class
# instead of model: str (= Model.id)
# this should be: if not self.model.supports_tools: warnings.warn...
candidates = [
model for model in self.available_models if model.id == self.model
]
if not candidates: # user must have specified the model themselves
known_good = False
else:
assert len(candidates) == 1, "Multiple models with the same id"
known_good = candidates[0].supports_structured_output is True
if not known_good:
warnings.warn(
f"Model '{self.model}' is not known to support structured output. "
"Your output may fail at inference time."
)
if isinstance(schema, dict):
output_parser: BaseOutputParser = JsonOutputParser()
nvext_param: Dict[str, Any] = {"guided_json": schema}
elif issubclass(schema, enum.Enum):
# langchain's EnumOutputParser is not in langchain_core
# and doesn't support streaming. this is a simple implementation
# that supports streaming with our semantics of returning None
# if no complete object can be constructed.
class EnumOutputParser(BaseOutputParser):
enum: Type[enum.Enum]
def parse(self, response: str) -> Any:
try:
return self.enum(response.strip())
except ValueError:
pass
return None
# guided_choice only supports string choices
choices = [choice.value for choice in schema]
if not all(isinstance(choice, str) for choice in choices):
# instead of erroring out we could coerce the enum values to
# strings, but would then need to coerce them back to their
# original type for Enum construction.
raise ValueError(
"Enum schema must only contain string choices. "
"Use StrEnum or ensure all member values are strings."
)
output_parser = EnumOutputParser(enum=schema)
nvext_param = {"guided_choice": choices}
elif is_basemodel_subclass(schema):
# PydanticOutputParser does not support streaming. what we do
# instead is ignore all inputs that are incomplete wrt the
# underlying Pydantic schema. if the entire input is invalid,
# we return None.
class ForgivingPydanticOutputParser(PydanticOutputParser):
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Any:
try:
return super().parse_result(result, partial=partial)
except OutputParserException:
pass
return None
output_parser = ForgivingPydanticOutputParser(pydantic_object=schema)
nvext_param = {"guided_json": schema.schema()}
else:
raise ValueError(
"Schema must be a Pydantic object, a dictionary "
"representing a JSON schema, or an Enum."
)
return super().bind(nvext=nvext_param) | output_parser