Source code for langchain_community.graph_vectorstores.cassandra

"""Apache Cassandra DB graph vector store integration."""

from __future__ import annotations

import asyncio
import json
import logging
import secrets
from dataclasses import asdict, is_dataclass
from typing import (
    TYPE_CHECKING,
    Any,
    AsyncIterable,
    Iterable,
    List,
    Optional,
    Sequence,
    Tuple,
    Type,
    TypeVar,
    cast,
)

from langchain_core._api import beta
from langchain_core.documents import Document
from typing_extensions import override

from langchain_community.graph_vectorstores.base import GraphVectorStore, Node
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
from langchain_community.graph_vectorstores.mmr_helper import MmrHelper
from langchain_community.utilities.cassandra import SetupMode
from langchain_community.vectorstores.cassandra import Cassandra as CassandraVectorStore

CGVST = TypeVar("CGVST", bound="CassandraGraphVectorStore")

if TYPE_CHECKING:
    from cassandra.cluster import Session
    from langchain_core.embeddings import Embeddings


logger = logging.getLogger(__name__)


[docs] class AdjacentNode: id: str links: list[Link] embedding: list[float]
[docs] def __init__(self, node: Node, embedding: list[float]) -> None: """Create an Adjacent Node.""" self.id = node.id or "" self.links = node.links self.embedding = embedding
def _serialize_links(links: list[Link]) -> str: class SetAndLinkEncoder(json.JSONEncoder): def default(self, obj: Any) -> Any: # noqa: ANN401 if not isinstance(obj, type) and is_dataclass(obj): return asdict(obj) if isinstance(obj, Iterable): return list(obj) # Let the base class default method raise the TypeError return super().default(obj) return json.dumps(links, cls=SetAndLinkEncoder) def _deserialize_links(json_blob: str | None) -> set[Link]: return { Link(kind=link["kind"], direction=link["direction"], tag=link["tag"]) for link in cast(list[dict[str, Any]], json.loads(json_blob or "[]")) } def _metadata_link_key(link: Link) -> str: return f"link:{link.kind}:{link.tag}" def _metadata_link_value() -> str: return "link" def _doc_to_node(doc: Document) -> Node: metadata = doc.metadata.copy() links = _deserialize_links(metadata.get(METADATA_LINKS_KEY)) metadata[METADATA_LINKS_KEY] = links return Node( id=doc.id, text=doc.page_content, metadata=metadata, links=list(links), ) def _incoming_links(node: Node | AdjacentNode) -> set[Link]: return {link for link in node.links if link.direction in ["in", "bidir"]} def _outgoing_links(node: Node | AdjacentNode) -> set[Link]: return {link for link in node.links if link.direction in ["out", "bidir"]}
[docs] @beta() class CassandraGraphVectorStore(GraphVectorStore):
[docs] def __init__( self, embedding: Embeddings, session: Session | None = None, keyspace: str | None = None, table_name: str = "", ttl_seconds: int | None = None, *, body_index_options: list[tuple[str, Any]] | None = None, setup_mode: SetupMode = SetupMode.SYNC, metadata_deny_list: Optional[list[str]] = None, ) -> None: """Apache Cassandra(R) for graph-vector-store workloads. To use it, you need a recent installation of the `cassio` library and a Cassandra cluster / Astra DB instance supporting vector capabilities. Example: .. code-block:: python from langchain_community.graph_vectorstores import CassandraGraphVectorStore from langchain_openai import OpenAIEmbeddings embeddings = OpenAIEmbeddings() session = ... # create your Cassandra session object keyspace = 'my_keyspace' # the keyspace should exist already table_name = 'my_graph_vector_store' vectorstore = CassandraGraphVectorStore( embeddings, session, keyspace, table_name, ) Args: embedding: Embedding function to use. session: Cassandra driver session. If not provided, it is resolved from cassio. keyspace: Cassandra keyspace. If not provided, it is resolved from cassio. table_name: Cassandra table (required). ttl_seconds: Optional time-to-live for the added texts. body_index_options: Optional options used to create the body index. Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] setup_mode: mode used to create the Cassandra table (SYNC, ASYNC or OFF). metadata_deny_list: Optional list of metadata keys to not index. i.e. to fine-tune which of the metadata fields are indexed. Note: if you plan to have massive unique text metadata entries, consider not indexing them for performance (and to overcome max-length limitations). Note: the `metadata_indexing` parameter from langchain_community.utilities.cassandra.Cassandra is not exposed since CassandraGraphVectorStore only supports the deny_list option. """ self.embedding = embedding if metadata_deny_list is None: metadata_deny_list = [] metadata_deny_list.append(METADATA_LINKS_KEY) self.vector_store = CassandraVectorStore( embedding=embedding, session=session, keyspace=keyspace, table_name=table_name, ttl_seconds=ttl_seconds, body_index_options=body_index_options, setup_mode=setup_mode, metadata_indexing=("deny_list", metadata_deny_list), ) store_session: Session = self.vector_store.session self._insert_node = store_session.prepare( f""" INSERT INTO {keyspace}.{table_name} ( row_id, body_blob, vector, attributes_blob, metadata_s ) VALUES (?, ?, ?, ?, ?) """ # noqa: S608 )
@property @override def embeddings(self) -> Embeddings | None: return self.embedding def _get_metadata_filter( self, metadata: dict[str, Any] | None = None, outgoing_link: Link | None = None, ) -> dict[str, Any]: if outgoing_link is None: return metadata or {} metadata_filter = {} if metadata is None else metadata.copy() metadata_filter[_metadata_link_key(link=outgoing_link)] = _metadata_link_value() return metadata_filter def _restore_links(self, doc: Document) -> Document: """Restores the links in the document by deserializing them from metadata. Args: doc: A single Document Returns: The same Document with restored links. """ links = _deserialize_links(doc.metadata.get(METADATA_LINKS_KEY)) doc.metadata[METADATA_LINKS_KEY] = links # TODO: Could this be skipped if we put these metadata entries # only in the searchable `metadata_s` column? for incoming_link_key in [ _metadata_link_key(link=link) for link in links if link.direction in ["in", "bidir"] ]: if incoming_link_key in doc.metadata: del doc.metadata[incoming_link_key] return doc def _get_node_metadata_for_insertion(self, node: Node) -> dict[str, Any]: metadata = node.metadata.copy() metadata[METADATA_LINKS_KEY] = _serialize_links(node.links) # TODO: Could we could put these metadata entries # only in the searchable `metadata_s` column? for incoming_link in _incoming_links(node=node): metadata[_metadata_link_key(link=incoming_link)] = _metadata_link_value() return metadata def _get_docs_for_insertion( self, nodes: Iterable[Node] ) -> tuple[list[Document], list[str]]: docs = [] ids = [] for node in nodes: node_id = secrets.token_hex(8) if not node.id else node.id doc = Document( page_content=node.text, metadata=self._get_node_metadata_for_insertion(node=node), id=node_id, ) docs.append(doc) ids.append(node_id) return (docs, ids)
[docs] @override def add_nodes( self, nodes: Iterable[Node], **kwargs: Any, ) -> Iterable[str]: """Add nodes to the graph store. Args: nodes: the nodes to add. **kwargs: Additional keyword arguments. """ (docs, ids) = self._get_docs_for_insertion(nodes=nodes) return self.vector_store.add_documents(docs, ids=ids)
[docs] @override async def aadd_nodes( self, nodes: Iterable[Node], **kwargs: Any, ) -> AsyncIterable[str]: """Add nodes to the graph store. Args: nodes: the nodes to add. **kwargs: Additional keyword arguments. """ (docs, ids) = self._get_docs_for_insertion(nodes=nodes) for inserted_id in await self.vector_store.aadd_documents(docs, ids=ids): yield inserted_id
[docs] @override def similarity_search_by_vector( self, embedding: list[float], k: int = 4, filter: dict[str, Any] | None = None, **kwargs: Any, ) -> list[Document]: """Return docs most similar to embedding vector. Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. filter: Filter on the metadata to apply. **kwargs: Additional arguments are ignored. Returns: The list of Documents most similar to the query vector. """ return [ self._restore_links(doc) for doc in self.vector_store.similarity_search_by_vector( embedding, k=k, filter=filter, **kwargs, ) ]
[docs] @override async def asimilarity_search_by_vector( self, embedding: list[float], k: int = 4, filter: dict[str, Any] | None = None, **kwargs: Any, ) -> list[Document]: """Return docs most similar to embedding vector. Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. filter: Filter on the metadata to apply. **kwargs: Additional arguments are ignored. Returns: The list of Documents most similar to the query vector. """ return [ self._restore_links(doc) for doc in await self.vector_store.asimilarity_search_by_vector( embedding, k=k, filter=filter, **kwargs, ) ]
[docs] def get_by_document_id(self, document_id: str) -> Document | None: """Retrieve a single document from the store, given its document ID. Args: document_id: The document ID Returns: The the document if it exists. Otherwise None. """ doc = self.vector_store.get_by_document_id(document_id=document_id) return self._restore_links(doc) if doc is not None else None
[docs] async def aget_by_document_id(self, document_id: str) -> Document | None: """Retrieve a single document from the store, given its document ID. Args: document_id: The document ID Returns: The the document if it exists. Otherwise None. """ doc = await self.vector_store.aget_by_document_id(document_id=document_id) return self._restore_links(doc) if doc is not None else None
[docs] def get_node(self, node_id: str) -> Node | None: """Retrieve a single node from the store, given its ID. Args: node_id: The node ID Returns: The the node if it exists. Otherwise None. """ doc = self.vector_store.get_by_document_id(document_id=node_id) if doc is None: return None return _doc_to_node(doc=doc)
async def _get_outgoing_links(self, source_ids: Iterable[str]) -> set[Link]: """Return the set of outgoing links for the given source IDs asynchronously. Args: source_ids: The IDs of the source nodes to retrieve outgoing links for. Returns: A set of `Link` objects representing the outgoing links from the source nodes. """ links = set() # Create coroutine objects without scheduling them yet coroutines = [ self.vector_store.aget_by_document_id(document_id=source_id) for source_id in source_ids ] # Schedule and await all coroutines docs = await asyncio.gather(*coroutines) for doc in docs: if doc is not None: node = _doc_to_node(doc=doc) links.update(_outgoing_links(node=node)) return links async def _get_adjacent( self, links: set[Link], query_embedding: list[float], retrieved_docs: dict[str, Document], k_per_link: int | None = None, filter: dict[str, Any] | None = None, # noqa: A002 ) -> Iterable[AdjacentNode]: """Return the target nodes with incoming links from any of the given links. Args: links: The links to look for. query_embedding: The query embedding. Used to rank target nodes. retrieved_docs: A cache of retrieved docs. This will be added to. k_per_link: The number of target nodes to fetch for each link. filter: Optional metadata to filter the results. Returns: Iterable of adjacent edges. """ targets: dict[str, AdjacentNode] = {} tasks = [] for link in links: metadata_filter = self._get_metadata_filter( metadata=filter, outgoing_link=link, ) tasks.append( self.vector_store.asimilarity_search_with_embedding_id_by_vector( embedding=query_embedding, k=k_per_link or 10, filter=metadata_filter, ) ) results = await asyncio.gather(*tasks) for result in results: for doc, embedding, doc_id in result: if doc_id not in retrieved_docs: retrieved_docs[doc_id] = doc if doc_id not in targets: node = _doc_to_node(doc=doc) targets[doc_id] = AdjacentNode(node=node, embedding=embedding) # TODO: Consider a combined limit based on the similarity and/or # predicated MMR score? return targets.values() @staticmethod def _build_docs_from_texts( texts: List[str], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, ) -> List[Document]: docs: List[Document] = [] for i, text in enumerate(texts): doc = Document( page_content=text, ) if metadatas is not None: doc.metadata = metadatas[i] if ids is not None: doc.id = ids[i] docs.append(doc) return docs
[docs] @classmethod def from_texts( cls: Type[CGVST], texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, *, session: Optional[Session] = None, keyspace: Optional[str] = None, table_name: str = "", ids: Optional[List[str]] = None, ttl_seconds: Optional[int] = None, body_index_options: Optional[List[Tuple[str, Any]]] = None, metadata_deny_list: Optional[list[str]] = None, **kwargs: Any, ) -> CGVST: """Create a CassandraGraphVectorStore from raw texts. Args: texts: Texts to add to the vectorstore. embedding: Embedding function to use. metadatas: Optional list of metadatas associated with the texts. session: Cassandra driver session. If not provided, it is resolved from cassio. keyspace: Cassandra key space. If not provided, it is resolved from cassio. table_name: Cassandra table (required). ids: Optional list of IDs associated with the texts. ttl_seconds: Optional time-to-live for the added texts. body_index_options: Optional options used to create the body index. Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] metadata_deny_list: Optional list of metadata keys to not index. i.e. to fine-tune which of the metadata fields are indexed. Note: if you plan to have massive unique text metadata entries, consider not indexing them for performance (and to overcome max-length limitations). Note: the `metadata_indexing` parameter from langchain_community.utilities.cassandra.Cassandra is not exposed since CassandraGraphVectorStore only supports the deny_list option. Returns: a CassandraGraphVectorStore. """ docs = cls._build_docs_from_texts( texts=texts, metadatas=metadatas, ids=ids, ) return cls.from_documents( documents=docs, embedding=embedding, session=session, keyspace=keyspace, table_name=table_name, ttl_seconds=ttl_seconds, body_index_options=body_index_options, metadata_deny_list=metadata_deny_list, **kwargs, )
[docs] @classmethod async def afrom_texts( cls: Type[CGVST], texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, *, session: Optional[Session] = None, keyspace: Optional[str] = None, table_name: str = "", ids: Optional[List[str]] = None, ttl_seconds: Optional[int] = None, body_index_options: Optional[List[Tuple[str, Any]]] = None, metadata_deny_list: Optional[list[str]] = None, **kwargs: Any, ) -> CGVST: """Create a CassandraGraphVectorStore from raw texts. Args: texts: Texts to add to the vectorstore. embedding: Embedding function to use. metadatas: Optional list of metadatas associated with the texts. session: Cassandra driver session. If not provided, it is resolved from cassio. keyspace: Cassandra key space. If not provided, it is resolved from cassio. table_name: Cassandra table (required). ids: Optional list of IDs associated with the texts. ttl_seconds: Optional time-to-live for the added texts. body_index_options: Optional options used to create the body index. Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] metadata_deny_list: Optional list of metadata keys to not index. i.e. to fine-tune which of the metadata fields are indexed. Note: if you plan to have massive unique text metadata entries, consider not indexing them for performance (and to overcome max-length limitations). Note: the `metadata_indexing` parameter from langchain_community.utilities.cassandra.Cassandra is not exposed since CassandraGraphVectorStore only supports the deny_list option. Returns: a CassandraGraphVectorStore. """ docs = cls._build_docs_from_texts( texts=texts, metadatas=metadatas, ids=ids, ) return await cls.afrom_documents( documents=docs, embedding=embedding, session=session, keyspace=keyspace, table_name=table_name, ttl_seconds=ttl_seconds, body_index_options=body_index_options, metadata_deny_list=metadata_deny_list, **kwargs, )
@staticmethod def _add_ids_to_docs( docs: List[Document], ids: Optional[List[str]] = None, ) -> List[Document]: if ids is not None: for doc, doc_id in zip(docs, ids): doc.id = doc_id return docs
[docs] @classmethod def from_documents( cls: Type[CGVST], documents: List[Document], embedding: Embeddings, *, session: Optional[Session] = None, keyspace: Optional[str] = None, table_name: str = "", ids: Optional[List[str]] = None, ttl_seconds: Optional[int] = None, body_index_options: Optional[List[Tuple[str, Any]]] = None, metadata_deny_list: Optional[list[str]] = None, **kwargs: Any, ) -> CGVST: """Create a CassandraGraphVectorStore from a document list. Args: documents: Documents to add to the vectorstore. embedding: Embedding function to use. session: Cassandra driver session. If not provided, it is resolved from cassio. keyspace: Cassandra key space. If not provided, it is resolved from cassio. table_name: Cassandra table (required). ids: Optional list of IDs associated with the documents. ttl_seconds: Optional time-to-live for the added documents. body_index_options: Optional options used to create the body index. Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] metadata_deny_list: Optional list of metadata keys to not index. i.e. to fine-tune which of the metadata fields are indexed. Note: if you plan to have massive unique text metadata entries, consider not indexing them for performance (and to overcome max-length limitations). Note: the `metadata_indexing` parameter from langchain_community.utilities.cassandra.Cassandra is not exposed since CassandraGraphVectorStore only supports the deny_list option. Returns: a CassandraGraphVectorStore. """ store = cls( embedding=embedding, session=session, keyspace=keyspace, table_name=table_name, ttl_seconds=ttl_seconds, body_index_options=body_index_options, metadata_deny_list=metadata_deny_list, **kwargs, ) store.add_documents(documents=cls._add_ids_to_docs(docs=documents, ids=ids)) return store
[docs] @classmethod async def afrom_documents( cls: Type[CGVST], documents: List[Document], embedding: Embeddings, *, session: Optional[Session] = None, keyspace: Optional[str] = None, table_name: str = "", ids: Optional[List[str]] = None, ttl_seconds: Optional[int] = None, body_index_options: Optional[List[Tuple[str, Any]]] = None, metadata_deny_list: Optional[list[str]] = None, **kwargs: Any, ) -> CGVST: """Create a CassandraGraphVectorStore from a document list. Args: documents: Documents to add to the vectorstore. embedding: Embedding function to use. session: Cassandra driver session. If not provided, it is resolved from cassio. keyspace: Cassandra key space. If not provided, it is resolved from cassio. table_name: Cassandra table (required). ids: Optional list of IDs associated with the documents. ttl_seconds: Optional time-to-live for the added documents. body_index_options: Optional options used to create the body index. Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] metadata_deny_list: Optional list of metadata keys to not index. i.e. to fine-tune which of the metadata fields are indexed. Note: if you plan to have massive unique text metadata entries, consider not indexing them for performance (and to overcome max-length limitations). Note: the `metadata_indexing` parameter from langchain_community.utilities.cassandra.Cassandra is not exposed since CassandraGraphVectorStore only supports the deny_list option. Returns: a CassandraGraphVectorStore. """ store = cls( embedding=embedding, session=session, keyspace=keyspace, table_name=table_name, ttl_seconds=ttl_seconds, setup_mode=SetupMode.ASYNC, body_index_options=body_index_options, metadata_deny_list=metadata_deny_list, **kwargs, ) await store.aadd_documents( documents=cls._add_ids_to_docs(docs=documents, ids=ids) ) return store