from __future__ import annotations
import time
from itertools import repeat
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
[docs]
class XataVectorStore(VectorStore):
"""`Xata` vector store.
It assumes you have a Xata database
created with the right schema. See the guide at:
https://integrations.langchain.com/vectorstores?integration_name=XataVectorStore
"""
[docs]
def __init__(
self,
api_key: str,
db_url: str,
embedding: Embeddings,
table_name: str,
) -> None:
"""Initialize with Xata client."""
try:
from xata.client import XataClient
except ImportError:
raise ImportError(
"Could not import xata python package. "
"Please install it with `pip install xata`."
)
self._client = XataClient(api_key=api_key, db_url=db_url)
self._embedding: Embeddings = embedding
self._table_name = table_name or "vectors"
@property
def embeddings(self) -> Embeddings:
return self._embedding
[docs]
def add_vectors(
self,
vectors: List[List[float]],
documents: List[Document],
ids: Optional[List[str]] = None,
) -> List[str]:
return self._add_vectors(vectors, documents, ids)
[docs]
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[Dict[Any, Any]]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
ids = ids
docs = self._texts_to_documents(texts, metadatas)
vectors = self._embedding.embed_documents(list(texts))
return self.add_vectors(vectors, docs, ids)
def _add_vectors(
self,
vectors: List[List[float]],
documents: List[Document],
ids: Optional[List[str]] = None,
) -> List[str]:
"""Add vectors to the Xata database."""
rows: List[Dict[str, Any]] = []
for idx, embedding in enumerate(vectors):
row = {
"content": documents[idx].page_content,
"embedding": embedding,
}
if ids:
row["id"] = ids[idx]
for key, val in documents[idx].metadata.items():
if key not in ["id", "content", "embedding"]:
row[key] = val
rows.append(row)
# XXX: I would have liked to use the BulkProcessor here, but it
# doesn't return the IDs, which we need here. Manual chunking it is.
chunk_size = 1000
id_list: List[str] = []
for i in range(0, len(rows), chunk_size):
chunk = rows[i : i + chunk_size]
r = self._client.records().bulk_insert(self._table_name, {"records": chunk})
if r.status_code != 200:
raise Exception(f"Error adding vectors to Xata: {r.status_code} {r}")
id_list.extend(r["recordIDs"])
return id_list
@staticmethod
def _texts_to_documents(
texts: Iterable[str],
metadatas: Optional[Iterable[Dict[Any, Any]]] = None,
) -> List[Document]:
"""Return list of Documents from list of texts and metadatas."""
if metadatas is None:
metadatas = repeat({})
docs = [
Document(page_content=text, metadata=metadata)
for text, metadata in zip(texts, metadatas)
]
return docs
[docs]
@classmethod
def from_texts(
cls: Type["XataVectorStore"],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
api_key: Optional[str] = None,
db_url: Optional[str] = None,
table_name: str = "vectors",
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> "XataVectorStore":
"""Return VectorStore initialized from texts and embeddings."""
if not api_key or not db_url:
raise ValueError("Xata api_key and db_url must be set.")
embeddings = embedding.embed_documents(texts)
ids = None # Xata will generate them for us
docs = cls._texts_to_documents(texts, metadatas)
vector_db = cls(
api_key=api_key,
db_url=db_url,
embedding=embedding,
table_name=table_name,
)
vector_db._add_vectors(embeddings, docs, ids)
return vector_db
[docs]
def similarity_search(
self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
docs_and_scores = self.similarity_search_with_score(query, k, filter=filter)
documents = [d[0] for d in docs_and_scores]
return documents
[docs]
def similarity_search_with_score(
self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Run similarity search with Chroma with distance.
Args:
query (str): Query text to search for.
k (int): Number of results to return. Defaults to 4.
filter (Optional[dict]): Filter by metadata. Defaults to None.
Returns:
List[Tuple[Document, float]]: List of documents most similar to the query
text with distance in float.
"""
embedding = self._embedding.embed_query(query)
payload = {
"queryVector": embedding,
"column": "embedding",
"size": k,
}
if filter:
payload["filter"] = filter
r = self._client.data().vector_search(self._table_name, payload=payload)
if r.status_code != 200:
raise Exception(f"Error running similarity search: {r.status_code} {r}")
hits = r["records"]
docs_and_scores = [
(
Document(
page_content=hit["content"],
metadata=self._extractMetadata(hit),
),
hit["xata"]["score"],
)
for hit in hits
]
return docs_and_scores
def _extractMetadata(self, record: dict) -> dict:
"""Extract metadata from a record. Filters out known columns."""
metadata = {}
for key, val in record.items():
if key not in ["id", "content", "embedding", "xata"]:
metadata[key] = val
return metadata
[docs]
def delete(
self,
ids: Optional[List[str]] = None,
delete_all: Optional[bool] = None,
**kwargs: Any,
) -> None:
"""Delete by vector IDs.
Args:
ids: List of ids to delete.
delete_all: Delete all records in the table.
"""
if delete_all:
self._delete_all()
self.wait_for_indexing(ndocs=0)
elif ids is not None:
chunk_size = 500
for i in range(0, len(ids), chunk_size):
chunk = ids[i : i + chunk_size]
operations = [
{"delete": {"table": self._table_name, "id": id}} for id in chunk
]
self._client.records().transaction(payload={"operations": operations})
else:
raise ValueError("Either ids or delete_all must be set.")
def _delete_all(self) -> None:
"""Delete all records in the table."""
while True:
r = self._client.data().query(self._table_name, payload={"columns": ["id"]})
if r.status_code != 200:
raise Exception(f"Error running query: {r.status_code} {r}")
ids = [rec["id"] for rec in r["records"]]
if len(ids) == 0:
break
operations = [
{"delete": {"table": self._table_name, "id": id}} for id in ids
]
self._client.records().transaction(payload={"operations": operations})
[docs]
def wait_for_indexing(self, timeout: float = 5, ndocs: int = 1) -> None:
"""Wait for the search index to contain a certain number of
documents. Useful in tests.
"""
start = time.time()
while True:
r = self._client.data().search_table(
self._table_name, payload={"query": "", "page": {"size": 0}}
)
if r.status_code != 200:
raise Exception(f"Error running search: {r.status_code} {r}")
if r["totalCount"] == ndocs:
break
if time.time() - start > timeout:
raise Exception("Timed out waiting for indexing to complete.")
time.sleep(0.5)