Source code for langchain.memory.vectorstore

"""Class for a VectorStore-backed memory object."""

from typing import Any, Dict, List, Optional, Sequence, Union

from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field
from langchain_core.vectorstores import VectorStoreRetriever

from langchain.memory.chat_memory import BaseMemory
from langchain.memory.utils import get_prompt_input_key


[docs]class VectorStoreRetrieverMemory(BaseMemory): """VectorStoreRetriever-backed memory.""" retriever: VectorStoreRetriever = Field(exclude=True) """VectorStoreRetriever object to connect to.""" memory_key: str = "history" #: :meta private: """Key name to locate the memories in the result of load_memory_variables.""" input_key: Optional[str] = None """Key name to index the inputs to load_memory_variables.""" return_docs: bool = False """Whether or not to return the result of querying the database directly.""" exclude_input_keys: Sequence[str] = Field(default_factory=tuple) """Input keys to exclude in addition to memory key when constructing the document""" @property def memory_variables(self) -> List[str]: """The list of keys emitted from the load_memory_variables method.""" return [self.memory_key] def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str: """Get the input key for the prompt.""" if self.input_key is None: return get_prompt_input_key(inputs, self.memory_variables) return self.input_key def _documents_to_memory_variables( self, docs: List[Document] ) -> Dict[str, Union[List[Document], str]]: result: Union[List[Document], str] if not self.return_docs: result = "\n".join([doc.page_content for doc in docs]) else: result = docs return {self.memory_key: result}
[docs] def load_memory_variables( self, inputs: Dict[str, Any] ) -> Dict[str, Union[List[Document], str]]: """Return history buffer.""" input_key = self._get_prompt_input_key(inputs) query = inputs[input_key] docs = self.retriever.invoke(query) return self._documents_to_memory_variables(docs)
[docs] async def aload_memory_variables( self, inputs: Dict[str, Any] ) -> Dict[str, Union[List[Document], str]]: """Return history buffer.""" input_key = self._get_prompt_input_key(inputs) query = inputs[input_key] docs = await self.retriever.ainvoke(query) return self._documents_to_memory_variables(docs)
def _form_documents( self, inputs: Dict[str, Any], outputs: Dict[str, str] ) -> List[Document]: """Format context from this conversation to buffer.""" # Each document should only include the current turn, not the chat history exclude = set(self.exclude_input_keys) exclude.add(self.memory_key) filtered_inputs = {k: v for k, v in inputs.items() if k not in exclude} texts = [ f"{k}: {v}" for k, v in list(filtered_inputs.items()) + list(outputs.items()) ] page_content = "\n".join(texts) return [Document(page_content=page_content)]
[docs] def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save context from this conversation to buffer.""" documents = self._form_documents(inputs, outputs) self.retriever.add_documents(documents)
[docs] async def asave_context( self, inputs: Dict[str, Any], outputs: Dict[str, str] ) -> None: """Save context from this conversation to buffer.""" documents = self._form_documents(inputs, outputs) await self.retriever.aadd_documents(documents)
[docs] def clear(self) -> None: """Nothing to clear."""
[docs] async def aclear(self) -> None: """Nothing to clear."""