import base64
import hashlib
import logging
from datetime import datetime
from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
)
from elasticsearch import (
Elasticsearch,
exceptions,
helpers,
)
from elasticsearch.helpers import BulkIndexError
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.load import dumps, loads
from langchain_core.stores import ByteStore
from langchain_elasticsearch.client import create_elasticsearch_client
if TYPE_CHECKING:
from elasticsearch import Elasticsearch
logger = logging.getLogger(__name__)
def _manage_cache_index(
es_client: Elasticsearch, index_name: str, mapping: Dict[str, Any]
) -> bool:
"""Write or update an index or alias according to the default mapping"""
if es_client.indices.exists_alias(name=index_name):
es_client.indices.put_mapping(index=index_name, body=mapping["mappings"])
return True
elif not es_client.indices.exists(index=index_name):
logger.debug(f"Creating new Elasticsearch index: {index_name}")
es_client.indices.create(index=index_name, body=mapping)
return False
return False
[docs]class ElasticsearchCache(BaseCache):
"""An Elasticsearch cache integration for LLMs."""
[docs] def __init__(
self,
index_name: str,
store_input: bool = True,
store_input_params: bool = True,
metadata: Optional[Dict[str, Any]] = None,
*,
es_url: Optional[str] = None,
es_cloud_id: Optional[str] = None,
es_user: Optional[str] = None,
es_api_key: Optional[str] = None,
es_password: Optional[str] = None,
es_params: Optional[Dict[str, Any]] = None,
):
"""
Initialize the Elasticsearch cache store by specifying the index/alias
to use and determining which additional information (like input, input
parameters, and any other metadata) should be stored in the cache.
Args:
index_name (str): The name of the index or the alias to use for the cache.
If they do not exist an index is created,
according to the default mapping defined by the `mapping` property.
store_input (bool): Whether to store the LLM input in the cache, i.e.,
the input prompt. Default to True.
store_input_params (bool): Whether to store the input parameters in the
cache, i.e., the LLM parameters used to generate the LLM response.
Default to True.
metadata (Optional[dict]): Additional metadata to store in the cache,
for filtering purposes. This must be JSON serializable in an
Elasticsearch document. Default to None.
es_url: URL of the Elasticsearch instance to connect to.
es_cloud_id: Cloud ID of the Elasticsearch instance to connect to.
es_user: Username to use when connecting to Elasticsearch.
es_password: Password to use when connecting to Elasticsearch.
es_api_key: API key to use when connecting to Elasticsearch.
es_params: Other parameters for the Elasticsearch client.
"""
self._index_name = index_name
self._store_input = store_input
self._store_input_params = store_input_params
self._metadata = metadata
self._es_client = create_elasticsearch_client(
url=es_url,
cloud_id=es_cloud_id,
api_key=es_api_key,
username=es_user,
password=es_password,
params=es_params,
)
self._is_alias = _manage_cache_index(
self._es_client,
self._index_name,
self.mapping,
)
@cached_property
def mapping(self) -> Dict[str, Any]:
"""Get the default mapping for the index."""
return {
"mappings": {
"properties": {
"llm_output": {"type": "text", "index": False},
"llm_params": {"type": "text", "index": False},
"llm_input": {"type": "text", "index": False},
"metadata": {"type": "object"},
"timestamp": {"type": "date"},
}
}
}
@staticmethod
def _key(prompt: str, llm_string: str) -> str:
"""Generate a key for the cache store."""
return hashlib.md5((prompt + llm_string).encode()).hexdigest()
[docs] def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
cache_key = self._key(prompt, llm_string)
if self._is_alias:
# get the latest record according to its writing date, in order to
# address cases where multiple indices have a doc with the same id
result = self._es_client.search(
index=self._index_name,
body={
"query": {"term": {"_id": cache_key}},
"sort": {"timestamp": {"order": "asc"}},
},
source_includes=["llm_output"],
)
if result["hits"]["total"]["value"] > 0:
record = result["hits"]["hits"][0]
else:
return None
else:
try:
record = self._es_client.get(
index=self._index_name, id=cache_key, source=["llm_output"]
)
except exceptions.NotFoundError:
return None
return [loads(item) for item in record["_source"]["llm_output"]]
[docs] def build_document(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> Dict[str, Any]:
"""Build the Elasticsearch document for storing a single LLM interaction"""
body: Dict[str, Any] = {
"llm_output": [dumps(item) for item in return_val],
"timestamp": datetime.now().isoformat(),
}
if self._store_input_params:
body["llm_params"] = llm_string
if self._metadata is not None:
body["metadata"] = self._metadata
if self._store_input:
body["llm_input"] = prompt
return body
[docs] def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update based on prompt and llm_string."""
body = self.build_document(prompt, llm_string, return_val)
self._es_client.index(
index=self._index_name,
id=self._key(prompt, llm_string),
body=body,
require_alias=self._is_alias,
refresh=True,
)
[docs] def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
self._es_client.delete_by_query(
index=self._index_name,
body={"query": {"match_all": {}}},
refresh=True,
wait_for_completion=True,
)
[docs]class ElasticsearchEmbeddingsCache(ByteStore):
"""An Elasticsearch store for caching embeddings."""
[docs] def __init__(
self,
index_name: str,
store_input: bool = True,
metadata: Optional[Dict[str, Any]] = None,
namespace: Optional[str] = None,
maximum_duplicates_allowed: int = 1,
*,
es_url: Optional[str] = None,
es_cloud_id: Optional[str] = None,
es_user: Optional[str] = None,
es_api_key: Optional[str] = None,
es_password: Optional[str] = None,
es_params: Optional[Dict[str, Any]] = None,
):
"""
Initialize the Elasticsearch cache store by specifying the index/alias
to use and determining which additional information (like input, input
parameters, and any other metadata) should be stored in the cache.
Provide a namespace to organize the cache.
Args:
index_name (str): The name of the index or the alias to use for the cache.
If they do not exist an index is created,
according to the default mapping defined by the `mapping` property.
store_input (bool): Whether to store the input in the cache.
Default to True.
metadata (Optional[dict]): Additional metadata to store in the cache,
for filtering purposes. This must be JSON serializable in an
Elasticsearch document. Default to None.
namespace (Optional[str]): A namespace to use for the cache.
maximum_duplicates_allowed (int): Defines the maximum number of duplicate
keys permitted. Must be used in scenarios where the same key appears
across multiple indices that share the same alias. Default to 1.
es_url: URL of the Elasticsearch instance to connect to.
es_cloud_id: Cloud ID of the Elasticsearch instance to connect to.
es_user: Username to use when connecting to Elasticsearch.
es_password: Password to use when connecting to Elasticsearch.
es_api_key: API key to use when connecting to Elasticsearch.
es_params: Other parameters for the Elasticsearch client.
"""
self._namespace = namespace
self._maximum_duplicates_allowed = maximum_duplicates_allowed
self._index_name = index_name
self._store_input = store_input
self._metadata = metadata
self._es_client = create_elasticsearch_client(
url=es_url,
cloud_id=es_cloud_id,
api_key=es_api_key,
username=es_user,
password=es_password,
params=es_params,
)
self._is_alias = _manage_cache_index(
self._es_client,
self._index_name,
self.mapping,
)
[docs] @staticmethod
def encode_vector(data: bytes) -> str:
"""Encode the vector data as bytes to as a base64 string."""
return base64.b64encode(data).decode("utf-8")
[docs] @staticmethod
def decode_vector(data: str) -> bytes:
"""Decode the base64 string to vector data as bytes."""
return base64.b64decode(data)
@cached_property
def mapping(self) -> Dict[str, Any]:
"""Get the default mapping for the index."""
return {
"mappings": {
"properties": {
"text_input": {"type": "text", "index": False},
"vector_dump": {
"type": "binary",
"doc_values": False,
},
"metadata": {"type": "object"},
"timestamp": {"type": "date"},
"namespace": {"type": "keyword"},
}
}
}
def _key(self, input_text: str) -> str:
"""Generate a key for the store."""
return hashlib.md5(((self._namespace or "") + input_text).encode()).hexdigest()
@classmethod
def _deduplicate_hits(cls, hits: List[dict]) -> Dict[str, bytes]:
"""
Collapse the results from a search query with multiple indices
returning only the latest version of the documents
"""
map_ids = {}
for hit in sorted(
hits,
key=lambda x: datetime.fromisoformat(x["_source"]["timestamp"]),
reverse=True,
):
vector_id: str = hit["_id"]
if vector_id not in map_ids:
map_ids[vector_id] = cls.decode_vector(hit["_source"]["vector_dump"])
return map_ids
[docs] def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
"""Get the values associated with the given keys."""
if not any(keys):
return []
cache_keys = [self._key(k) for k in keys]
if self._is_alias:
try:
results = self._es_client.search(
index=self._index_name,
body={
"query": {"ids": {"values": cache_keys}},
"size": len(cache_keys) * self._maximum_duplicates_allowed,
},
source_includes=["vector_dump", "timestamp"],
)
except exceptions.BadRequestError as e:
if "window too large" in (
e.body.get("error", {}).get("root_cause", [{}])[0].get("reason", "")
):
logger.warning(
"Exceeded the maximum window size, "
"Reduce the duplicates manually or lower "
"`maximum_duplicate_allowed.`"
)
raise e
total_hits = results["hits"]["total"]["value"]
if self._maximum_duplicates_allowed > 1 and total_hits > len(cache_keys):
logger.warning(
f"Deduplicating, found {total_hits} hits for {len(cache_keys)} keys"
)
map_ids = self._deduplicate_hits(results["hits"]["hits"])
else:
map_ids = {
r["_id"]: self.decode_vector(r["_source"]["vector_dump"])
for r in results["hits"]["hits"]
}
return [map_ids.get(k) for k in cache_keys]
else:
records = self._es_client.mget(
index=self._index_name, ids=cache_keys, source_includes=["vector_dump"]
)
return [
self.decode_vector(r["_source"]["vector_dump"]) if r["found"] else None
for r in records["docs"]
]
[docs] def build_document(self, text_input: str, vector: bytes) -> Dict[str, Any]:
"""Build the Elasticsearch document for storing a single embedding"""
body: Dict[str, Any] = {
"vector_dump": self.encode_vector(vector),
"timestamp": datetime.now().isoformat(),
}
if self._metadata is not None:
body["metadata"] = self._metadata
if self._store_input:
body["text_input"] = text_input
if self._namespace:
body["namespace"] = self._namespace
return body
def _bulk(self, actions: Iterable[Dict[str, Any]]) -> None:
try:
helpers.bulk(
client=self._es_client,
actions=actions,
index=self._index_name,
require_alias=self._is_alias,
refresh=True,
)
except BulkIndexError as e:
first_error = e.errors[0].get("index", {}).get("error", {})
logger.error(f"First bulk error reason: {first_error.get('reason')}")
raise e
[docs] def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
"""Set the values for the given keys."""
actions = (
{
"_op_type": "index",
"_id": self._key(key),
"_source": self.build_document(key, vector),
}
for key, vector in key_value_pairs
)
self._bulk(actions)
[docs] def mdelete(self, keys: Sequence[str]) -> None:
"""Delete the given keys and their associated values."""
actions = ({"_op_type": "delete", "_id": self._key(key)} for key in keys)
self._bulk(actions)
[docs] def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
"""Get an iterator over keys that match the given prefix."""
# TODO This method is not currently used by CacheBackedEmbeddings,
# we can leave it blank. It could be implemented with ES "index_prefixes",
# but they are limited and expensive.
raise NotImplementedError()