"""Apache Cassandra DB graph vector store integration."""
from __future__ import annotations
import asyncio
import json
import logging
import secrets
from dataclasses import asdict, is_dataclass
from typing import (
TYPE_CHECKING,
Any,
AsyncIterable,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
cast,
)
from langchain_core._api import beta
from langchain_core.documents import Document
from typing_extensions import override
from langchain_community.graph_vectorstores.base import GraphVectorStore, Node
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
from langchain_community.graph_vectorstores.mmr_helper import MmrHelper
from langchain_community.utilities.cassandra import SetupMode
from langchain_community.vectorstores.cassandra import Cassandra as CassandraVectorStore
CGVST = TypeVar("CGVST", bound="CassandraGraphVectorStore")
if TYPE_CHECKING:
from cassandra.cluster import Session
from langchain_core.embeddings import Embeddings
logger = logging.getLogger(__name__)
[docs]
class AdjacentNode:
id: str
links: list[Link]
embedding: list[float]
[docs]
def __init__(self, node: Node, embedding: list[float]) -> None:
"""Create an Adjacent Node."""
self.id = node.id or ""
self.links = node.links
self.embedding = embedding
def _serialize_links(links: list[Link]) -> str:
class SetAndLinkEncoder(json.JSONEncoder):
def default(self, obj: Any) -> Any: # noqa: ANN401
if not isinstance(obj, type) and is_dataclass(obj):
return asdict(obj)
if isinstance(obj, Iterable):
return list(obj)
# Let the base class default method raise the TypeError
return super().default(obj)
return json.dumps(links, cls=SetAndLinkEncoder)
def _deserialize_links(json_blob: str | None) -> set[Link]:
return {
Link(kind=link["kind"], direction=link["direction"], tag=link["tag"])
for link in cast(list[dict[str, Any]], json.loads(json_blob or "[]"))
}
def _metadata_link_key(link: Link) -> str:
return f"link:{link.kind}:{link.tag}"
def _metadata_link_value() -> str:
return "link"
def _doc_to_node(doc: Document) -> Node:
metadata = doc.metadata.copy()
links = _deserialize_links(metadata.get(METADATA_LINKS_KEY))
metadata[METADATA_LINKS_KEY] = links
return Node(
id=doc.id,
text=doc.page_content,
metadata=metadata,
links=list(links),
)
def _incoming_links(node: Node | AdjacentNode) -> set[Link]:
return {link for link in node.links if link.direction in ["in", "bidir"]}
def _outgoing_links(node: Node | AdjacentNode) -> set[Link]:
return {link for link in node.links if link.direction in ["out", "bidir"]}
[docs]
@beta()
class CassandraGraphVectorStore(GraphVectorStore):
[docs]
def __init__(
self,
embedding: Embeddings,
session: Session | None = None,
keyspace: str | None = None,
table_name: str = "",
ttl_seconds: int | None = None,
*,
body_index_options: list[tuple[str, Any]] | None = None,
setup_mode: SetupMode = SetupMode.SYNC,
metadata_deny_list: Optional[list[str]] = None,
) -> None:
"""Apache Cassandra(R) for graph-vector-store workloads.
To use it, you need a recent installation of the `cassio` library
and a Cassandra cluster / Astra DB instance supporting vector capabilities.
Example:
.. code-block:: python
from langchain_community.graph_vectorstores import
CassandraGraphVectorStore
from langchain_openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
session = ... # create your Cassandra session object
keyspace = 'my_keyspace' # the keyspace should exist already
table_name = 'my_graph_vector_store'
vectorstore = CassandraGraphVectorStore(
embeddings,
session,
keyspace,
table_name,
)
Args:
embedding: Embedding function to use.
session: Cassandra driver session. If not provided, it is resolved from
cassio.
keyspace: Cassandra keyspace. If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ttl_seconds: Optional time-to-live for the added texts.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
setup_mode: mode used to create the Cassandra table (SYNC,
ASYNC or OFF).
metadata_deny_list: Optional list of metadata keys to not index.
i.e. to fine-tune which of the metadata fields are indexed.
Note: if you plan to have massive unique text metadata entries,
consider not indexing them for performance
(and to overcome max-length limitations).
Note: the `metadata_indexing` parameter from
langchain_community.utilities.cassandra.Cassandra is not
exposed since CassandraGraphVectorStore only supports the
deny_list option.
"""
self.embedding = embedding
if metadata_deny_list is None:
metadata_deny_list = []
metadata_deny_list.append(METADATA_LINKS_KEY)
self.vector_store = CassandraVectorStore(
embedding=embedding,
session=session,
keyspace=keyspace,
table_name=table_name,
ttl_seconds=ttl_seconds,
body_index_options=body_index_options,
setup_mode=setup_mode,
metadata_indexing=("deny_list", metadata_deny_list),
)
store_session: Session = self.vector_store.session
self._insert_node = store_session.prepare(
f"""
INSERT INTO {keyspace}.{table_name} (
row_id, body_blob, vector, attributes_blob, metadata_s
) VALUES (?, ?, ?, ?, ?)
""" # noqa: S608
)
@property
@override
def embeddings(self) -> Embeddings | None:
return self.embedding
def _get_metadata_filter(
self,
metadata: dict[str, Any] | None = None,
outgoing_link: Link | None = None,
) -> dict[str, Any]:
if outgoing_link is None:
return metadata or {}
metadata_filter = {} if metadata is None else metadata.copy()
metadata_filter[_metadata_link_key(link=outgoing_link)] = _metadata_link_value()
return metadata_filter
def _restore_links(self, doc: Document) -> Document:
"""Restores the links in the document by deserializing them from metadata.
Args:
doc: A single Document
Returns:
The same Document with restored links.
"""
links = _deserialize_links(doc.metadata.get(METADATA_LINKS_KEY))
doc.metadata[METADATA_LINKS_KEY] = links
# TODO: Could this be skipped if we put these metadata entries
# only in the searchable `metadata_s` column?
for incoming_link_key in [
_metadata_link_key(link=link)
for link in links
if link.direction in ["in", "bidir"]
]:
if incoming_link_key in doc.metadata:
del doc.metadata[incoming_link_key]
return doc
def _get_node_metadata_for_insertion(self, node: Node) -> dict[str, Any]:
metadata = node.metadata.copy()
metadata[METADATA_LINKS_KEY] = _serialize_links(node.links)
# TODO: Could we could put these metadata entries
# only in the searchable `metadata_s` column?
for incoming_link in _incoming_links(node=node):
metadata[_metadata_link_key(link=incoming_link)] = _metadata_link_value()
return metadata
def _get_docs_for_insertion(
self, nodes: Iterable[Node]
) -> tuple[list[Document], list[str]]:
docs = []
ids = []
for node in nodes:
node_id = secrets.token_hex(8) if not node.id else node.id
doc = Document(
page_content=node.text,
metadata=self._get_node_metadata_for_insertion(node=node),
id=node_id,
)
docs.append(doc)
ids.append(node_id)
return (docs, ids)
[docs]
@override
def add_nodes(
self,
nodes: Iterable[Node],
**kwargs: Any,
) -> Iterable[str]:
"""Add nodes to the graph store.
Args:
nodes: the nodes to add.
**kwargs: Additional keyword arguments.
"""
(docs, ids) = self._get_docs_for_insertion(nodes=nodes)
return self.vector_store.add_documents(docs, ids=ids)
[docs]
@override
async def aadd_nodes(
self,
nodes: Iterable[Node],
**kwargs: Any,
) -> AsyncIterable[str]:
"""Add nodes to the graph store.
Args:
nodes: the nodes to add.
**kwargs: Additional keyword arguments.
"""
(docs, ids) = self._get_docs_for_insertion(nodes=nodes)
for inserted_id in await self.vector_store.aadd_documents(docs, ids=ids):
yield inserted_id
[docs]
@override
def similarity_search(
self,
query: str,
k: int = 4,
filter: dict[str, Any] | None = None,
**kwargs: Any,
) -> list[Document]:
"""Retrieve documents from this graph store.
Args:
query: The query string.
k: The number of Documents to return. Defaults to 4.
filter: Optional metadata to filter the results.
**kwargs: Additional keyword arguments.
Returns:
Collection of retrieved documents.
"""
return [
self._restore_links(doc)
for doc in self.vector_store.similarity_search(
query=query,
k=k,
filter=filter,
**kwargs,
)
]
[docs]
@override
async def asimilarity_search(
self,
query: str,
k: int = 4,
filter: dict[str, Any] | None = None,
**kwargs: Any,
) -> list[Document]:
"""Retrieve documents from this graph store.
Args:
query: The query string.
k: The number of Documents to return. Defaults to 4.
filter: Optional metadata to filter the results.
**kwargs: Additional keyword arguments.
Returns:
Collection of retrieved documents.
"""
return [
self._restore_links(doc)
for doc in await self.vector_store.asimilarity_search(
query=query,
k=k,
filter=filter,
**kwargs,
)
]
[docs]
@override
def similarity_search_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: dict[str, Any] | None = None,
**kwargs: Any,
) -> list[Document]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
**kwargs: Additional arguments are ignored.
Returns:
The list of Documents most similar to the query vector.
"""
return [
self._restore_links(doc)
for doc in self.vector_store.similarity_search_by_vector(
embedding,
k=k,
filter=filter,
**kwargs,
)
]
[docs]
@override
async def asimilarity_search_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: dict[str, Any] | None = None,
**kwargs: Any,
) -> list[Document]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
**kwargs: Additional arguments are ignored.
Returns:
The list of Documents most similar to the query vector.
"""
return [
self._restore_links(doc)
for doc in await self.vector_store.asimilarity_search_by_vector(
embedding,
k=k,
filter=filter,
**kwargs,
)
]
[docs]
def get_by_document_id(self, document_id: str) -> Document | None:
"""Retrieve a single document from the store, given its document ID.
Args:
document_id: The document ID
Returns:
The the document if it exists. Otherwise None.
"""
doc = self.vector_store.get_by_document_id(document_id=document_id)
return self._restore_links(doc) if doc is not None else None
[docs]
async def aget_by_document_id(self, document_id: str) -> Document | None:
"""Retrieve a single document from the store, given its document ID.
Args:
document_id: The document ID
Returns:
The the document if it exists. Otherwise None.
"""
doc = await self.vector_store.aget_by_document_id(document_id=document_id)
return self._restore_links(doc) if doc is not None else None
[docs]
def get_node(self, node_id: str) -> Node | None:
"""Retrieve a single node from the store, given its ID.
Args:
node_id: The node ID
Returns:
The the node if it exists. Otherwise None.
"""
doc = self.vector_store.get_by_document_id(document_id=node_id)
if doc is None:
return None
return _doc_to_node(doc=doc)
[docs]
@override
async def ammr_traversal_search( # noqa: C901
self,
query: str,
*,
initial_roots: Sequence[str] = (),
k: int = 4,
depth: int = 2,
fetch_k: int = 100,
adjacent_k: int = 10,
lambda_mult: float = 0.5,
score_threshold: float = float("-inf"),
filter: dict[str, Any] | None = None,
**kwargs: Any,
) -> AsyncIterable[Document]:
"""Retrieve documents from this graph store using MMR-traversal.
This strategy first retrieves the top `fetch_k` results by similarity to
the question. It then selects the top `k` results based on
maximum-marginal relevance using the given `lambda_mult`.
At each step, it considers the (remaining) documents from `fetch_k` as
well as any documents connected by edges to a selected document
retrieved based on similarity (a "root").
Args:
query: The query string to search for.
initial_roots: Optional list of document IDs to use for initializing search.
The top `adjacent_k` nodes adjacent to each initial root will be
included in the set of initial candidates. To fetch only in the
neighborhood of these nodes, set `fetch_k = 0`.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of initial Documents to fetch via similarity.
Will be added to the nodes adjacent to `initial_roots`.
Defaults to 100.
adjacent_k: Number of adjacent Documents to fetch.
Defaults to 10.
depth: Maximum depth of a node (number of edges) from a node
retrieved via similarity. Defaults to 2.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding to maximum
diversity and 1 to minimum diversity. Defaults to 0.5.
score_threshold: Only documents with a score greater than or equal
this threshold will be chosen. Defaults to -infinity.
filter: Optional metadata to filter the results.
**kwargs: Additional keyword arguments.
"""
query_embedding = self.embedding.embed_query(query)
helper = MmrHelper(
k=k,
query_embedding=query_embedding,
lambda_mult=lambda_mult,
score_threshold=score_threshold,
)
# For each unselected node, stores the outgoing links.
outgoing_links_map: dict[str, set[Link]] = {}
visited_links: set[Link] = set()
# Map from id to Document
retrieved_docs: dict[str, Document] = {}
async def fetch_neighborhood(neighborhood: Sequence[str]) -> None:
nonlocal outgoing_links_map, visited_links, retrieved_docs
# Put the neighborhood into the outgoing links, to avoid adding it
# to the candidate set in the future.
outgoing_links_map.update(
{content_id: set() for content_id in neighborhood}
)
# Initialize the visited_links with the set of outgoing links from the
# neighborhood. This prevents re-visiting them.
visited_links = await self._get_outgoing_links(neighborhood)
# Call `self._get_adjacent` to fetch the candidates.
adjacent_nodes = await self._get_adjacent(
links=visited_links,
query_embedding=query_embedding,
k_per_link=adjacent_k,
filter=filter,
retrieved_docs=retrieved_docs,
)
new_candidates: dict[str, list[float]] = {}
for adjacent_node in adjacent_nodes:
if adjacent_node.id not in outgoing_links_map:
outgoing_links_map[adjacent_node.id] = _outgoing_links(
node=adjacent_node
)
new_candidates[adjacent_node.id] = adjacent_node.embedding
helper.add_candidates(new_candidates)
async def fetch_initial_candidates() -> None:
nonlocal outgoing_links_map, visited_links, retrieved_docs
results = (
await self.vector_store.asimilarity_search_with_embedding_id_by_vector(
embedding=query_embedding,
k=fetch_k,
filter=filter,
)
)
candidates: dict[str, list[float]] = {}
for doc, embedding, doc_id in results:
if doc_id not in retrieved_docs:
retrieved_docs[doc_id] = doc
if doc_id not in outgoing_links_map:
node = _doc_to_node(doc)
outgoing_links_map[doc_id] = _outgoing_links(node=node)
candidates[doc_id] = embedding
helper.add_candidates(candidates)
if initial_roots:
await fetch_neighborhood(initial_roots)
if fetch_k > 0:
await fetch_initial_candidates()
# Tracks the depth of each candidate.
depths = {candidate_id: 0 for candidate_id in helper.candidate_ids()}
# Select the best item, K times.
for _ in range(k):
selected_id = helper.pop_best()
if selected_id is None:
break
next_depth = depths[selected_id] + 1
if next_depth < depth:
# If the next nodes would not exceed the depth limit, find the
# adjacent nodes.
# Find the links linked to from the selected ID.
selected_outgoing_links = outgoing_links_map.pop(selected_id)
# Don't re-visit already visited links.
selected_outgoing_links.difference_update(visited_links)
# Find the nodes with incoming links from those links.
adjacent_nodes = await self._get_adjacent(
links=selected_outgoing_links,
query_embedding=query_embedding,
k_per_link=adjacent_k,
filter=filter,
retrieved_docs=retrieved_docs,
)
# Record the selected_outgoing_links as visited.
visited_links.update(selected_outgoing_links)
new_candidates = {}
for adjacent_node in adjacent_nodes:
if adjacent_node.id not in outgoing_links_map:
outgoing_links_map[adjacent_node.id] = _outgoing_links(
node=adjacent_node
)
new_candidates[adjacent_node.id] = adjacent_node.embedding
if next_depth < depths.get(adjacent_node.id, depth + 1):
# If this is a new shortest depth, or there was no
# previous depth, update the depths. This ensures that
# when we discover a node we will have the shortest
# depth available.
#
# NOTE: No effort is made to traverse from nodes that
# were previously selected if they become reachable via
# a shorter path via nodes selected later. This is
# currently "intended", but may be worth experimenting
# with.
depths[adjacent_node.id] = next_depth
helper.add_candidates(new_candidates)
for doc_id, similarity_score, mmr_score in zip(
helper.selected_ids,
helper.selected_similarity_scores,
helper.selected_mmr_scores,
):
if doc_id in retrieved_docs:
doc = self._restore_links(retrieved_docs[doc_id])
doc.metadata["similarity_score"] = similarity_score
doc.metadata["mmr_score"] = mmr_score
yield doc
else:
msg = f"retrieved_docs should contain id: {doc_id}"
raise RuntimeError(msg)
[docs]
@override
def mmr_traversal_search(
self,
query: str,
*,
initial_roots: Sequence[str] = (),
k: int = 4,
depth: int = 2,
fetch_k: int = 100,
adjacent_k: int = 10,
lambda_mult: float = 0.5,
score_threshold: float = float("-inf"),
filter: dict[str, Any] | None = None,
**kwargs: Any,
) -> Iterable[Document]:
"""Retrieve documents from this graph store using MMR-traversal.
This strategy first retrieves the top `fetch_k` results by similarity to
the question. It then selects the top `k` results based on
maximum-marginal relevance using the given `lambda_mult`.
At each step, it considers the (remaining) documents from `fetch_k` as
well as any documents connected by edges to a selected document
retrieved based on similarity (a "root").
Args:
query: The query string to search for.
initial_roots: Optional list of document IDs to use for initializing search.
The top `adjacent_k` nodes adjacent to each initial root will be
included in the set of initial candidates. To fetch only in the
neighborhood of these nodes, set `fetch_k = 0`.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of initial Documents to fetch via similarity.
Will be added to the nodes adjacent to `initial_roots`.
Defaults to 100.
adjacent_k: Number of adjacent Documents to fetch.
Defaults to 10.
depth: Maximum depth of a node (number of edges) from a node
retrieved via similarity. Defaults to 2.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding to maximum
diversity and 1 to minimum diversity. Defaults to 0.5.
score_threshold: Only documents with a score greater than or equal
this threshold will be chosen. Defaults to -infinity.
filter: Optional metadata to filter the results.
**kwargs: Additional keyword arguments.
"""
async def collect_docs() -> Iterable[Document]:
async_iter = self.ammr_traversal_search(
query=query,
initial_roots=initial_roots,
k=k,
depth=depth,
fetch_k=fetch_k,
adjacent_k=adjacent_k,
lambda_mult=lambda_mult,
score_threshold=score_threshold,
filter=filter,
**kwargs,
)
return [doc async for doc in async_iter]
return asyncio.run(collect_docs())
[docs]
@override
async def atraversal_search( # noqa: C901
self,
query: str,
*,
k: int = 4,
depth: int = 1,
filter: dict[str, Any] | None = None,
**kwargs: Any,
) -> AsyncIterable[Document]:
"""Retrieve documents from this knowledge store.
First, `k` nodes are retrieved using a vector search for the `query` string.
Then, additional nodes are discovered up to the given `depth` from those
starting nodes.
Args:
query: The query string.
k: The number of Documents to return from the initial vector search.
Defaults to 4.
depth: The maximum depth of edges to traverse. Defaults to 1.
filter: Optional metadata to filter the results.
**kwargs: Additional keyword arguments.
Returns:
Collection of retrieved documents.
"""
# Depth 0:
# Query for `k` nodes similar to the question.
# Retrieve `content_id` and `outgoing_links()`.
#
# Depth 1:
# Query for nodes that have an incoming link in the `outgoing_links()` set.
# Combine node IDs.
# Query for `outgoing_links()` of those "new" node IDs.
#
# ...
# Map from visited ID to depth
visited_ids: dict[str, int] = {}
# Map from visited link to depth
visited_links: dict[Link, int] = {}
# Map from id to Document
retrieved_docs: dict[str, Document] = {}
async def visit_nodes(d: int, docs: Iterable[Document]) -> None:
"""Recursively visit nodes and their outgoing links."""
nonlocal visited_ids, visited_links, retrieved_docs
# Iterate over nodes, tracking the *new* outgoing links for this
# depth. These are links that are either new, or newly discovered at a
# lower depth.
outgoing_links: set[Link] = set()
for doc in docs:
if doc.id is not None:
if doc.id not in retrieved_docs:
retrieved_docs[doc.id] = doc
# If this node is at a closer depth, update visited_ids
if d <= visited_ids.get(doc.id, depth):
visited_ids[doc.id] = d
# If we can continue traversing from this node,
if d < depth:
node = _doc_to_node(doc=doc)
# Record any new (or newly discovered at a lower depth)
# links to the set to traverse.
for link in _outgoing_links(node=node):
if d <= visited_links.get(link, depth):
# Record that we'll query this link at the
# given depth, so we don't fetch it again
# (unless we find it an earlier depth)
visited_links[link] = d
outgoing_links.add(link)
if outgoing_links:
metadata_search_tasks = []
for outgoing_link in outgoing_links:
metadata_filter = self._get_metadata_filter(
metadata=filter,
outgoing_link=outgoing_link,
)
metadata_search_tasks.append(
asyncio.create_task(
self.vector_store.ametadata_search(
filter=metadata_filter, n=1000
)
)
)
results = await asyncio.gather(*metadata_search_tasks)
# Visit targets concurrently
visit_target_tasks = [
visit_targets(d=d + 1, docs=docs) for docs in results
]
await asyncio.gather(*visit_target_tasks)
async def visit_targets(d: int, docs: Iterable[Document]) -> None:
"""Visit target nodes retrieved from outgoing links."""
nonlocal visited_ids, retrieved_docs
new_ids_at_next_depth = set()
for doc in docs:
if doc.id is not None:
if doc.id not in retrieved_docs:
retrieved_docs[doc.id] = doc
if d <= visited_ids.get(doc.id, depth):
new_ids_at_next_depth.add(doc.id)
if new_ids_at_next_depth:
visit_node_tasks = [
visit_nodes(d=d, docs=[retrieved_docs[doc_id]])
for doc_id in new_ids_at_next_depth
if doc_id in retrieved_docs
]
fetch_tasks = [
asyncio.create_task(
self.vector_store.aget_by_document_id(document_id=doc_id)
)
for doc_id in new_ids_at_next_depth
if doc_id not in retrieved_docs
]
new_docs: list[Document | None] = await asyncio.gather(*fetch_tasks)
visit_node_tasks.extend(
visit_nodes(d=d, docs=[new_doc])
for new_doc in new_docs
if new_doc is not None
)
await asyncio.gather(*visit_node_tasks)
# Start the traversal
initial_docs = self.vector_store.similarity_search(
query=query,
k=k,
filter=filter,
)
await visit_nodes(d=0, docs=initial_docs)
for doc_id in visited_ids:
if doc_id in retrieved_docs:
yield self._restore_links(retrieved_docs[doc_id])
else:
msg = f"retrieved_docs should contain id: {doc_id}"
raise RuntimeError(msg)
[docs]
@override
def traversal_search(
self,
query: str,
*,
k: int = 4,
depth: int = 1,
filter: dict[str, Any] | None = None,
**kwargs: Any,
) -> Iterable[Document]:
"""Retrieve documents from this knowledge store.
First, `k` nodes are retrieved using a vector search for the `query` string.
Then, additional nodes are discovered up to the given `depth` from those
starting nodes.
Args:
query: The query string.
k: The number of Documents to return from the initial vector search.
Defaults to 4.
depth: The maximum depth of edges to traverse. Defaults to 1.
filter: Optional metadata to filter the results.
**kwargs: Additional keyword arguments.
Returns:
Collection of retrieved documents.
"""
async def collect_docs() -> Iterable[Document]:
async_iter = self.atraversal_search(
query=query,
k=k,
depth=depth,
filter=filter,
**kwargs,
)
return [doc async for doc in async_iter]
return asyncio.run(collect_docs())
async def _get_outgoing_links(self, source_ids: Iterable[str]) -> set[Link]:
"""Return the set of outgoing links for the given source IDs asynchronously.
Args:
source_ids: The IDs of the source nodes to retrieve outgoing links for.
Returns:
A set of `Link` objects representing the outgoing links from the source
nodes.
"""
links = set()
# Create coroutine objects without scheduling them yet
coroutines = [
self.vector_store.aget_by_document_id(document_id=source_id)
for source_id in source_ids
]
# Schedule and await all coroutines
docs = await asyncio.gather(*coroutines)
for doc in docs:
if doc is not None:
node = _doc_to_node(doc=doc)
links.update(_outgoing_links(node=node))
return links
async def _get_adjacent(
self,
links: set[Link],
query_embedding: list[float],
retrieved_docs: dict[str, Document],
k_per_link: int | None = None,
filter: dict[str, Any] | None = None, # noqa: A002
) -> Iterable[AdjacentNode]:
"""Return the target nodes with incoming links from any of the given links.
Args:
links: The links to look for.
query_embedding: The query embedding. Used to rank target nodes.
retrieved_docs: A cache of retrieved docs. This will be added to.
k_per_link: The number of target nodes to fetch for each link.
filter: Optional metadata to filter the results.
Returns:
Iterable of adjacent edges.
"""
targets: dict[str, AdjacentNode] = {}
tasks = []
for link in links:
metadata_filter = self._get_metadata_filter(
metadata=filter,
outgoing_link=link,
)
tasks.append(
self.vector_store.asimilarity_search_with_embedding_id_by_vector(
embedding=query_embedding,
k=k_per_link or 10,
filter=metadata_filter,
)
)
results = await asyncio.gather(*tasks)
for result in results:
for doc, embedding, doc_id in result:
if doc_id not in retrieved_docs:
retrieved_docs[doc_id] = doc
if doc_id not in targets:
node = _doc_to_node(doc=doc)
targets[doc_id] = AdjacentNode(node=node, embedding=embedding)
# TODO: Consider a combined limit based on the similarity and/or
# predicated MMR score?
return targets.values()
@staticmethod
def _build_docs_from_texts(
texts: List[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
) -> List[Document]:
docs: List[Document] = []
for i, text in enumerate(texts):
doc = Document(
page_content=text,
)
if metadatas is not None:
doc.metadata = metadatas[i]
if ids is not None:
doc.id = ids[i]
docs.append(doc)
return docs
[docs]
@classmethod
def from_texts(
cls: Type[CGVST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
*,
session: Optional[Session] = None,
keyspace: Optional[str] = None,
table_name: str = "",
ids: Optional[List[str]] = None,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_deny_list: Optional[list[str]] = None,
**kwargs: Any,
) -> CGVST:
"""Create a CassandraGraphVectorStore from raw texts.
Args:
texts: Texts to add to the vectorstore.
embedding: Embedding function to use.
metadatas: Optional list of metadatas associated with the texts.
session: Cassandra driver session.
If not provided, it is resolved from cassio.
keyspace: Cassandra key space.
If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ids: Optional list of IDs associated with the texts.
ttl_seconds: Optional time-to-live for the added texts.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
metadata_deny_list: Optional list of metadata keys to not index.
i.e. to fine-tune which of the metadata fields are indexed.
Note: if you plan to have massive unique text metadata entries,
consider not indexing them for performance
(and to overcome max-length limitations).
Note: the `metadata_indexing` parameter from
langchain_community.utilities.cassandra.Cassandra is not
exposed since CassandraGraphVectorStore only supports the
deny_list option.
Returns:
a CassandraGraphVectorStore.
"""
docs = cls._build_docs_from_texts(
texts=texts,
metadatas=metadatas,
ids=ids,
)
return cls.from_documents(
documents=docs,
embedding=embedding,
session=session,
keyspace=keyspace,
table_name=table_name,
ttl_seconds=ttl_seconds,
body_index_options=body_index_options,
metadata_deny_list=metadata_deny_list,
**kwargs,
)
[docs]
@classmethod
async def afrom_texts(
cls: Type[CGVST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
*,
session: Optional[Session] = None,
keyspace: Optional[str] = None,
table_name: str = "",
ids: Optional[List[str]] = None,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_deny_list: Optional[list[str]] = None,
**kwargs: Any,
) -> CGVST:
"""Create a CassandraGraphVectorStore from raw texts.
Args:
texts: Texts to add to the vectorstore.
embedding: Embedding function to use.
metadatas: Optional list of metadatas associated with the texts.
session: Cassandra driver session.
If not provided, it is resolved from cassio.
keyspace: Cassandra key space.
If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ids: Optional list of IDs associated with the texts.
ttl_seconds: Optional time-to-live for the added texts.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
metadata_deny_list: Optional list of metadata keys to not index.
i.e. to fine-tune which of the metadata fields are indexed.
Note: if you plan to have massive unique text metadata entries,
consider not indexing them for performance
(and to overcome max-length limitations).
Note: the `metadata_indexing` parameter from
langchain_community.utilities.cassandra.Cassandra is not
exposed since CassandraGraphVectorStore only supports the
deny_list option.
Returns:
a CassandraGraphVectorStore.
"""
docs = cls._build_docs_from_texts(
texts=texts,
metadatas=metadatas,
ids=ids,
)
return await cls.afrom_documents(
documents=docs,
embedding=embedding,
session=session,
keyspace=keyspace,
table_name=table_name,
ttl_seconds=ttl_seconds,
body_index_options=body_index_options,
metadata_deny_list=metadata_deny_list,
**kwargs,
)
@staticmethod
def _add_ids_to_docs(
docs: List[Document],
ids: Optional[List[str]] = None,
) -> List[Document]:
if ids is not None:
for doc, doc_id in zip(docs, ids):
doc.id = doc_id
return docs
[docs]
@classmethod
def from_documents(
cls: Type[CGVST],
documents: List[Document],
embedding: Embeddings,
*,
session: Optional[Session] = None,
keyspace: Optional[str] = None,
table_name: str = "",
ids: Optional[List[str]] = None,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_deny_list: Optional[list[str]] = None,
**kwargs: Any,
) -> CGVST:
"""Create a CassandraGraphVectorStore from a document list.
Args:
documents: Documents to add to the vectorstore.
embedding: Embedding function to use.
session: Cassandra driver session.
If not provided, it is resolved from cassio.
keyspace: Cassandra key space.
If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ids: Optional list of IDs associated with the documents.
ttl_seconds: Optional time-to-live for the added documents.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
metadata_deny_list: Optional list of metadata keys to not index.
i.e. to fine-tune which of the metadata fields are indexed.
Note: if you plan to have massive unique text metadata entries,
consider not indexing them for performance
(and to overcome max-length limitations).
Note: the `metadata_indexing` parameter from
langchain_community.utilities.cassandra.Cassandra is not
exposed since CassandraGraphVectorStore only supports the
deny_list option.
Returns:
a CassandraGraphVectorStore.
"""
store = cls(
embedding=embedding,
session=session,
keyspace=keyspace,
table_name=table_name,
ttl_seconds=ttl_seconds,
body_index_options=body_index_options,
metadata_deny_list=metadata_deny_list,
**kwargs,
)
store.add_documents(documents=cls._add_ids_to_docs(docs=documents, ids=ids))
return store
[docs]
@classmethod
async def afrom_documents(
cls: Type[CGVST],
documents: List[Document],
embedding: Embeddings,
*,
session: Optional[Session] = None,
keyspace: Optional[str] = None,
table_name: str = "",
ids: Optional[List[str]] = None,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_deny_list: Optional[list[str]] = None,
**kwargs: Any,
) -> CGVST:
"""Create a CassandraGraphVectorStore from a document list.
Args:
documents: Documents to add to the vectorstore.
embedding: Embedding function to use.
session: Cassandra driver session.
If not provided, it is resolved from cassio.
keyspace: Cassandra key space.
If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ids: Optional list of IDs associated with the documents.
ttl_seconds: Optional time-to-live for the added documents.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
metadata_deny_list: Optional list of metadata keys to not index.
i.e. to fine-tune which of the metadata fields are indexed.
Note: if you plan to have massive unique text metadata entries,
consider not indexing them for performance
(and to overcome max-length limitations).
Note: the `metadata_indexing` parameter from
langchain_community.utilities.cassandra.Cassandra is not
exposed since CassandraGraphVectorStore only supports the
deny_list option.
Returns:
a CassandraGraphVectorStore.
"""
store = cls(
embedding=embedding,
session=session,
keyspace=keyspace,
table_name=table_name,
ttl_seconds=ttl_seconds,
setup_mode=SetupMode.ASYNC,
body_index_options=body_index_options,
metadata_deny_list=metadata_deny_list,
**kwargs,
)
await store.aadd_documents(
documents=cls._add_ids_to_docs(docs=documents, ids=ids)
)
return store