[docs]classNVIDIAEmbeddings(BaseModel,Embeddings):""" Client to NVIDIA embeddings models. Fields: - model: str, the name of the model to use - truncate: "NONE", "START", "END", truncate input text if it exceeds the model's maximum token length. Default is "NONE", which raises an error if an input is too long. - dimensions: int, the number of dimensions for the embeddings. This parameter is not supported by all models. """model_config=ConfigDict(validate_assignment=True,)_client:_NVIDIAClient=PrivateAttr()base_url:Optional[str]=Field(default=None,description="Base url for model listing an invocation",)model:Optional[str]=Field(None,description="Name of the model to invoke")truncate:Literal["NONE","START","END"]=Field(default="NONE",description=("Truncate input text if it exceeds the model's maximum token length. ""Default is 'NONE', which raises an error if an input is too long."),)dimensions:Optional[int]=Field(default=None,description=("The number of dimensions for the embeddings. This parameter is not ""supported by all models."),)max_batch_size:int=Field(default=_DEFAULT_BATCH_SIZE)def__init__(self,**kwargs:Any):""" Create a new NVIDIAEmbeddings embedder. This class provides access to a NVIDIA NIM for embedding. By default, it connects to a hosted NIM, but can be configured to connect to a local NIM using the `base_url` parameter. An API key is required to connect to the hosted NIM. Args: model (str): The model to use for embedding. nvidia_api_key (str): The API key to use for connecting to the hosted NIM. api_key (str): Alternative to nvidia_api_key. base_url (str): The base URL of the NIM to connect to. Format for base URL is http://host:port trucate (str): "NONE", "START", "END", truncate input text if it exceeds the model's context length. Default is "NONE", which raises an error if an input is too long. dimensions (int): The number of dimensions for the embeddings. This parameter is not supported by all models. API Key: - The recommended way to provide the API key is through the `NVIDIA_API_KEY` environment variable. Base URL: - Connect to a self-hosted model with NVIDIA NIM using the `base_url` arg to link to the local host at localhost:8000: embedder = NVIDIAEmbeddings(base_url="http://localhost:8080/v1") """super().__init__(**kwargs)# allow nvidia_base_url as an alternative for base_urlbase_url=kwargs.pop("nvidia_base_url",self.base_url)# allow nvidia_api_key as an alternative for api_keyapi_key=kwargs.pop("nvidia_api_key",kwargs.pop("api_key",None))self._client=_NVIDIAClient(**({"base_url":base_url}ifbase_urlelse{}),# only pass if setmdl_name=self.model,default_hosted_model_name=_DEFAULT_MODEL_NAME,**({"api_key":api_key}ifapi_keyelse{}),# only pass if setinfer_path="{base_url}/embeddings",cls=self.__class__.__name__,)# todo: only store the model in one place# the model may be updated to a newer name during initializationself.model=self._client.mdl_name# same for base_urlself.base_url=self._client.base_url@propertydefavailable_models(self)->List[Model]:""" Get a list of available models that work with NVIDIAEmbeddings. """returnself._client.get_available_models(self.__class__.__name__)
[docs]@classmethoddefget_available_models(cls,**kwargs:Any,)->List[Model]:""" Get a list of available models that work with NVIDIAEmbeddings. """returncls(**kwargs).available_models
def_embed(self,texts:List[str],model_type:Literal["passage","query"])->List[List[float]]:"""Embed a single text entry to either passage or query type"""# API Catalog API -# input: str | list[str] -- char limit depends on model# model: str -- model name, e.g. NV-Embed-QA# encoding_format: "float" | "base64"# input_type: "query" | "passage"# user: str -- ignored# truncate: "NONE" | "START" | "END" -- default "NONE", error raised if# an input is too long# dimensions: int -- not supported by all modelspayload:Dict[str,Any]={"input":texts,"model":self.model,"encoding_format":"float","input_type":model_type,}ifself.truncate:payload["truncate"]=self.truncateifself.dimensions:payload["dimensions"]=self.dimensionsresponse=self._client.get_req(payload=payload,)response.raise_for_status()result=response.json()data=result.get("data",result)ifnotisinstance(data,list):raiseValueError(f"Expected data with a list of embeddings. Got: {data}")embedding_list=[(res["embedding"],res["index"])forresindata]self._invoke_callback_vars(result)return[x[0]forxinsorted(embedding_list,key=lambdax:x[1])]
[docs]defembed_query(self,text:str)->List[float]:"""Input pathway for query embeddings."""returnself._embed([text],model_type="query")[0]
[docs]defembed_documents(self,texts:List[str])->List[List[float]]:"""Input pathway for document embeddings."""ifnotisinstance(texts,list)ornotall(isinstance(text,str)fortextintexts):raiseValueError(f"`texts` must be a list of strings, given: {repr(texts)}")all_embeddings=[]foriinrange(0,len(texts),self.max_batch_size):batch=texts[i:i+self.max_batch_size]all_embeddings.extend(self._embed(batch,model_type="passage"))returnall_embeddings
def_invoke_callback_vars(self,response:dict)->None:"""Invoke the callback context variables if there are any."""callback_vars=[usage_callback_var.get(),]llm_output={**response,"model_name":self.model}result=LLMResult(generations=[[]],llm_output=llm_output)forcb_varincallback_vars:ifcb_var:cb_var.on_llm_end(result)