"""
TextEmbed: Embedding Inference Server
TextEmbed provides a high-throughput, low-latency solution for serving embeddings.
It supports various sentence-transformer models.
Now, it includes the ability to deploy image embedding models.
TextEmbed offers flexibility and scalability for diverse applications.
TextEmbed is maintained by Keval Dekivadiya and is licensed under the Apache-2.0 license.
""" # noqa: E501
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import aiohttp
import numpy as np
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.utils import from_env, secret_from_env
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
__all__ = ["TextEmbedEmbeddings"]
[docs]
class TextEmbedEmbeddings(BaseModel, Embeddings):
"""
A class to handle embedding requests to the TextEmbed API.
Attributes:
model : The TextEmbed model ID to use for embeddings.
api_url : The base URL for the TextEmbed API.
api_key : The API key for authenticating with the TextEmbed API.
client : The TextEmbed client instance.
Example:
.. code-block:: python
from langchain_community.embeddings import TextEmbedEmbeddings
embeddings = TextEmbedEmbeddings(
model="sentence-transformers/clip-ViT-B-32",
api_url="http://localhost:8000/v1",
api_key="<API_KEY>"
)
For more information: https://github.com/kevaldekivadiya2415/textembed/blob/main/docs/setup.md
""" # noqa: E501
model: str
"""Underlying TextEmbed model id."""
api_url: str = Field(
default_factory=from_env(
"TEXTEMBED_API_URL", default="http://localhost:8000/v1"
)
)
"""Endpoint URL to use."""
api_key: SecretStr = Field(default_factory=secret_from_env("TEXTEMBED_API_KEY"))
"""API Key for authentication"""
client: Any = None
"""TextEmbed client."""
model_config = ConfigDict(
extra="forbid",
)
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and URL exist in the environment."""
self.client = AsyncOpenAITextEmbedEmbeddingClient(
host=self.api_url, api_key=self.api_key.get_secret_value()
)
return self
[docs]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to TextEmbed's embedding endpoint.
Args:
texts (List[str]): The list of texts to embed.
Returns:
List[List[float]]: List of embeddings, one for each text.
"""
embeddings = self.client.embed(
model=self.model,
texts=texts,
)
return embeddings
[docs]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Async call out to TextEmbed's embedding endpoint.
Args:
texts (List[str]): The list of texts to embed.
Returns:
List[List[float]]: List of embeddings, one for each text.
"""
embeddings = await self.client.aembed(
model=self.model,
texts=texts,
)
return embeddings
[docs]
def embed_query(self, text: str) -> List[float]:
"""Call out to TextEmbed's embedding endpoint for a single query.
Args:
text (str): The text to embed.
Returns:
List[float]: Embeddings for the text.
"""
return self.embed_documents([text])[0]
[docs]
async def aembed_query(self, text: str) -> List[float]:
"""Async call out to TextEmbed's embedding endpoint for a single query.
Args:
text (str): The text to embed.
Returns:
List[float]: Embeddings for the text.
"""
embeddings = await self.aembed_documents([text])
return embeddings[0]
[docs]
class AsyncOpenAITextEmbedEmbeddingClient:
"""
A client to handle synchronous and asynchronous requests to the TextEmbed API.
Attributes:
host (str): The base URL for the TextEmbed API.
api_key (str): The API key for authenticating with the TextEmbed API.
aiosession (Optional[aiohttp.ClientSession]): The aiohttp session for async requests.
_batch_size (int): Maximum batch size for a single request.
""" # noqa: E501
[docs]
def __init__(
self,
host: str = "http://localhost:8000/v1",
api_key: Union[str, None] = None,
aiosession: Optional[aiohttp.ClientSession] = None,
) -> None:
self.host = host
self.api_key = api_key
self.aiosession = aiosession
if self.host is None or len(self.host) < 3:
raise ValueError("Parameter `host` must be set to a valid URL")
self._batch_size = 256
@staticmethod
def _permute(
texts: List[str], sorter: Callable = len
) -> Tuple[List[str], Callable]:
"""
Sorts texts in ascending order and provides a function to restore the original order.
Args:
texts (List[str]): List of texts to sort.
sorter (Callable, optional): Sorting function, defaults to length.
Returns:
Tuple[List[str], Callable]: Sorted texts and a function to restore original order.
""" # noqa: E501
if len(texts) == 1:
return texts, lambda t: t
length_sorted_idx = np.argsort([-sorter(sen) for sen in texts])
texts_sorted = [texts[idx] for idx in length_sorted_idx]
return texts_sorted, lambda unsorted_embeddings: [
unsorted_embeddings[idx] for idx in np.argsort(length_sorted_idx)
]
def _batch(self, texts: List[str]) -> List[List[str]]:
"""
Splits a list of texts into batches of size max `self._batch_size`.
Args:
texts (List[str]): List of texts to split.
Returns:
List[List[str]]: List of batches of texts.
"""
if len(texts) == 1:
return [texts]
batches = []
for start_index in range(0, len(texts), self._batch_size):
batches.append(texts[start_index : start_index + self._batch_size])
return batches
@staticmethod
def _unbatch(batch_of_texts: List[List[Any]]) -> List[Any]:
"""
Merges batches of texts into a single list.
Args:
batch_of_texts (List[List[Any]]): List of batches of texts.
Returns:
List[Any]: Merged list of texts.
"""
if len(batch_of_texts) == 1 and len(batch_of_texts[0]) == 1:
return batch_of_texts[0]
texts = []
for sublist in batch_of_texts:
texts.extend(sublist)
return texts
def _kwargs_post_request(self, model: str, texts: List[str]) -> Dict[str, Any]:
"""
Builds the kwargs for the POST request, used by sync method.
Args:
model (str): The model to use for embedding.
texts (List[str]): List of texts to embed.
Returns:
Dict[str, Any]: Dictionary of POST request parameters.
"""
return dict(
url=f"{self.host}/embedding",
headers={
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {self.api_key}",
},
json=dict(
input=texts,
model=model,
),
)
def _sync_request_embed(
self, model: str, batch_texts: List[str]
) -> List[List[float]]:
"""
Sends a synchronous request to the embedding endpoint.
Args:
model (str): The model to use for embedding.
batch_texts (List[str]): Batch of texts to embed.
Returns:
List[List[float]]: List of embeddings for the batch.
Raises:
Exception: If the response status is not 200.
"""
response = requests.post(
**self._kwargs_post_request(model=model, texts=batch_texts)
)
if response.status_code != 200:
raise Exception(
f"TextEmbed responded with an unexpected status message "
f"{response.status_code}: {response.text}"
)
return [e["embedding"] for e in response.json()["data"]]
[docs]
def embed(self, model: str, texts: List[str]) -> List[List[float]]:
"""
Embeds a list of texts synchronously.
Args:
model (str): The model to use for embedding.
texts (List[str]): List of texts to embed.
Returns:
List[List[float]]: List of embeddings for the texts.
"""
perm_texts, unpermute_func = self._permute(texts)
perm_texts_batched = self._batch(perm_texts)
# Request
map_args = (
self._sync_request_embed,
[model] * len(perm_texts_batched),
perm_texts_batched,
)
if len(perm_texts_batched) == 1:
embeddings_batch_perm = list(map(*map_args))
else:
with ThreadPoolExecutor(32) as p:
embeddings_batch_perm = list(p.map(*map_args))
embeddings_perm = self._unbatch(embeddings_batch_perm)
embeddings = unpermute_func(embeddings_perm)
return embeddings
async def _async_request(
self, session: aiohttp.ClientSession, **kwargs: Dict[str, Any]
) -> List[List[float]]:
"""
Sends an asynchronous request to the embedding endpoint.
Args:
session (aiohttp.ClientSession): The aiohttp session for the request.
kwargs (Dict[str, Any]): Dictionary of POST request parameters.
Returns:
List[List[float]]: List of embeddings for the request.
Raises:
Exception: If the response status is not 200.
"""
async with session.post(**kwargs) as response: # type: ignore
if response.status != 200:
raise Exception(
f"TextEmbed responded with an unexpected status message "
f"{response.status}: {response.text}"
)
embedding = (await response.json())["data"]
return [e["embedding"] for e in embedding]
[docs]
async def aembed(self, model: str, texts: List[str]) -> List[List[float]]:
"""
Embeds a list of texts asynchronously.
Args:
model (str): The model to use for embedding.
texts (List[str]): List of texts to embed.
Returns:
List[List[float]]: List of embeddings for the texts.
"""
perm_texts, unpermute_func = self._permute(texts)
perm_texts_batched = self._batch(perm_texts)
async with aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=32)
) as session:
embeddings_batch_perm = await asyncio.gather(
*[
self._async_request(
session=session,
**self._kwargs_post_request(model=model, texts=t),
)
for t in perm_texts_batched
]
)
embeddings_perm = self._unbatch(embeddings_batch_perm)
embeddings = unpermute_func(embeddings_perm)
return embeddings