Source code for redisvl.utils.vectorize.base

import logging
from enum import Enum
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union

from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import Annotated

from redisvl.extensions.cache.embeddings import EmbeddingsCache
from redisvl.redis.utils import array_to_buffer
from redisvl.schema.fields import VectorDataType

logger = logging.getLogger(__name__)


class Vectorizers(Enum):
    azure_openai = "azure_openai"
    openai = "openai"
    cohere = "cohere"
    mistral = "mistral"
    vertexai = "vertexai"
    hf = "hf"
    voyageai = "voyageai"


class BaseVectorizer(BaseModel):
    """Base RedisVL vectorizer interface.

    This class defines the interface for text vectorization with an optional
    caching layer to improve performance by avoiding redundant API calls.

    Attributes:
        model: The name of the embedding model.
        dtype: The data type of the embeddings, defaults to "float32".
        dims: The dimensionality of the vectors.
        cache: Optional embedding cache to store and retrieve embeddings.
    """

    model: str
    dtype: str = "float32"
    dims: Annotated[Optional[int], Field(strict=True, gt=0)] = None
    cache: Optional[EmbeddingsCache] = Field(default=None)
    model_config = ConfigDict(arbitrary_types_allowed=True)

    @property
    def type(self) -> str:
        """Return the type of vectorizer."""
        return "base"

    @field_validator("dtype")
    @classmethod
    def check_dtype(cls, dtype):
        """Validate the data type is supported."""
        try:
            VectorDataType(dtype.upper())
        except ValueError:
            raise ValueError(
                f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
            )
        return dtype

    def embed(
        self,
        text: str,
        preprocess: Optional[Callable] = None,
        as_buffer: bool = False,
        skip_cache: bool = False,
        **kwargs,
    ) -> Union[List[float], bytes]:
        """Generate a vector embedding for a text string.

        Args:
            text: The text to convert to a vector embedding
            preprocess: Function to apply to the text before embedding
            as_buffer: Return the embedding as a binary buffer instead of a list
            skip_cache: Bypass the cache for this request
            **kwargs: Additional model-specific parameters

        Returns:
            The vector embedding as either a list of floats or binary buffer

        Examples:
            >>> embedding = vectorizer.embed("Hello world")
        """
        # Apply preprocessing if provided
        if preprocess is not None:
            text = preprocess(text)

        # Check cache if available and not skipped
        if self.cache is not None and not skip_cache:
            try:
                cache_result = self.cache.get(text=text, model_name=self.model)
                if cache_result:
                    logger.debug(f"Cache hit for text with model {self.model}")
                    return self._process_embedding(
                        cache_result["embedding"], as_buffer, self.dtype
                    )
            except Exception as e:
                logger.warning(f"Error accessing embedding cache: {str(e)}")

        # Generate embedding using provider-specific implementation
        cache_metadata = kwargs.pop("metadata", {})
        embedding = self._embed(text, **kwargs)

        # Store in cache if available and not skipped
        if self.cache is not None and not skip_cache:
            try:
                self.cache.set(
                    text=text,
                    model_name=self.model,
                    embedding=embedding,
                    metadata=cache_metadata,
                )
            except Exception as e:
                logger.warning(f"Error storing in embedding cache: {str(e)}")

        # Process and return result
        return self._process_embedding(embedding, as_buffer, self.dtype)

    def embed_many(
        self,
        texts: List[str],
        preprocess: Optional[Callable] = None,
        batch_size: int = 10,
        as_buffer: bool = False,
        skip_cache: bool = False,
        **kwargs,
    ) -> Union[List[List[float]], List[bytes]]:
        """Generate vector embeddings for multiple texts efficiently.

        Args:
            texts: List of texts to convert to vector embeddings
            preprocess: Function to apply to each text before embedding
            batch_size: Number of texts to process in each API call
            as_buffer: Return embeddings as binary buffers instead of lists
            skip_cache: Bypass the cache for this request
            **kwargs: Additional model-specific parameters

        Returns:
            List of vector embeddings in the same order as the input texts

        Examples:
            >>> embeddings = vectorizer.embed_many(["Hello", "World"], batch_size=2)
        """
        if not texts:
            return []

        # Apply preprocessing if provided
        if preprocess is not None:
            processed_texts = [preprocess(text) for text in texts]
        else:
            processed_texts = texts

        # Get cached embeddings and identify misses
        results, cache_misses, cache_miss_indices = self._get_from_cache_batch(
            processed_texts, skip_cache
        )

        # Generate embeddings for cache misses
        if cache_misses:
            cache_metadata = kwargs.pop("metadata", {})
            new_embeddings = self._embed_many(
                texts=cache_misses, batch_size=batch_size, **kwargs
            )

            # Store new embeddings in cache
            self._store_in_cache_batch(
                cache_misses, new_embeddings, cache_metadata, skip_cache
            )

            # Insert new embeddings into results array
            for idx, embedding in zip(cache_miss_indices, new_embeddings):
                results[idx] = embedding

        # Process and return results
        return [self._process_embedding(emb, as_buffer, self.dtype) for emb in results]

    async def aembed(
        self,
        text: str,
        preprocess: Optional[Callable] = None,
        as_buffer: bool = False,
        skip_cache: bool = False,
        **kwargs,
    ) -> Union[List[float], bytes]:
        """Asynchronously generate a vector embedding for a text string.

        Args:
            text: The text to convert to a vector embedding
            preprocess: Function to apply to the text before embedding
            as_buffer: Return the embedding as a binary buffer instead of a list
            skip_cache: Bypass the cache for this request
            **kwargs: Additional model-specific parameters

        Returns:
            The vector embedding as either a list of floats or binary buffer

        Examples:
            >>> embedding = await vectorizer.aembed("Hello world")
        """
        # Apply preprocessing if provided
        if preprocess is not None:
            text = preprocess(text)

        # Check cache if available and not skipped
        if self.cache is not None and not skip_cache:
            try:
                cache_result = await self.cache.aget(text=text, model_name=self.model)
                if cache_result:
                    logger.debug(f"Async cache hit for text with model {self.model}")
                    return self._process_embedding(
                        cache_result["embedding"], as_buffer, self.dtype
                    )
            except Exception as e:
                logger.warning(
                    f"Error accessing embedding cache asynchronously: {str(e)}"
                )

        # Generate embedding using provider-specific implementation
        cache_metadata = kwargs.pop("metadata", {})
        embedding = await self._aembed(text, **kwargs)

        # Store in cache if available and not skipped
        if self.cache is not None and not skip_cache:
            try:
                await self.cache.aset(
                    text=text,
                    model_name=self.model,
                    embedding=embedding,
                    metadata=cache_metadata,
                )
            except Exception as e:
                logger.warning(
                    f"Error storing in embedding cache asynchronously: {str(e)}"
                )

        # Process and return result
        return self._process_embedding(embedding, as_buffer, self.dtype)

    async def aembed_many(
        self,
        texts: List[str],
        preprocess: Optional[Callable] = None,
        batch_size: int = 10,
        as_buffer: bool = False,
        skip_cache: bool = False,
        **kwargs,
    ) -> Union[List[List[float]], List[bytes]]:
        """Asynchronously generate vector embeddings for multiple texts efficiently.

        Args:
            texts: List of texts to convert to vector embeddings
            preprocess: Function to apply to each text before embedding
            batch_size: Number of texts to process in each API call
            as_buffer: Return embeddings as binary buffers instead of lists
            skip_cache: Bypass the cache for this request
            **kwargs: Additional model-specific parameters

        Returns:
            List of vector embeddings in the same order as the input texts

        Examples:
            >>> embeddings = await vectorizer.aembed_many(["Hello", "World"], batch_size=2)
        """
        if not texts:
            return []

        # Apply preprocessing if provided
        if preprocess is not None:
            processed_texts = [preprocess(text) for text in texts]
        else:
            processed_texts = texts

        # Get cached embeddings and identify misses
        results, cache_misses, cache_miss_indices = await self._aget_from_cache_batch(
            processed_texts, skip_cache
        )

        # Generate embeddings for cache misses
        if cache_misses:
            cache_metadata = kwargs.pop("metadata", {})
            new_embeddings = await self._aembed_many(
                texts=cache_misses, batch_size=batch_size, **kwargs
            )

            # Store new embeddings in cache
            await self._astore_in_cache_batch(
                cache_misses, new_embeddings, cache_metadata, skip_cache
            )

            # Insert new embeddings into results array
            for idx, embedding in zip(cache_miss_indices, new_embeddings):
                results[idx] = embedding

        # Process and return results
        return [self._process_embedding(emb, as_buffer, self.dtype) for emb in results]

    def _embed(self, text: str, **kwargs) -> List[float]:
        """Generate a vector embedding for a single text."""
        raise NotImplementedError

    def _embed_many(
        self, texts: List[str], batch_size: int = 10, **kwargs
    ) -> List[List[float]]:
        """Generate vector embeddings for a batch of texts."""
        raise NotImplementedError

    async def _aembed(self, text: str, **kwargs) -> List[float]:
        """Asynchronously generate a vector embedding for a single text."""
        logger.warning(
            "This vectorizer has no async embed method. Falling back to sync."
        )
        return self._embed(text, **kwargs)

    async def _aembed_many(
        self, texts: List[str], batch_size: int = 10, **kwargs
    ) -> List[List[float]]:
        """Asynchronously generate vector embeddings for a batch of texts."""
        logger.warning(
            "This vectorizer has no async embed_many method. Falling back to sync."
        )
        return self._embed_many(texts, batch_size, **kwargs)

    def _get_from_cache_batch(
        self, texts: List[str], skip_cache: bool
    ) -> tuple[List[Optional[List[float]]], List[str], List[int]]:
        """Get vector embeddings from cache and track cache misses.

        Args:
            texts: List of texts to get from cache
            skip_cache: Whether to skip cache lookup

        Returns:
            Tuple of (results, cache_misses, cache_miss_indices)
        """
        results = [None] * len(texts)
        cache_misses = []
        cache_miss_indices = []

        # Skip cache if requested or no cache available
        if skip_cache or self.cache is None:
            return results, texts, list(range(len(texts)))  # type: ignore

        try:
            # Efficient batch cache lookup
            cache_results = self.cache.mget(texts=texts, model_name=self.model)

            # Process cache hits and collect misses
            for i, (text, cache_result) in enumerate(zip(texts, cache_results)):
                if cache_result:
                    results[i] = cache_result["embedding"]
                else:
                    cache_misses.append(text)
                    cache_miss_indices.append(i)

            logger.debug(
                f"Cache hits: {len(texts) - len(cache_misses)}, misses: {len(cache_misses)}"
            )
        except Exception as e:
            logger.warning(f"Error accessing embedding cache in batch: {str(e)}")
            # On cache error, process all texts
            cache_misses = texts
            cache_miss_indices = list(range(len(texts)))

        return results, cache_misses, cache_miss_indices  # type: ignore

    async def _aget_from_cache_batch(
        self, texts: List[str], skip_cache: bool
    ) -> tuple[List[Optional[List[float]]], List[str], List[int]]:
        """Asynchronously get vector embeddings from cache and track cache misses.

        Args:
            texts: List of texts to get from cache
            skip_cache: Whether to skip cache lookup

        Returns:
            Tuple of (results, cache_misses, cache_miss_indices)
        """
        results = [None] * len(texts)
        cache_misses = []
        cache_miss_indices = []

        # Skip cache if requested or no cache available
        if skip_cache or self.cache is None:
            return results, texts, list(range(len(texts)))  # type: ignore

        try:
            # Efficient batch cache lookup
            cache_results = await self.cache.amget(texts=texts, model_name=self.model)

            # Process cache hits and collect misses
            for i, (text, cache_result) in enumerate(zip(texts, cache_results)):
                if cache_result:
                    results[i] = cache_result["embedding"]
                else:
                    cache_misses.append(text)
                    cache_miss_indices.append(i)

            logger.debug(
                f"Async cache hits: {len(texts) - len(cache_misses)}, misses: {len(cache_misses)}"
            )
        except Exception as e:
            logger.warning(
                f"Error accessing embedding cache in batch asynchronously: {str(e)}"
            )
            # On cache error, process all texts
            cache_misses = texts
            cache_miss_indices = list(range(len(texts)))

        return results, cache_misses, cache_miss_indices  # type: ignore

    def _store_in_cache_batch(
        self,
        texts: List[str],
        embeddings: List[List[float]],
        metadata: Dict,
        skip_cache: bool,
    ) -> None:
        """Store a batch of vector embeddings in the cache.

        Args:
            texts: List of texts that were embedded
            embeddings: List of vector embeddings
            metadata: Metadata to store with the embeddings
            skip_cache: Whether to skip cache storage
        """
        if skip_cache or self.cache is None:
            return

        try:
            # Prepare batch cache storage items
            cache_items = [
                {
                    "text": text,
                    "model_name": self.model,
                    "embedding": emb,
                    "metadata": metadata,
                }
                for text, emb in zip(texts, embeddings)
            ]
            self.cache.mset(items=cache_items)
        except Exception as e:
            logger.warning(f"Error storing batch in embedding cache: {str(e)}")

    async def _astore_in_cache_batch(
        self,
        texts: List[str],
        embeddings: List[List[float]],
        metadata: Dict,
        skip_cache: bool,
    ) -> None:
        """Asynchronously store a batch of vector embeddings in the cache.

        Args:
            texts: List of texts that were embedded
            embeddings: List of vector embeddings
            metadata: Metadata to store with the embeddings
            skip_cache: Whether to skip cache storage
        """
        if skip_cache or self.cache is None:
            return

        try:
            # Prepare batch cache storage items
            cache_items = [
                {
                    "text": text,
                    "model_name": self.model,
                    "embedding": emb,
                    "metadata": metadata,
                }
                for text, emb in zip(texts, embeddings)
            ]
            await self.cache.amset(items=cache_items)
        except Exception as e:
            logger.warning(
                f"Error storing batch in embedding cache asynchronously: {str(e)}"
            )

    def batchify(self, seq: list, size: int, preprocess: Optional[Callable] = None):
        """Split a sequence into batches of specified size.

        Args:
            seq: Sequence to split into batches
            size: Batch size
            preprocess: Optional function to preprocess each item

        Yields:
            Batches of the sequence
        """
        for pos in range(0, len(seq), size):
            if preprocess is not None:
                yield [preprocess(chunk) for chunk in seq[pos : pos + size]]
            else:
                yield seq[pos : pos + size]

    def _process_embedding(
        self, embedding: Optional[List[float]], as_buffer: bool, dtype: str
    ):
        """Process the vector embedding format based on the as_buffer flag."""
        if embedding is not None:
            if as_buffer:
                return array_to_buffer(embedding, dtype)
        return embedding