Source code for langchain.embeddings.base

import functools
from importlib import util
from typing import Any, List, Optional, Tuple, Union

from langchain_core._api import beta
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import Runnable

_SUPPORTED_PROVIDERS = {
    "azure_openai": "langchain_openai",
    "bedrock": "langchain_aws",
    "cohere": "langchain_cohere",
    "google_vertexai": "langchain_google_vertexai",
    "huggingface": "langchain_huggingface",
    "mistralai": "langchain_mistralai",
    "openai": "langchain_openai",
}


def _get_provider_list() -> str:
    """Get formatted list of providers and their packages."""
    return "\n".join(
        f"  - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()
    )


def _parse_model_string(model_name: str) -> Tuple[str, str]:
    """Parse a model string into provider and model name components.

    The model string should be in the format 'provider:model-name', where provider
    is one of the supported providers.

    Args:
        model_name: A model string in the format 'provider:model-name'

    Returns:
        A tuple of (provider, model_name)

    .. code-block:: python

        _parse_model_string("openai:text-embedding-3-small")
        # Returns: ("openai", "text-embedding-3-small")

        _parse_model_string("bedrock:amazon.titan-embed-text-v1")
        # Returns: ("bedrock", "amazon.titan-embed-text-v1")

    Raises:
        ValueError: If the model string is not in the correct format or
            the provider is unsupported
    """
    if ":" not in model_name:
        providers = _SUPPORTED_PROVIDERS
        raise ValueError(
            f"Invalid model format '{model_name}'.\n"
            f"Model name must be in format 'provider:model-name'\n"
            f"Example valid model strings:\n"
            f"  - openai:text-embedding-3-small\n"
            f"  - bedrock:amazon.titan-embed-text-v1\n"
            f"  - cohere:embed-english-v3.0\n"
            f"Supported providers: {providers}"
        )

    provider, model = model_name.split(":", 1)
    provider = provider.lower().strip()
    model = model.strip()

    if provider not in _SUPPORTED_PROVIDERS:
        raise ValueError(
            f"Provider '{provider}' is not supported.\n"
            f"Supported providers and their required packages:\n"
            f"{_get_provider_list()}"
        )
    if not model:
        raise ValueError("Model name cannot be empty")
    return provider, model


def _infer_model_and_provider(
    model: str, *, provider: Optional[str] = None
) -> Tuple[str, str]:
    if not model.strip():
        raise ValueError("Model name cannot be empty")
    if provider is None and ":" in model:
        provider, model_name = _parse_model_string(model)
    else:
        provider = provider
        model_name = model

    if not provider:
        providers = _SUPPORTED_PROVIDERS
        raise ValueError(
            "Must specify either:\n"
            "1. A model string in format 'provider:model-name'\n"
            "   Example: 'openai:text-embedding-3-small'\n"
            "2. Or explicitly set provider from: "
            f"{providers}"
        )

    if provider not in _SUPPORTED_PROVIDERS:
        raise ValueError(
            f"Provider '{provider}' is not supported.\n"
            f"Supported providers and their required packages:\n"
            f"{_get_provider_list()}"
        )
    return provider, model_name


@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
def _check_pkg(pkg: str) -> None:
    """Check if a package is installed."""
    if not util.find_spec(pkg):
        raise ImportError(
            f"Could not import {pkg} python package. "
            f"Please install it with `pip install {pkg}`"
        )


[docs] @beta() def init_embeddings( model: str, *, provider: Optional[str] = None, **kwargs: Any, ) -> Union[Embeddings, Runnable[Any, List[float]]]: """Initialize an embeddings model from a model name and optional provider. **Note:** Must have the integration package corresponding to the model provider installed. Args: model: Name of the model to use. Can be either: - A model string like "openai:text-embedding-3-small" - Just the model name if provider is specified provider: Optional explicit provider name. If not specified, will attempt to parse from the model string. Supported providers and their required packages: {_get_provider_list()} **kwargs: Additional model-specific parameters passed to the embedding model. These vary by provider, see the provider-specific documentation for details. Returns: An Embeddings instance that can generate embeddings for text. Raises: ValueError: If the model provider is not supported or cannot be determined ImportError: If the required provider package is not installed .. dropdown:: Example Usage :open: .. code-block:: python # Using a model string model = init_embeddings("openai:text-embedding-3-small") model.embed_query("Hello, world!") # Using explicit provider model = init_embeddings( model="text-embedding-3-small", provider="openai" ) model.embed_documents(["Hello, world!", "Goodbye, world!"]) # With additional parameters model = init_embeddings( "openai:text-embedding-3-small", api_key="sk-..." ) .. versionadded:: 0.3.9 """ if not model: providers = _SUPPORTED_PROVIDERS.keys() raise ValueError( "Must specify model name. " f"Supported providers are: {', '.join(providers)}" ) provider, model_name = _infer_model_and_provider(model, provider=provider) pkg = _SUPPORTED_PROVIDERS[provider] _check_pkg(pkg) if provider == "openai": from langchain_openai import OpenAIEmbeddings return OpenAIEmbeddings(model=model_name, **kwargs) elif provider == "azure_openai": from langchain_openai import AzureOpenAIEmbeddings return AzureOpenAIEmbeddings(model=model_name, **kwargs) elif provider == "google_vertexai": from langchain_google_vertexai import VertexAIEmbeddings return VertexAIEmbeddings(model=model_name, **kwargs) elif provider == "bedrock": from langchain_aws import BedrockEmbeddings return BedrockEmbeddings(model_id=model_name, **kwargs) elif provider == "cohere": from langchain_cohere import CohereEmbeddings return CohereEmbeddings(model=model_name, **kwargs) elif provider == "mistralai": from langchain_mistralai import MistralAIEmbeddings return MistralAIEmbeddings(model=model_name, **kwargs) elif provider == "huggingface": from langchain_huggingface import HuggingFaceEmbeddings return HuggingFaceEmbeddings(model_name=model_name, **kwargs) else: raise ValueError( f"Provider '{provider}' is not supported.\n" f"Supported providers and their required packages:\n" f"{_get_provider_list()}" )
__all__ = [ "init_embeddings", "Embeddings", # This one is for backwards compatibility ]