Source code for langchain_nvidia_ai_endpoints.embeddings
"""Embeddings Components Derived from NVEModel/Embeddings"""importosimportwarningsfromtypingimportAny,Dict,List,Literal,Optionalfromlangchain_core.embeddingsimportEmbeddingsfromlangchain_core.outputs.llm_resultimportLLMResultfromlangchain_core.pydantic_v1import(BaseModel,Field,PrivateAttr,root_validator,validator,)fromlangchain_nvidia_ai_endpoints._commonimport_NVIDIAClientfromlangchain_nvidia_ai_endpoints._staticsimportModelfromlangchain_nvidia_ai_endpoints.callbacksimportusage_callback_var
[docs]classNVIDIAEmbeddings(BaseModel,Embeddings):""" Client to NVIDIA embeddings models. Fields: - model: str, the name of the model to use - truncate: "NONE", "START", "END", truncate input text if it exceeds the model's maximum token length. Default is "NONE", which raises an error if an input is too long. """classConfig:validate_assignment=True_client:_NVIDIAClient=PrivateAttr(_NVIDIAClient)_default_model_name:str="nvidia/nv-embedqa-e5-v5"_default_max_batch_size:int=50_default_base_url:str="https://integrate.api.nvidia.com/v1"base_url:str=Field(description="Base url for model listing an invocation",)model:Optional[str]=Field(description="Name of the model to invoke")truncate:Literal["NONE","START","END"]=Field(default="NONE",description=("Truncate input text if it exceeds the model's maximum token length. ""Default is 'NONE', which raises an error if an input is too long."),)max_batch_size:int=Field(default=_default_max_batch_size)model_type:Optional[Literal["passage","query"]]=Field(None,description="(DEPRECATED) The type of text to be embedded.")_base_url_var="NVIDIA_BASE_URL"@root_validator(pre=True)def_validate_base_url(cls,values:Dict[str,Any])->Dict[str,Any]:values["base_url"]=(values.get(cls._base_url_var.lower())orvalues.get("base_url")oros.getenv(cls._base_url_var)orcls._default_base_url)returnvaluesdef__init__(self,**kwargs:Any):""" Create a new NVIDIAEmbeddings embedder. This class provides access to a NVIDIA NIM for embedding. By default, it connects to a hosted NIM, but can be configured to connect to a local NIM using the `base_url` parameter. An API key is required to connect to the hosted NIM. Args: model (str): The model to use for embedding. nvidia_api_key (str): The API key to use for connecting to the hosted NIM. api_key (str): Alternative to nvidia_api_key. base_url (str): The base URL of the NIM to connect to. Format for base URL is http://host:port trucate (str): "NONE", "START", "END", truncate input text if it exceeds the model's context length. Default is "NONE", which raises an error if an input is too long. API Key: - The recommended way to provide the API key is through the `NVIDIA_API_KEY` environment variable. Base URL: - Connect to a self-hosted model with NVIDIA NIM using the `base_url` arg to link to the local host at localhost:8000: embedder = NVIDIAEmbeddings(base_url="http://localhost:8080/v1") """super().__init__(**kwargs)self._client=_NVIDIAClient(base_url=self.base_url,model_name=self.model,default_hosted_model_name=self._default_model_name,api_key=kwargs.get("nvidia_api_key",kwargs.get("api_key",None)),infer_path="{base_url}/embeddings",cls=self.__class__.__name__,)# todo: only store the model in one place# the model may be updated to a newer name during initializationself.model=self._client.model_name# todo: remove when nvolveqa_40k is removed from MODEL_TABLEif"model"inkwargsandkwargs["model"]in["playground_nvolveqa_40k","nvolveqa_40k",]:warnings.warn('Setting truncate="END" for nvolveqa_40k backward compatibility')self.truncate="END"@validator("model_type")def_validate_model_type(cls,v:Optional[Literal["passage","query"]])->Optional[Literal["passage","query"]]:ifv:warnings.warn("Warning: `model_type` is deprecated and will be removed ""in a future release. Please use `embed_query` or ""`embed_documents` appropriately.")returnv@propertydefavailable_models(self)->List[Model]:""" Get a list of available models that work with NVIDIAEmbeddings. """returnself._client.get_available_models(self.__class__.__name__)
[docs]@classmethoddefget_available_models(cls,**kwargs:Any,)->List[Model]:""" Get a list of available models that work with NVIDIAEmbeddings. """returncls(**kwargs).available_models
def_embed(self,texts:List[str],model_type:Literal["passage","query"])->List[List[float]]:"""Embed a single text entry to either passage or query type"""# API Catalog API -# input: str | list[str] -- char limit depends on model# model: str -- model name, e.g. NV-Embed-QA# encoding_format: "float" | "base64"# input_type: "query" | "passage"# user: str -- ignored# truncate: "NONE" | "START" | "END" -- default "NONE", error raised if# an input is too longpayload={"input":texts,"model":self.model,"encoding_format":"float","input_type":model_type,}ifself.truncate:payload["truncate"]=self.truncateresponse=self._client.get_req(payload=payload,)response.raise_for_status()result=response.json()data=result.get("data",result)ifnotisinstance(data,list):raiseValueError(f"Expected data with a list of embeddings. Got: {data}")embedding_list=[(res["embedding"],res["index"])forresindata]self._invoke_callback_vars(result)return[x[0]forxinsorted(embedding_list,key=lambdax:x[1])]
[docs]defembed_query(self,text:str)->List[float]:"""Input pathway for query embeddings."""returnself._embed([text],model_type=self.model_typeor"query")[0]
[docs]defembed_documents(self,texts:List[str])->List[List[float]]:"""Input pathway for document embeddings."""ifnotisinstance(texts,list)ornotall(isinstance(text,str)fortextintexts):raiseValueError(f"`texts` must be a list of strings, given: {repr(texts)}")all_embeddings=[]foriinrange(0,len(texts),self.max_batch_size):batch=texts[i:i+self.max_batch_size]all_embeddings.extend(self._embed(batch,model_type=self.model_typeor"passage"))returnall_embeddings
def_invoke_callback_vars(self,response:dict)->None:"""Invoke the callback context variables if there are any."""callback_vars=[usage_callback_var.get(),]llm_output={**response,"model_name":self.model}result=LLMResult(generations=[[]],llm_output=llm_output)forcb_varincallback_vars:ifcb_var:cb_var.on_llm_end(result)