Source code for langchain_aws.vectorstores.inmemorydb.cache
from__future__importannotationsimporthashlibimportjsonimportloggingfromtypingimport(Any,Dict,List,Optional,Union,cast,)fromlangchain_core.cachesimportRETURN_VAL_TYPE,BaseCachefromlangchain_core.embeddingsimportEmbeddingsfromlangchain_core.load.dumpimportdumpsfromlangchain_core.load.loadimportloadsfromlangchain_core.outputsimportGenerationfromlangchain_aws.vectorstores.inmemorydbimportInMemoryVectorStorelogger=logging.getLogger(__file__)def_hash(_input:str)->str:"""Use a deterministic hashing approach."""returnhashlib.md5(_input.encode()).hexdigest()def_dump_generations_to_json(generations:RETURN_VAL_TYPE)->str:"""Dump generations to json. Args: generations (RETURN_VAL_TYPE): A list of language model generations. Returns: str: Json representing a list of generations. Warning: would not work well with arbitrary subclasses of `Generation` """returnjson.dumps([generation.dict()forgenerationingenerations])def_load_generations_from_json(generations_json:str)->RETURN_VAL_TYPE:"""Load generations from json. Args: generations_json (str): A string of json representing a list of generations. Raises: ValueError: Could not decode json string to list of generations. Returns: RETURN_VAL_TYPE: A list of generations. Warning: would not work well with arbitrary subclasses of `Generation` """try:results=json.loads(generations_json)return[Generation(**generation_dict)forgeneration_dictinresults]exceptjson.JSONDecodeError:raiseValueError(f"Could not decode json to list of generations: {generations_json}")def_dumps_generations(generations:RETURN_VAL_TYPE)->str:""" Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation` Args: generations (RETURN_VAL_TYPE): A list of language model generations. Returns: str: a single string representing a list of generations. This function (+ its counterpart `_loads_generations`) rely on the dumps/loads pair with Reviver, so are able to deal with all subclasses of Generation. Each item in the list can be `dumps`ed to a string, then we make the whole list of strings into a json-dumped. """returnjson.dumps([dumps(_item)for_itemingenerations])def_loads_generations(generations_str:str)->Union[RETURN_VAL_TYPE,None]:""" Deserialization of a string into a generic RETURN_VAL_TYPE (i.e. a sequence of `Generation`). See `_dumps_generations`, the inverse of this function. Args: generations_str (str): A string representing a list of generations. Compatible with the legacy cache-blob format Does not raise exceptions for malformed entries, just logs a warning and returns none: the caller should be prepared for such a cache miss. Returns: RETURN_VAL_TYPE: A list of generations. """try:generations=[loads(_item_str)for_item_strinjson.loads(generations_str)]returngenerationsexcept(json.JSONDecodeError,TypeError):# deferring the (soft) handling to after the legacy-format attemptpasstry:gen_dicts=json.loads(generations_str)# not relying on `_load_generations_from_json` (which could disappear):generations=[Generation(**generation_dict)forgeneration_dictingen_dicts]logger.warning(f"Legacy 'Generation' cached blob encountered: '{generations_str}'")returngenerationsexcept(json.JSONDecodeError,TypeError):logger.warning(f"Malformed/unparsable cached blob encountered: '{generations_str}'")returnNone
[docs]classInMemorySemanticCache(BaseCache):"""Cache that uses MemoryDB as a vector-store backend."""# TODO - implement a TTL policy in MemoryDBDEFAULT_SCHEMA={"content_key":"prompt","text":[{"name":"prompt"},{"name":"return_val"},{"name":"llm_string"},],}
[docs]def__init__(self,redis_url:str,embedding:Embeddings,score_threshold:float=0.2):"""Initialize by passing in the `init` GPTCache func Args: redis_url (str): URL to connect to MemoryDB. embedding (Embedding): Embedding provider for semantic encoding and search. score_threshold (float, 0.2): Example: .. code-block:: python from langchain_core.globals import set_llm_cache from langchain_aws.cache import InMemorySemanticCache set_llm_cache(InMemorySemanticCache( redis_url="redis://localhost:6379", embedding=OpenAIEmbeddings() )) """self._cache_dict:Dict[str,InMemoryVectorStore]={}self.redis_url=redis_urlself.embedding=embeddingself.score_threshold=score_threshold
def_index_name(self,llm_string:str)->str:hashed_index=_hash(llm_string)returnf"cache:{hashed_index}"def_get_llm_cache(self,llm_string:str)->InMemoryVectorStore:index_name=self._index_name(llm_string)# return vectorstore client for the specific llm stringifindex_nameinself._cache_dict:returnself._cache_dict[index_name]# create new vectorstore client for the specific llm stringtry:self._cache_dict[index_name]=InMemoryVectorStore.from_existing_index(embedding=self.embedding,index_name=index_name,redis_url=self.redis_url,schema=cast(Dict,self.DEFAULT_SCHEMA),)exceptValueError:inmemory=InMemoryVectorStore(embedding=self.embedding,index_name=index_name,redis_url=self.redis_url,index_schema=cast(Dict,self.DEFAULT_SCHEMA),)_embedding=self.embedding.embed_query(text="test")inmemory._create_index_if_not_exist(dim=len(_embedding))self._cache_dict[index_name]=inmemoryreturnself._cache_dict[index_name]
[docs]defclear(self,**kwargs:Any)->None:"""Clear semantic cache for a given llm_string."""index_name=self._index_name(kwargs["llm_string"])ifindex_nameinself._cache_dict:self._cache_dict[index_name].drop_index(index_name=index_name,delete_documents=True,redis_url=self.redis_url)delself._cache_dict[index_name]
[docs]deflookup(self,prompt:str,llm_string:str)->Optional[RETURN_VAL_TYPE]:"""Look up based on prompt and llm_string."""llm_cache=self._get_llm_cache(llm_string)generations:List=[]# Read from a Hashresults=llm_cache.similarity_search(query=prompt,distance_threshold=0.1,)ifresults:fordocumentinresults:try:generations.extend(loads(document.metadata["return_val"]))exceptException:logger.warning("Retrieving a cache value that could not be deserialized ""properly. This is likely due to the cache being in an ""older format. Please recreate your cache to avoid this ""error.")# In a previous life we stored the raw text directly# in the table, so assume it's in that format.generations.extend(_load_generations_from_json(document.metadata["return_val"]))returngenerationsifgenerationselseNone
[docs]defupdate(self,prompt:str,llm_string:str,return_val:RETURN_VAL_TYPE)->None:"""Update cache based on prompt and llm_string."""forgeninreturn_val:ifnotisinstance(gen,Generation):raiseValueError("InMemorySemanticCache only supports caching of "f"normal LLM generations, got {type(gen)}")llm_cache=self._get_llm_cache(llm_string)metadata={"llm_string":llm_string,"prompt":prompt,"return_val":dumps([gforginreturn_val]),}llm_cache.add_texts(texts=[prompt],metadatas=[metadata])