from __future__ import annotations
import warnings
from typing import Any, Dict, Iterator, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from pydantic import ConfigDict, Field, PrivateAttr
from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
from langchain_nvidia_ai_endpoints._statics import Model
_DEFAULT_MODEL_NAME: str = "nvidia/mistral-nemo-minitron-8b-base"
[docs]
class NVIDIA(LLM):
"""
LangChain LLM that uses the Completions API with NVIDIA NIMs.
"""
model_config = ConfigDict(
validate_assignment=True,
)
_client: _NVIDIAClient = PrivateAttr()
_default_model_name: str = "nvidia/mistral-nemo-minitron-8b-base"
base_url: Optional[str] = Field(
default=None,
description="Base url for model listing and invocation",
)
model: Optional[str] = Field(None, description="The model to use for completions.")
_init_args: Dict[str, Any] = PrivateAttr()
"""Stashed arguments given to the constructor that can be passed to
the Completions API endpoint."""
def __check_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""
Check kwargs, warn for unknown keys, and return a copy recognized keys.
"""
completions_arguments = {
"frequency_penalty",
"max_tokens",
"presence_penalty",
"seed",
"stop",
"temperature",
"top_p",
"best_of",
"echo",
"logit_bias",
"logprobs",
"n",
"suffix",
"user",
"stream",
}
recognized_kwargs = {
k: v for k, v in kwargs.items() if k in completions_arguments
}
unrecognized_kwargs = set(kwargs) - completions_arguments
if len(unrecognized_kwargs) > 0:
warnings.warn(f"Unrecognized, ignored arguments: {unrecognized_kwargs}")
return recognized_kwargs
def __init__(self, **kwargs: Any):
"""
Create a new NVIDIA LLM for Completions APIs.
This class provides access to a NVIDIA NIM for completions. By default, it
connects to a hosted NIM, but can be configured to connect to a local NIM
using the `base_url` parameter. An API key is required to connect to the
hosted NIM.
Args:
model (str): The model to use for completions.
nvidia_api_key (str): The API key to use for connecting to the hosted NIM.
api_key (str): Alternative to nvidia_api_key.
base_url (str): The base URL of the NIM to connect to.
API Key:
- The recommended way to provide the API key is through the `NVIDIA_API_KEY`
environment variable.
Additional arguments that can be passed to the Completions API:
- max_tokens (int): The maximum number of tokens to generate.
- stop (str or List[str]): The stop sequence to use for generating completions.
- temperature (float): The temperature to use for generating completions.
- top_p (float): The top-p value to use for generating completions.
- frequency_penalty (float): The frequency penalty to apply to the completion.
- presence_penalty (float): The presence penalty to apply to the completion.
- seed (int): The seed to use for generating completions.
These additional arguments can also be passed with `bind()`, e.g.
`NVIDIA().bind(max_tokens=512)`, or pass directly to `invoke()` or `stream()`,
e.g. `NVIDIA().invoke("prompt", max_tokens=512)`.
"""
super().__init__(**kwargs)
# allow nvidia_base_url as an alternative for base_url
base_url = kwargs.pop("nvidia_base_url", self.base_url)
# allow nvidia_api_key as an alternative for api_key
api_key = kwargs.pop("nvidia_api_key", kwargs.pop("api_key", None))
self._client = _NVIDIAClient(
**({"base_url": base_url} if base_url else {}), # only pass if set
mdl_name=self.model,
default_hosted_model_name=_DEFAULT_MODEL_NAME,
**({"api_key": api_key} if api_key else {}), # only pass if set
infer_path="{base_url}/completions",
cls=self.__class__.__name__,
)
# todo: only store the model in one place
# the model may be updated to a newer name during initialization
self.model = self._client.mdl_name
# same for base_url
self.base_url = self._client.base_url
# stash all additional args that can be passed to the Completions API,
# but first make sure we pull out any args that are processed elsewhere.
for key in [
"model",
"nvidia_base_url",
"base_url",
]:
if key in kwargs:
del kwargs[key]
self._init_args = self.__check_kwargs(kwargs)
@property
def available_models(self) -> List[Model]:
"""
Get a list of available models that work with NVIDIA.
"""
return self._client.get_available_models(self.__class__.__name__)
[docs]
@classmethod
def get_available_models(
cls,
**kwargs: Any,
) -> List[Model]:
"""
Get a list of available models that work with the Completions API.
"""
return cls(**kwargs).available_models
@property
def _llm_type(self) -> str:
"""
Get the type of language model used by this chat model.
Used for logging purposes only.
"""
return "NVIDIA"
@property
def _identifying_params(self) -> Dict[str, Any]:
"""
Get parameters used to help identify the LLM.
"""
return {
"model": self.model,
"base_url": self.base_url,
}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
payload: Dict[str, Any] = {
"model": self.model,
"prompt": prompt,
**self._init_args,
**self.__check_kwargs(kwargs),
}
if stop:
payload["stop"] = stop
if payload.get("stream", False):
warnings.warn("stream set to true for non-streaming call, ignoring")
del payload["stream"]
response = self._client.get_req(payload=payload)
response.raise_for_status()
# todo: handle response's usage and system_fingerprint
choices = response.json()["choices"]
# todo: write a test for this by setting n > 1 on the request
# aug 2024: n > 1 is not supported by endpoints
if len(choices) > 1:
warnings.warn(
f"Multiple choices in response, returning only the first: {choices}"
)
return choices[0]["text"]
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
payload: Dict[str, Any] = {
"model": self.model,
"prompt": prompt,
"stream": True,
**self._init_args,
**self.__check_kwargs(kwargs),
}
if stop:
payload["stop"] = stop
# we construct payload w/ **kwargs positioned to override stream=True,
# this lets us know if a user passed stream=False
if not payload.get("stream", True):
warnings.warn("stream set to false for streaming call, ignoring")
payload["stream"] = True
for chunk in self._client.get_req_stream(payload=payload):
content = chunk["content"]
generation = GenerationChunk(text=content)
if run_manager: # todo: add tests for run_manager
run_manager.on_llm_new_token(content, chunk=generation)
yield generation