Source code for langchain_community.embeddings.tensorflow_hub

from typing import Any, List

from langchain_core.embeddings import Embeddings
from pydantic import BaseModel, ConfigDict

DEFAULT_MODEL_URL = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"


[docs] class TensorflowHubEmbeddings(BaseModel, Embeddings): """TensorflowHub embedding models. To use, you should have the ``tensorflow_text`` python package installed. Example: .. code-block:: python from langchain_community.embeddings import TensorflowHubEmbeddings url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3" tf = TensorflowHubEmbeddings(model_url=url) """ embed: Any = None #: :meta private: model_url: str = DEFAULT_MODEL_URL """Model name to use.""" def __init__(self, **kwargs: Any): """Initialize the tensorflow_hub and tensorflow_text.""" super().__init__(**kwargs) try: import tensorflow_hub except ImportError: raise ImportError( "Could not import tensorflow-hub python package. " "Please install it with `pip install tensorflow-hub``." ) try: import tensorflow_text # noqa except ImportError: raise ImportError( "Could not import tensorflow_text python package. " "Please install it with `pip install tensorflow_text``." ) self.embed = tensorflow_hub.load(self.model_url) model_config = ConfigDict( extra="forbid", protected_namespaces=(), )
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """Compute doc embeddings using a TensorflowHub embedding model. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """ texts = list(map(lambda x: x.replace("\n", " "), texts)) embeddings = self.embed(texts).numpy() return embeddings.tolist()
[docs] def embed_query(self, text: str) -> List[float]: """Compute query embeddings using a TensorflowHub embedding model. Args: text: The text to embed. Returns: Embeddings for the text. """ text = text.replace("\n", " ") embedding = self.embed([text]).numpy()[0] return embedding.tolist()