Source code for langchain_aws.utilities.redis
from __future__ import annotations
import logging
import re
from typing import TYPE_CHECKING, Any, List, Optional, Pattern
import numpy as np
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
    from redis.client import Redis as RedisType  # type: ignore[import-untyped]
def _array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes:
    return np.array(array).astype(dtype).tobytes()
def _buffer_to_array(buffer: bytes, dtype: Any = np.float32) -> List[float]:
    return np.frombuffer(buffer, dtype=dtype).tolist()
[docs]
class TokenEscaper:
    """
    Escape punctuation within an input string.
    """
    # Characters that RediSearch requires us to escape during queries.
    # Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
    DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"
[docs]
    def __init__(self, escape_chars_re: Optional[Pattern] = None):
        if escape_chars_re:
            self.escaped_chars_re = escape_chars_re
        else:
            self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS) 
[docs]
    def escape(self, value: str) -> str:
        if not isinstance(value, str):
            raise TypeError(
                "Value must be a string object for token escaping."
                f"Got type {type(value)}"
            )
        def escape_symbol(match: re.Match) -> str:
            value = match.group(0)
            return f"\\{value}"
        return self.escaped_chars_re.sub(escape_symbol, value) 
 
[docs]
def get_client(redis_url: str, **kwargs: Any) -> RedisType:
    """Get a redis client from the connection url given. This helper accepts
    urls for Redis server (TCP with/without TLS or UnixSocket) as well as
    Redis Sentinel connections.
    Before creating a connection the existence of the database driver is checked
    and ValueError raised otherwise.
    To use, you should have the ``redis`` python package installed.
    Example:
        .. code-block:: python
            from langchain_community.utilities.redis import get_client
            redis_client = get_client(
                redis_url="redis://username:password@localhost:6379"
                index_name="my-index",
                embedding_function=embeddings.embed_query,
            )
    """
    # Initialize with necessary components.
    try:
        import redis  # type: ignore[import-untyped]
    except ImportError:
        raise ImportError(
            "Could not import redis python package. "
            "Please install it with `pip install redis>=4.1.0`."
        )
    # Connect to redis server from url, reconnect with cluster client if needed
    redis_client = redis.from_url(redis_url, **kwargs)
    if _check_for_cluster(redis_client):
        redis_client.close()
        redis_client = _redis_cluster_client(redis_url, **kwargs)
    return redis_client 
def _check_for_cluster(redis_client: RedisType) -> bool:
    import redis
    try:
        cluster_info = redis_client.info("cluster")
        return cluster_info["cluster_enabled"] == 1
    except redis.exceptions.RedisError:
        return False
def _redis_cluster_client(redis_url: str, **kwargs: Any) -> RedisType:
    from redis.cluster import RedisCluster  # type: ignore[import-untyped]
    return RedisCluster.from_url(redis_url, **kwargs)  # type: ignore[return-value]