Source code for langchain_community.embeddings.textembed
"""TextEmbed: Embedding Inference ServerTextEmbed provides a high-throughput, low-latency solution for serving embeddings.It supports various sentence-transformer models.Now, it includes the ability to deploy image embedding models.TextEmbed offers flexibility and scalability for diverse applications.TextEmbed is maintained by Keval Dekivadiya and is licensed under the Apache-2.0 license."""# noqa: E501importasynciofromconcurrent.futuresimportThreadPoolExecutorfromtypingimportAny,Callable,Dict,List,Optional,Tuple,Unionimportaiohttpimportnumpyasnpimportrequestsfromlangchain_core.embeddingsimportEmbeddingsfromlangchain_core.utilsimportfrom_env,secret_from_envfrompydanticimportBaseModel,ConfigDict,Field,SecretStr,model_validatorfromtyping_extensionsimportSelf__all__=["TextEmbedEmbeddings"]
[docs]classTextEmbedEmbeddings(BaseModel,Embeddings):""" A class to handle embedding requests to the TextEmbed API. Attributes: model : The TextEmbed model ID to use for embeddings. api_url : The base URL for the TextEmbed API. api_key : The API key for authenticating with the TextEmbed API. client : The TextEmbed client instance. Example: .. code-block:: python from langchain_community.embeddings import TextEmbedEmbeddings embeddings = TextEmbedEmbeddings( model="sentence-transformers/clip-ViT-B-32", api_url="http://localhost:8000/v1", api_key="<API_KEY>" ) For more information: https://github.com/kevaldekivadiya2415/textembed/blob/main/docs/setup.md """# noqa: E501model:str"""Underlying TextEmbed model id."""api_url:str=Field(default_factory=from_env("TEXTEMBED_API_URL",default="http://localhost:8000/v1"))"""Endpoint URL to use."""api_key:SecretStr=Field(default_factory=secret_from_env("TEXTEMBED_API_KEY"))"""API Key for authentication"""client:Any=None"""TextEmbed client."""model_config=ConfigDict(extra="forbid",)@model_validator(mode="after")defvalidate_environment(self)->Self:"""Validate that api key and URL exist in the environment."""self.client=AsyncOpenAITextEmbedEmbeddingClient(host=self.api_url,api_key=self.api_key.get_secret_value())returnself
[docs]defembed_documents(self,texts:List[str])->List[List[float]]:"""Call out to TextEmbed's embedding endpoint. Args: texts (List[str]): The list of texts to embed. Returns: List[List[float]]: List of embeddings, one for each text. """embeddings=self.client.embed(model=self.model,texts=texts,)returnembeddings
[docs]asyncdefaembed_documents(self,texts:List[str])->List[List[float]]:"""Async call out to TextEmbed's embedding endpoint. Args: texts (List[str]): The list of texts to embed. Returns: List[List[float]]: List of embeddings, one for each text. """embeddings=awaitself.client.aembed(model=self.model,texts=texts,)returnembeddings
[docs]defembed_query(self,text:str)->List[float]:"""Call out to TextEmbed's embedding endpoint for a single query. Args: text (str): The text to embed. Returns: List[float]: Embeddings for the text. """returnself.embed_documents([text])[0]
[docs]asyncdefaembed_query(self,text:str)->List[float]:"""Async call out to TextEmbed's embedding endpoint for a single query. Args: text (str): The text to embed. Returns: List[float]: Embeddings for the text. """embeddings=awaitself.aembed_documents([text])returnembeddings[0]
[docs]classAsyncOpenAITextEmbedEmbeddingClient:""" A client to handle synchronous and asynchronous requests to the TextEmbed API. Attributes: host (str): The base URL for the TextEmbed API. api_key (str): The API key for authenticating with the TextEmbed API. aiosession (Optional[aiohttp.ClientSession]): The aiohttp session for async requests. _batch_size (int): Maximum batch size for a single request. """# noqa: E501
[docs]def__init__(self,host:str="http://localhost:8000/v1",api_key:Union[str,None]=None,aiosession:Optional[aiohttp.ClientSession]=None,)->None:self.host=hostself.api_key=api_keyself.aiosession=aiosessionifself.hostisNoneorlen(self.host)<3:raiseValueError("Parameter `host` must be set to a valid URL")self._batch_size=256
@staticmethoddef_permute(texts:List[str],sorter:Callable=len)->Tuple[List[str],Callable]:""" Sorts texts in ascending order and provides a function to restore the original order. Args: texts (List[str]): List of texts to sort. sorter (Callable, optional): Sorting function, defaults to length. Returns: Tuple[List[str], Callable]: Sorted texts and a function to restore original order. """# noqa: E501iflen(texts)==1:returntexts,lambdat:tlength_sorted_idx=np.argsort([-sorter(sen)forsenintexts])texts_sorted=[texts[idx]foridxinlength_sorted_idx]returntexts_sorted,lambdaunsorted_embeddings:[unsorted_embeddings[idx]foridxinnp.argsort(length_sorted_idx)]def_batch(self,texts:List[str])->List[List[str]]:""" Splits a list of texts into batches of size max `self._batch_size`. Args: texts (List[str]): List of texts to split. Returns: List[List[str]]: List of batches of texts. """iflen(texts)==1:return[texts]batches=[]forstart_indexinrange(0,len(texts),self._batch_size):batches.append(texts[start_index:start_index+self._batch_size])returnbatches@staticmethoddef_unbatch(batch_of_texts:List[List[Any]])->List[Any]:""" Merges batches of texts into a single list. Args: batch_of_texts (List[List[Any]]): List of batches of texts. Returns: List[Any]: Merged list of texts. """iflen(batch_of_texts)==1andlen(batch_of_texts[0])==1:returnbatch_of_texts[0]texts=[]forsublistinbatch_of_texts:texts.extend(sublist)returntextsdef_kwargs_post_request(self,model:str,texts:List[str])->Dict[str,Any]:""" Builds the kwargs for the POST request, used by sync method. Args: model (str): The model to use for embedding. texts (List[str]): List of texts to embed. Returns: Dict[str, Any]: Dictionary of POST request parameters. """returndict(url=f"{self.host}/embedding",headers={"accept":"application/json","content-type":"application/json","Authorization":f"Bearer {self.api_key}",},json=dict(input=texts,model=model,),)def_sync_request_embed(self,model:str,batch_texts:List[str])->List[List[float]]:""" Sends a synchronous request to the embedding endpoint. Args: model (str): The model to use for embedding. batch_texts (List[str]): Batch of texts to embed. Returns: List[List[float]]: List of embeddings for the batch. Raises: Exception: If the response status is not 200. """response=requests.post(**self._kwargs_post_request(model=model,texts=batch_texts))ifresponse.status_code!=200:raiseException(f"TextEmbed responded with an unexpected status message "f"{response.status_code}: {response.text}")return[e["embedding"]foreinresponse.json()["data"]]
[docs]defembed(self,model:str,texts:List[str])->List[List[float]]:""" Embeds a list of texts synchronously. Args: model (str): The model to use for embedding. texts (List[str]): List of texts to embed. Returns: List[List[float]]: List of embeddings for the texts. """perm_texts,unpermute_func=self._permute(texts)perm_texts_batched=self._batch(perm_texts)# Requestmap_args=(self._sync_request_embed,[model]*len(perm_texts_batched),perm_texts_batched,)iflen(perm_texts_batched)==1:embeddings_batch_perm=list(map(*map_args))else:withThreadPoolExecutor(32)asp:embeddings_batch_perm=list(p.map(*map_args))embeddings_perm=self._unbatch(embeddings_batch_perm)embeddings=unpermute_func(embeddings_perm)returnembeddings
asyncdef_async_request(self,session:aiohttp.ClientSession,**kwargs:Dict[str,Any])->List[List[float]]:""" Sends an asynchronous request to the embedding endpoint. Args: session (aiohttp.ClientSession): The aiohttp session for the request. kwargs (Dict[str, Any]): Dictionary of POST request parameters. Returns: List[List[float]]: List of embeddings for the request. Raises: Exception: If the response status is not 200. """asyncwithsession.post(**kwargs)asresponse:# type: ignoreifresponse.status!=200:raiseException(f"TextEmbed responded with an unexpected status message "f"{response.status}: {response.text}")embedding=(awaitresponse.json())["data"]return[e["embedding"]foreinembedding]
[docs]asyncdefaembed(self,model:str,texts:List[str])->List[List[float]]:""" Embeds a list of texts asynchronously. Args: model (str): The model to use for embedding. texts (List[str]): List of texts to embed. Returns: List[List[float]]: List of embeddings for the texts. """perm_texts,unpermute_func=self._permute(texts)perm_texts_batched=self._batch(perm_texts)asyncwithaiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=32))assession:embeddings_batch_perm=awaitasyncio.gather(*[self._async_request(session=session,**self._kwargs_post_request(model=model,texts=t),)fortinperm_texts_batched])embeddings_perm=self._unbatch(embeddings_batch_perm)embeddings=unpermute_func(embeddings_perm)returnembeddings