Source code for langchain_azure_ai.embeddings.inference

"""Azure AI embeddings model inference API."""

import logging
from typing import (
    Any,
    Dict,
    Mapping,
    Optional,
)

from azure.ai.inference import EmbeddingsClient
from azure.ai.inference.aio import EmbeddingsClient as EmbeddingsClientAsync
from azure.ai.inference.models import EmbeddingInputType
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import HttpResponseError
from langchain_core.embeddings import Embeddings
from pydantic import Field, PrivateAttr, model_validator

from langchain_azure_ai._resources import ModelInferenceService

logger = logging.getLogger(__name__)


[docs] class AzureAIEmbeddingsModel(ModelInferenceService, Embeddings): """Azure AI model inference for embeddings. Examples: .. code-block:: python from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel embed_model = AzureAIEmbeddingsModel( endpoint="https://[your-endpoint].inference.ai.azure.com", credential="your-api-key", ) If your endpoint supports multiple models, indicate the parameter `model_name`: .. code-block:: python from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel embed_model = AzureAIEmbeddingsModel( endpoint="https://[your-service].services.ai.azure.com/models", credential="your-api-key", model="cohere-embed-v3-multilingual" ) Troubleshooting: To diagnostic issues with the model, you can enable debug logging: .. code-block:: python import sys import logging from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel logger = logging.getLogger("azure") # Set the desired logging level. logger.setLevel(logging.DEBUG) handler = logging.StreamHandler(stream=sys.stdout) logger.addHandler(handler) model = AzureAIEmbeddingsModel( endpoint="https://[your-service].services.ai.azure.com/models", credential="your-api-key", model="cohere-embed-v3-multilingual", client_kwargs={ "logging_enable": True } ) """ model_name: Optional[str] = Field(default=None, alias="model") """The name of the model to use for inference, if the endpoint is running more than one model. If not, this parameter is ignored.""" embed_batch_size: int = 1024 """The batch size for embedding requests. The default is 1024.""" dimensions: Optional[int] = None """The number of dimensions in the embeddings to generate. If None, the model's default is used.""" model_kwargs: Dict[str, Any] = {} """Additional kwargs model parameters.""" _client: EmbeddingsClient = PrivateAttr() _async_client: EmbeddingsClientAsync = PrivateAttr() _embed_input_type: Optional[EmbeddingInputType] = PrivateAttr() _model_name: Optional[str] = PrivateAttr() @model_validator(mode="after") def initialize_client(self) -> "AzureAIEmbeddingsModel": """Initialize the Azure AI model inference client.""" credential = ( AzureKeyCredential(self.credential) if isinstance(self.credential, str) else self.credential ) self._client = EmbeddingsClient( endpoint=self.endpoint, # type: ignore[arg-type] credential=credential, # type: ignore[arg-type] model=self.model_name, **self.client_kwargs, ) self._async_client = EmbeddingsClientAsync( endpoint=self.endpoint, # type: ignore[arg-type] credential=credential, # type: ignore[arg-type] model=self.model_name, **self.client_kwargs, ) if not self.model_name: try: # Get model info from the endpoint. This method may not be supported # by all endpoints. model_info = self._client.get_model_info() self._model_name = model_info.get("model_name", None) self._embed_input_type = ( None if model_info.get("model_provider_name", None).lower() == "cohere" else EmbeddingInputType.TEXT ) except HttpResponseError: logger.warning( f"Endpoint '{self.endpoint}' does not support model metadata " "retrieval. Unable to populate model attributes." ) self._model_name = "" self._embed_input_type = EmbeddingInputType.TEXT else: self._embed_input_type = ( None if "cohere" in self.model_name.lower() else EmbeddingInputType.TEXT ) return self def _get_model_params(self, **kwargs: Dict[str, Any]) -> Mapping[str, Any]: params: Dict[str, Any] = {} if self.dimensions: params["dimensions"] = self.dimensions if self.model_kwargs: params["model_extras"] = self.model_kwargs params.update(kwargs) return params def _embed( self, texts: list[str], input_type: EmbeddingInputType ) -> list[list[float]]: embeddings = [] for text_batch in range(0, len(texts), self.embed_batch_size): response = self._client.embed( input=texts[text_batch : text_batch + self.embed_batch_size], input_type=self._embed_input_type or input_type, **self._get_model_params(), ) embeddings.extend([data.embedding for data in response.data]) return embeddings # type: ignore[return-value] async def _embed_async( self, texts: list[str], input_type: EmbeddingInputType ) -> list[list[float]]: embeddings = [] for text_batch in range(0, len(texts), self.embed_batch_size): response = await self._async_client.embed( input=texts[text_batch : text_batch + self.embed_batch_size], input_type=self._embed_input_type or input_type, **self._get_model_params(), ) embeddings.extend([data.embedding for data in response.data]) return embeddings # type: ignore[return-value]
[docs] def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs. Args: texts: List of text to embed. Returns: List of embeddings. """ return self._embed(texts, EmbeddingInputType.DOCUMENT)
[docs] def embed_query(self, text: str) -> list[float]: """Embed query text. Args: text: Text to embed. Returns: Embedding. """ return self._embed([text], EmbeddingInputType.QUERY)[0]
[docs] async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Asynchronous Embed search docs. Args: texts: List of text to embed. Returns: List of embeddings. """ return await self._embed_async(texts, EmbeddingInputType.DOCUMENT)
[docs] async def aembed_query(self, text: str) -> list[float]: """Asynchronous Embed query text. Args: text: Text to embed. Returns: Embedding. """ embeddings = await self._embed_async([text], EmbeddingInputType.QUERY) return embeddings[0] if embeddings else []