Source code for langchain_azure_ai.embeddings.inference

"""Azure AI embeddings model inference API."""

import logging
from typing import (
    Any,
    AsyncGenerator,
    Dict,
    Generator,
    Mapping,
    Optional,
    Union,
)

from azure.ai.inference import EmbeddingsClient
from azure.ai.inference.aio import EmbeddingsClient as EmbeddingsClientAsync
from azure.ai.inference.models import EmbeddingInputType
from azure.core.credentials import AzureKeyCredential, TokenCredential
from azure.core.exceptions import HttpResponseError
from langchain_core.embeddings import Embeddings
from langchain_core.utils import get_from_dict_or_env, pre_init
from pydantic import BaseModel, ConfigDict, PrivateAttr, model_validator

from langchain_azure_ai.utils.utils import get_endpoint_from_project

logger = logging.getLogger(__name__)


[docs] class AzureAIEmbeddingsModel(BaseModel, Embeddings): """Azure AI model inference for embeddings. Examples: .. code-block:: python from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel embed_model = AzureAIEmbeddingsModel( endpoint="https://[your-endpoint].inference.ai.azure.com", credential="your-api-key", ) If your endpoint supports multiple models, indicate the parameter `model_name`: .. code-block:: python from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel embed_model = AzureAIEmbeddingsModel( endpoint="https://[your-service].services.ai.azure.com/models", credential="your-api-key", model_name="cohere-embed-v3-multilingual" ) Troubleshooting: To diagnostic issues with the model, you can enable debug logging: .. code-block:: python import sys import logging from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel logger = logging.getLogger("azure") # Set the desired logging level. logger.setLevel(logging.DEBUG) handler = logging.StreamHandler(stream=sys.stdout) logger.addHandler(handler) model = AzureAIEmbeddingsModel( endpoint="https://[your-service].services.ai.azure.com/models", credential="your-api-key", model_name="cohere-embed-v3-multilingual", client_kwargs={ "logging_enable": True } ) """ model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) project_connection_string: Optional[str] = None """The connection string to use for the Azure AI project. If this is specified, then the `endpoint` parameter becomes optional and `credential` has to be of type `TokenCredential`.""" endpoint: Optional[str] = None """The endpoint URI where the model is deployed. Either this or the `project_connection_string` parameter must be specified.""" credential: Union[str, AzureKeyCredential, TokenCredential] """The API key or credential to use for the Azure AI model inference.""" api_version: Optional[str] = None """The API version to use for the Azure AI model inference API. If None, the default version is used.""" model_name: Optional[str] = None """The name of the model to use for inference, if the endpoint is running more than one model. If not, this parameter is ignored.""" embed_batch_size: int = 1024 """The batch size for embedding requests. The default is 1024.""" dimensions: Optional[int] = None """The number of dimensions in the embeddings to generate. If None, the model's default is used.""" model_kwargs: Dict[str, Any] = {} """Additional kwargs model parameters.""" client_kwargs: Dict[str, Any] = {} """Additional kwargs for the Azure AI client used.""" _client: EmbeddingsClient = PrivateAttr() _async_client: EmbeddingsClientAsync = PrivateAttr() _embed_input_type: Optional[EmbeddingInputType] = PrivateAttr() _model_name: Optional[str] = PrivateAttr()
[docs] @pre_init def validate_environment(cls, values: Dict) -> Any: """Validate that api key exists in environment.""" values["endpoint"] = get_from_dict_or_env( values, "endpoint", "AZURE_INFERENCE_ENDPOINT" ) values["credential"] = get_from_dict_or_env( values, "credential", "AZURE_INFERENCE_CREDENTIAL" ) if values["api_version"]: values["client_kwargs"]["api_version"] = values["api_version"] return values
@model_validator(mode="after") def initialize_client(self) -> "AzureAIEmbeddingsModel": """Initialize the Azure AI model inference client.""" if self.project_connection_string: if not isinstance(self.credential, TokenCredential): raise ValueError( "When using the `project_connection_string` parameter, the " "`credential` parameter must be of type `TokenCredential`." ) self.endpoint, self.credential = get_endpoint_from_project( self.project_connection_string, self.credential ) credential = ( AzureKeyCredential(self.credential) if isinstance(self.credential, str) else self.credential ) self._client = EmbeddingsClient( endpoint=self.endpoint, # type: ignore[arg-type] credential=credential, # type: ignore[arg-type] model=self.model_name, user_agent="langchain-azure-ai", **self.client_kwargs, ) self._async_client = EmbeddingsClientAsync( endpoint=self.endpoint, # type: ignore[arg-type] credential=credential, # type: ignore[arg-type] model=self.model_name, user_agent="langchain-azure-ai", **self.client_kwargs, ) if not self.model_name: try: # Get model info from the endpoint. This method may not be supported # by all endpoints. model_info = self._client.get_model_info() self._model_name = model_info.get("model_name", None) self._embed_input_type = ( None if model_info.get("model_provider_name", None).lower() == "cohere" else EmbeddingInputType.TEXT ) except HttpResponseError: logger.warning( f"Endpoint '{self.endpoint}' does not support model metadata " "retrieval. Unable to populate model attributes." ) self._model_name = "" self._embed_input_type = EmbeddingInputType.TEXT else: self._embed_input_type = ( None if "cohere" in self.model_name.lower() else EmbeddingInputType.TEXT ) return self def _get_model_params(self, **kwargs: Dict[str, Any]) -> Mapping[str, Any]: params: Dict[str, Any] = {} if self.dimensions: params["dimensions"] = self.dimensions if self.model_kwargs: params["model_extras"] = self.model_kwargs params.update(kwargs) return params def _embed( self, texts: list[str], input_type: EmbeddingInputType ) -> Generator[list[float], None, None]: for text_batch in range(0, len(texts), self.embed_batch_size): response = self._client.embed( input=texts[text_batch : text_batch + self.embed_batch_size], input_type=self._embed_input_type or input_type, **self._get_model_params(), ) for data in response.data: yield data.embedding # type: ignore async def _embed_async( self, texts: list[str], input_type: EmbeddingInputType ) -> AsyncGenerator[list[float], None]: for text_batch in range(0, len(texts), self.embed_batch_size): response = await self._async_client.embed( input=texts[text_batch : text_batch + self.embed_batch_size], input_type=self._embed_input_type or input_type, **self._get_model_params(), ) for data in response.data: yield data.embedding # type: ignore
[docs] def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs. Args: texts: List of text to embed. Returns: List of embeddings. """ return list(self._embed(texts, EmbeddingInputType.DOCUMENT))
[docs] def embed_query(self, text: str) -> list[float]: """Embed query text. Args: text: Text to embed. Returns: Embedding. """ return list(self._embed([text], EmbeddingInputType.QUERY))[0]
[docs] async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Asynchronous Embed search docs. Args: texts: List of text to embed. Returns: List of embeddings. """ return self._embed_async(texts, EmbeddingInputType.DOCUMENT) # type: ignore[return-value]
[docs] async def aembed_query(self, text: str) -> list[float]: """Asynchronous Embed query text. Args: text: Text to embed. Returns: Embedding. """ async for item in self._embed_async([text], EmbeddingInputType.QUERY): return item return []