Source code for langchain_nvidia_ai_endpoints.embeddings

"""Embeddings Components Derived from NVEModel/Embeddings"""

import os
import warnings
from typing import Any, Dict, List, Literal, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.pydantic_v1 import (
    BaseModel,
    Field,
    PrivateAttr,
    root_validator,
    validator,
)

from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
from langchain_nvidia_ai_endpoints._statics import Model
from langchain_nvidia_ai_endpoints.callbacks import usage_callback_var


[docs]class NVIDIAEmbeddings(BaseModel, Embeddings): """ Client to NVIDIA embeddings models. Fields: - model: str, the name of the model to use - truncate: "NONE", "START", "END", truncate input text if it exceeds the model's maximum token length. Default is "NONE", which raises an error if an input is too long. """ class Config: validate_assignment = True _client: _NVIDIAClient = PrivateAttr(_NVIDIAClient) _default_model_name: str = "nvidia/nv-embedqa-e5-v5" _default_max_batch_size: int = 50 _default_base_url: str = "https://integrate.api.nvidia.com/v1" base_url: str = Field( description="Base url for model listing an invocation", ) model: Optional[str] = Field(description="Name of the model to invoke") truncate: Literal["NONE", "START", "END"] = Field( default="NONE", description=( "Truncate input text if it exceeds the model's maximum token length. " "Default is 'NONE', which raises an error if an input is too long." ), ) max_batch_size: int = Field(default=_default_max_batch_size) model_type: Optional[Literal["passage", "query"]] = Field( None, description="(DEPRECATED) The type of text to be embedded." ) _base_url_var = "NVIDIA_BASE_URL" @root_validator(pre=True) def _validate_base_url(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["base_url"] = ( values.get(cls._base_url_var.lower()) or values.get("base_url") or os.getenv(cls._base_url_var) or cls._default_base_url ) return values def __init__(self, **kwargs: Any): """ Create a new NVIDIAEmbeddings embedder. This class provides access to a NVIDIA NIM for embedding. By default, it connects to a hosted NIM, but can be configured to connect to a local NIM using the `base_url` parameter. An API key is required to connect to the hosted NIM. Args: model (str): The model to use for embedding. nvidia_api_key (str): The API key to use for connecting to the hosted NIM. api_key (str): Alternative to nvidia_api_key. base_url (str): The base URL of the NIM to connect to. Format for base URL is http://host:port trucate (str): "NONE", "START", "END", truncate input text if it exceeds the model's context length. Default is "NONE", which raises an error if an input is too long. API Key: - The recommended way to provide the API key is through the `NVIDIA_API_KEY` environment variable. Base URL: - Connect to a self-hosted model with NVIDIA NIM using the `base_url` arg to link to the local host at localhost:8000: embedder = NVIDIAEmbeddings(base_url="http://localhost:8080/v1") """ super().__init__(**kwargs) self._client = _NVIDIAClient( base_url=self.base_url, model_name=self.model, default_hosted_model_name=self._default_model_name, api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)), infer_path="{base_url}/embeddings", cls=self.__class__.__name__, ) # todo: only store the model in one place # the model may be updated to a newer name during initialization self.model = self._client.model_name # todo: remove when nvolveqa_40k is removed from MODEL_TABLE if "model" in kwargs and kwargs["model"] in [ "playground_nvolveqa_40k", "nvolveqa_40k", ]: warnings.warn( 'Setting truncate="END" for nvolveqa_40k backward compatibility' ) self.truncate = "END" @validator("model_type") def _validate_model_type( cls, v: Optional[Literal["passage", "query"]] ) -> Optional[Literal["passage", "query"]]: if v: warnings.warn( "Warning: `model_type` is deprecated and will be removed " "in a future release. Please use `embed_query` or " "`embed_documents` appropriately." ) return v @property def available_models(self) -> List[Model]: """ Get a list of available models that work with NVIDIAEmbeddings. """ return self._client.get_available_models(self.__class__.__name__)
[docs] @classmethod def get_available_models( cls, **kwargs: Any, ) -> List[Model]: """ Get a list of available models that work with NVIDIAEmbeddings. """ return cls(**kwargs).available_models
def _embed( self, texts: List[str], model_type: Literal["passage", "query"] ) -> List[List[float]]: """Embed a single text entry to either passage or query type""" # API Catalog API - # input: str | list[str] -- char limit depends on model # model: str -- model name, e.g. NV-Embed-QA # encoding_format: "float" | "base64" # input_type: "query" | "passage" # user: str -- ignored # truncate: "NONE" | "START" | "END" -- default "NONE", error raised if # an input is too long payload = { "input": texts, "model": self.model, "encoding_format": "float", "input_type": model_type, } if self.truncate: payload["truncate"] = self.truncate response = self._client.get_req( payload=payload, ) response.raise_for_status() result = response.json() data = result.get("data", result) if not isinstance(data, list): raise ValueError(f"Expected data with a list of embeddings. Got: {data}") embedding_list = [(res["embedding"], res["index"]) for res in data] self._invoke_callback_vars(result) return [x[0] for x in sorted(embedding_list, key=lambda x: x[1])]
[docs] def embed_query(self, text: str) -> List[float]: """Input pathway for query embeddings.""" return self._embed([text], model_type=self.model_type or "query")[0]
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """Input pathway for document embeddings.""" if not isinstance(texts, list) or not all( isinstance(text, str) for text in texts ): raise ValueError(f"`texts` must be a list of strings, given: {repr(texts)}") all_embeddings = [] for i in range(0, len(texts), self.max_batch_size): batch = texts[i : i + self.max_batch_size] all_embeddings.extend( self._embed(batch, model_type=self.model_type or "passage") ) return all_embeddings
def _invoke_callback_vars(self, response: dict) -> None: """Invoke the callback context variables if there are any.""" callback_vars = [ usage_callback_var.get(), ] llm_output = {**response, "model_name": self.model} result = LLMResult(generations=[[]], llm_output=llm_output) for cb_var in callback_vars: if cb_var: cb_var.on_llm_end(result)