"""Astra DB - based caches."""from__future__importannotationsimporthashlibimportjsonfromfunctoolsimportlru_cache,wrapsfromtypingimportTYPE_CHECKING,Any,Awaitable,Callable,Generatorfromastrapy.dbimportAstraDB,AsyncAstraDB,loggerfromlangchain_core.cachesimportRETURN_VAL_TYPE,BaseCachefromlangchain_core.language_models.llmsimportaget_prompts,get_promptsfromlangchain_core.load.dumpimportdumpsfromlangchain_core.load.loadimportloadsfromlangchain_core.outputsimportGenerationfromtyping_extensionsimportoverridefromlangchain_astradb.utils.astradbimport(COMPONENT_NAME_CACHE,COMPONENT_NAME_SEMANTICCACHE,SetupMode,_AstraDBCollectionEnvironment,)ifTYPE_CHECKING:fromastrapy.authenticationimportTokenProviderfromlangchain_core.embeddingsimportEmbeddingsfromlangchain_core.language_modelsimportLLMASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME="langchain_astradb_cache"ASTRA_DB_SEMANTIC_CACHE_DEFAULT_COLLECTION_NAME="langchain_astradb_semantic_cache"ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD=0.85ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE=16def_hash(_input:str)->str:"""Use a deterministic hashing approach."""returnhashlib.md5(_input.encode()).hexdigest()# noqa: S324def_dumps_generations(generations:RETURN_VAL_TYPE)->str:"""Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`. Args: generations (RETURN_VAL_TYPE): A list of language model generations. Returns: str: a single string representing a list of generations. This function (+ its counterpart `_loads_generations`) rely on the dumps/loads pair with Reviver, so are able to deal with all subclasses of Generation. Each item in the list can be `dumps`ed to a string, then we make the whole list of strings into a json-dumped. """returnjson.dumps([dumps(_item)for_itemingenerations])def_loads_generations(generations_str:str)->RETURN_VAL_TYPE|None:"""Get Generations from a string. Deserialization of a string into a generic RETURN_VAL_TYPE (i.e. a sequence of `Generation`). See `_dumps_generations`, the inverse of this function. Args: generations_str (str): A string representing a list of generations. Compatible with the legacy cache-blob format Does not raise exceptions for malformed entries, just logs a warning and returns none: the caller should be prepared for such a cache miss. Returns: RETURN_VAL_TYPE: A list of generations. """try:return[loads(_item_str)for_item_strinjson.loads(generations_str)]except(json.JSONDecodeError,TypeError):# deferring the (soft) handling to after the legacy-format attemptpasstry:gen_dicts=json.loads(generations_str)# not relying on `_load_generations_from_json` (which could disappear):except(json.JSONDecodeError,TypeError):logger.warning(f"Malformed/unparsable cached blob encountered: '{generations_str}'")returnNoneelse:generations=[Generation(**generation_dict)forgeneration_dictingen_dicts]logger.warning(f"Legacy 'Generation' cached blob encountered: '{generations_str}'")returngenerations
[docs]def__init__(self,*,collection_name:str=ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,token:str|TokenProvider|None=None,api_endpoint:str|None=None,namespace:str|None=None,environment:str|None=None,pre_delete_collection:bool=False,setup_mode:SetupMode=SetupMode.SYNC,ext_callers:list[tuple[str|None,str|None]|str|None]|None=None,astra_db_client:AstraDB|None=None,async_astra_db_client:AsyncAstraDB|None=None,):"""Cache that uses Astra DB as a backend. It uses a single collection as a kv store The lookup keys, combined in the _id of the documents, are: - prompt, a string - llm_string, a deterministic str representation of the model parameters. (needed to prevent same-prompt-different-model collisions) Args: collection_name: name of the Astra DB collection to create/use. token: API token for Astra DB usage, either in the form of a string or a subclass of `astrapy.authentication.TokenProvider`. If not provided, the environment variable ASTRA_DB_APPLICATION_TOKEN is inspected. api_endpoint: full URL to the API endpoint, such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`. If not provided, the environment variable ASTRA_DB_API_ENDPOINT is inspected. namespace: namespace (aka keyspace) where the collection is created. If not provided, the environment variable ASTRA_DB_KEYSPACE is inspected. Defaults to the database's "default namespace". environment: a string specifying the environment of the target Data API. If omitted, defaults to "prod" (Astra DB production). Other values are in `astrapy.constants.Environment` enum class. setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or OFF). pre_delete_collection: whether to delete the collection before creating it. If False and the collection already exists, the collection will be used as is. ext_callers: one or more caller identities to identify Data API calls in the User-Agent header. This is a list of (name, version) pairs, or just strings if no version info is provided, which, if supplied, becomes the leading part of the User-Agent string in all API requests related to this component. astra_db_client: *DEPRECATED starting from version 0.3.5.* *Please use 'token', 'api_endpoint' and optionally 'environment'.* you can pass an already-created 'astrapy.db.AstraDB' instance (alternatively to 'token', 'api_endpoint' and 'environment'). async_astra_db_client: *DEPRECATED starting from version 0.3.5.* *Please use 'token', 'api_endpoint' and optionally 'environment'.* you can pass an already-created 'astrapy.db.AsyncAstraDB' instance (alternatively to 'token', 'api_endpoint' and 'environment'). """self.astra_env=_AstraDBCollectionEnvironment(collection_name=collection_name,token=token,api_endpoint=api_endpoint,keyspace=namespace,environment=environment,setup_mode=setup_mode,pre_delete_collection=pre_delete_collection,ext_callers=ext_callers,component_name=COMPONENT_NAME_CACHE,astra_db_client=astra_db_client,async_astra_db_client=async_astra_db_client,)self.collection=self.astra_env.collectionself.async_collection=self.astra_env.async_collection
[docs]defdelete_through_llm(self,prompt:str,llm:LLM,stop:list[str]|None=None)->None:"""A wrapper around `delete` with the LLM being passed. In case the llm(prompt) calls have a `stop` param, you should pass it here. """llm_string=get_prompts({**llm.dict(),"stop":stop},[],)[1]returnself.delete(prompt,llm_string=llm_string)
[docs]asyncdefadelete_through_llm(self,prompt:str,llm:LLM,stop:list[str]|None=None)->None:"""A wrapper around `adelete` with the LLM being passed. In case the llm(prompt) calls have a `stop` param, you should pass it here. """llm_string=(awaitaget_prompts({**llm.dict(),"stop":stop},[],))[1]returnawaitself.adelete(prompt,llm_string=llm_string)
[docs]defdelete(self,prompt:str,llm_string:str)->None:"""Evict from cache if there's an entry."""self.astra_env.ensure_db_setup()doc_id=self._make_id(prompt,llm_string)self.collection.delete_one({"_id":doc_id})
[docs]asyncdefadelete(self,prompt:str,llm_string:str)->None:"""Evict from cache if there's an entry."""awaitself.astra_env.aensure_db_setup()doc_id=self._make_id(prompt,llm_string)awaitself.async_collection.delete_one({"_id":doc_id})
_unset=["unset"]class_CachedAwaitable:"""Cache the result of an awaitable so it can be awaited multiple times."""def__init__(self,awaitable:Awaitable[Any]):self.awaitable=awaitableself.result=_unsetdef__await__(self)->Generator:ifself.resultis_unset:self.result=yield fromself.awaitable.__await__()returnself.resultdef_reawaitable(func:Callable)->Callable:"""Make an async function result awaitable multiple times."""@wraps(func)defwrapper(*args:Any,**kwargs:Any)->_CachedAwaitable:return_CachedAwaitable(func(*args,**kwargs))returnwrapperdef_async_lru_cache(maxsize:int=128)->Callable:"""Least-recently-used async cache decorator. Equivalent to functools.lru_cache for async functions. """defdecorating_function(user_function:Callable)->Callable:returnlru_cache(maxsize)(_reawaitable(user_function))returndecorating_function
[docs]def__init__(self,*,collection_name:str=ASTRA_DB_SEMANTIC_CACHE_DEFAULT_COLLECTION_NAME,token:str|TokenProvider|None=None,api_endpoint:str|None=None,namespace:str|None=None,environment:str|None=None,setup_mode:SetupMode=SetupMode.SYNC,pre_delete_collection:bool=False,embedding:Embeddings,metric:str|None=None,similarity_threshold:float=ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD,ext_callers:list[tuple[str|None,str|None]|str|None]|None=None,astra_db_client:AstraDB|None=None,async_astra_db_client:AsyncAstraDB|None=None,):"""Astra DB semantic cache. Cache that uses Astra DB as a vector-store backend for semantic (i.e. similarity-based) lookup. It uses a single (vector) collection and can store cached values from several LLMs, so the LLM's 'llm_string' is stored in the document metadata. You can choose the preferred similarity (or use the API default). The default score threshold is tuned to the default metric. Tune it carefully yourself if switching to another distance metric. Args: collection_name: name of the Astra DB collection to create/use. token: API token for Astra DB usage, either in the form of a string or a subclass of `astrapy.authentication.TokenProvider`. If not provided, the environment variable ASTRA_DB_APPLICATION_TOKEN is inspected. api_endpoint: full URL to the API endpoint, such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`. If not provided, the environment variable ASTRA_DB_API_ENDPOINT is inspected. namespace: namespace (aka keyspace) where the collection is created. If not provided, the environment variable ASTRA_DB_KEYSPACE is inspected. Defaults to the database's "default namespace". environment: a string specifying the environment of the target Data API. If omitted, defaults to "prod" (Astra DB production). Other values are in `astrapy.constants.Environment` enum class. setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or OFF). pre_delete_collection: whether to delete the collection before creating it. If False and the collection already exists, the collection will be used as is. embedding: Embedding provider for semantic encoding and search. metric: the function to use for evaluating similarity of text embeddings. Defaults to 'cosine' (alternatives: 'euclidean', 'dot_product') similarity_threshold: the minimum similarity for accepting a (semantic-search) match. ext_callers: one or more caller identities to identify Data API calls in the User-Agent header. This is a list of (name, version) pairs, or just strings if no version info is provided, which, if supplied, becomes the leading part of the User-Agent string in all API requests related to this component. astra_db_client: *DEPRECATED starting from version 0.3.5.* *Please use 'token', 'api_endpoint' and optionally 'environment'.* you can pass an already-created 'astrapy.db.AstraDB' instance (alternatively to 'token', 'api_endpoint' and 'environment'). async_astra_db_client: *DEPRECATED starting from version 0.3.5.* *Please use 'token', 'api_endpoint' and optionally 'environment'.* you can pass an already-created 'astrapy.db.AsyncAstraDB' instance (alternatively to 'token', 'api_endpoint' and 'environment'). """self.embedding=embeddingself.metric=metricself.similarity_threshold=similarity_thresholdself.collection_name=collection_name# The contract for this class has separate lookup and update:# in order to spare some embedding calculations we cache them between# the two calls.# Note: each instance of this class has its own `_get_embedding` with# its own lru.@lru_cache(maxsize=ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)def_cache_embedding(text:str)->list[float]:returnself.embedding.embed_query(text=text)self._get_embedding=_cache_embedding@_async_lru_cache(maxsize=ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)asyncdef_acache_embedding(text:str)->list[float]:returnawaitself.embedding.aembed_query(text=text)self._aget_embedding=_acache_embeddingembedding_dimension:int|Awaitable[int]|None=Noneifsetup_mode==SetupMode.ASYNC:embedding_dimension=self._aget_embedding_dimension()elifsetup_mode==SetupMode.SYNC:embedding_dimension=self._get_embedding_dimension()self.astra_env=_AstraDBCollectionEnvironment(collection_name=collection_name,token=token,api_endpoint=api_endpoint,keyspace=namespace,environment=environment,setup_mode=setup_mode,pre_delete_collection=pre_delete_collection,embedding_dimension=embedding_dimension,metric=metric,ext_callers=ext_callers,component_name=COMPONENT_NAME_SEMANTICCACHE,astra_db_client=astra_db_client,async_astra_db_client=async_astra_db_client,)self.collection=self.astra_env.collectionself.async_collection=self.astra_env.async_collection
def_get_embedding_dimension(self)->int:returnlen(self._get_embedding(text="This is a sample sentence."))asyncdef_aget_embedding_dimension(self)->int:returnlen(awaitself._aget_embedding(text="This is a sample sentence."))@staticmethoddef_make_id(prompt:str,llm_string:str)->str:returnf"{_hash(prompt)}#{_hash(llm_string)}"
[docs]deflookup_with_id(self,prompt:str,llm_string:str)->tuple[str,RETURN_VAL_TYPE]|None:"""Look up based on prompt and llm_string. If there are hits, return (document_id, cached_entry) for the top hit """self.astra_env.ensure_db_setup()prompt_embedding:list[float]=self._get_embedding(text=prompt)llm_string_hash=_hash(llm_string)hit=self.collection.find_one(filter={"llm_string_hash":llm_string_hash,},sort={"$vector":prompt_embedding},projection={"body_blob":True,"_id":True},include_similarity=True,)ifhitisNoneorhit["$similarity"]<self.similarity_threshold:returnNonegenerations=_loads_generations(hit["body_blob"])ifgenerationsisNone:returnNone# this protects against malformed cached items:returnhit["_id"],generations
[docs]asyncdefalookup_with_id(self,prompt:str,llm_string:str)->tuple[str,RETURN_VAL_TYPE]|None:"""Look up based on prompt and llm_string. If there are hits, return (document_id, cached_entry) for the top hit """awaitself.astra_env.aensure_db_setup()prompt_embedding:list[float]=awaitself._aget_embedding(text=prompt)llm_string_hash=_hash(llm_string)hit=awaitself.async_collection.find_one(filter={"llm_string_hash":llm_string_hash,},sort={"$vector":prompt_embedding},projection={"body_blob":True,"_id":True},include_similarity=True,)ifhitisNoneorhit["$similarity"]<self.similarity_threshold:returnNonegenerations=_loads_generations(hit["body_blob"])ifgenerationsisNone:returnNone# this protects against malformed cached items:returnhit["_id"],generations
[docs]deflookup_with_id_through_llm(self,prompt:str,llm:LLM,stop:list[str]|None=None)->tuple[str,RETURN_VAL_TYPE]|None:"""Look up based on prompt and LLM. If there are hits, return (document_id, cached_entry) for the top hit """llm_string=get_prompts({**llm.dict(),"stop":stop},[],)[1]returnself.lookup_with_id(prompt,llm_string=llm_string)
[docs]asyncdefalookup_with_id_through_llm(self,prompt:str,llm:LLM,stop:list[str]|None=None)->tuple[str,RETURN_VAL_TYPE]|None:"""Look up based on prompt and LLM. If there are hits, return (document_id, cached_entry) for the top hit """llm_string=(awaitaget_prompts({**llm.dict(),"stop":stop},[],))[1]returnawaitself.alookup_with_id(prompt,llm_string=llm_string)
[docs]defdelete_by_document_id(self,document_id:str)->None:"""Delete by document ID. Given this is a "similarity search" cache, an invalidation pattern that makes sense is first a lookup to get an ID, and then deleting with that ID. This is for the second step. """self.astra_env.ensure_db_setup()self.collection.delete_one({"_id":document_id})
[docs]asyncdefadelete_by_document_id(self,document_id:str)->None:"""Delete by document ID. Given this is a "similarity search" cache, an invalidation pattern that makes sense is first a lookup to get an ID, and then deleting with that ID. This is for the second step. """awaitself.astra_env.aensure_db_setup()awaitself.async_collection.delete_one({"_id":document_id})