[docs]classXataVectorStore(VectorStore):"""`Xata` vector store. It assumes you have a Xata database created with the right schema. See the guide at: https://integrations.langchain.com/vectorstores?integration_name=XataVectorStore """
[docs]def__init__(self,api_key:str,db_url:str,embedding:Embeddings,table_name:str,)->None:"""Initialize with Xata client."""try:fromxata.clientimportXataClientexceptImportError:raiseImportError("Could not import xata python package. ""Please install it with `pip install xata`.")self._client=XataClient(api_key=api_key,db_url=db_url)self._embedding:Embeddings=embeddingself._table_name=table_nameor"vectors"
def_add_vectors(self,vectors:List[List[float]],documents:List[Document],ids:Optional[List[str]]=None,)->List[str]:"""Add vectors to the Xata database."""rows:List[Dict[str,Any]]=[]foridx,embeddinginenumerate(vectors):row={"content":documents[idx].page_content,"embedding":embedding,}ifids:row["id"]=ids[idx]forkey,valindocuments[idx].metadata.items():ifkeynotin["id","content","embedding"]:row[key]=valrows.append(row)# XXX: I would have liked to use the BulkProcessor here, but it# doesn't return the IDs, which we need here. Manual chunking it is.chunk_size=1000id_list:List[str]=[]foriinrange(0,len(rows),chunk_size):chunk=rows[i:i+chunk_size]r=self._client.records().bulk_insert(self._table_name,{"records":chunk})ifr.status_code!=200:raiseException(f"Error adding vectors to Xata: {r.status_code}{r}")id_list.extend(r["recordIDs"])returnid_list@staticmethoddef_texts_to_documents(texts:Iterable[str],metadatas:Optional[Iterable[Dict[Any,Any]]]=None,)->List[Document]:"""Return list of Documents from list of texts and metadatas."""ifmetadatasisNone:metadatas=repeat({})docs=[Document(page_content=text,metadata=metadata)fortext,metadatainzip(texts,metadatas)]returndocs
[docs]@classmethoddeffrom_texts(cls:Type["XataVectorStore"],texts:List[str],embedding:Embeddings,metadatas:Optional[List[dict]]=None,api_key:Optional[str]=None,db_url:Optional[str]=None,table_name:str="vectors",ids:Optional[List[str]]=None,**kwargs:Any,)->"XataVectorStore":"""Return VectorStore initialized from texts and embeddings."""ifnotapi_keyornotdb_url:raiseValueError("Xata api_key and db_url must be set.")embeddings=embedding.embed_documents(texts)ids=None# Xata will generate them for usdocs=cls._texts_to_documents(texts,metadatas)vector_db=cls(api_key=api_key,db_url=db_url,embedding=embedding,table_name=table_name,)vector_db._add_vectors(embeddings,docs,ids)returnvector_db
[docs]defsimilarity_search(self,query:str,k:int=4,filter:Optional[dict]=None,**kwargs:Any)->List[Document]:"""Return docs most similar to query. Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. Returns: List of Documents most similar to the query. """docs_and_scores=self.similarity_search_with_score(query,k,filter=filter)documents=[d[0]fordindocs_and_scores]returndocuments
[docs]defsimilarity_search_with_score(self,query:str,k:int=4,filter:Optional[dict]=None,**kwargs:Any)->List[Tuple[Document,float]]:"""Run similarity search with Chroma with distance. Args: query (str): Query text to search for. k (int): Number of results to return. Defaults to 4. filter (Optional[dict]): Filter by metadata. Defaults to None. Returns: List[Tuple[Document, float]]: List of documents most similar to the query text with distance in float. """embedding=self._embedding.embed_query(query)payload={"queryVector":embedding,"column":"embedding","size":k,}iffilter:payload["filter"]=filterr=self._client.data().vector_search(self._table_name,payload=payload)ifr.status_code!=200:raiseException(f"Error running similarity search: {r.status_code}{r}")hits=r["records"]docs_and_scores=[(Document(page_content=hit["content"],metadata=self._extractMetadata(hit),),hit["xata"]["score"],)forhitinhits]returndocs_and_scores
def_extractMetadata(self,record:dict)->dict:"""Extract metadata from a record. Filters out known columns."""metadata={}forkey,valinrecord.items():ifkeynotin["id","content","embedding","xata"]:metadata[key]=valreturnmetadata
[docs]defdelete(self,ids:Optional[List[str]]=None,delete_all:Optional[bool]=None,**kwargs:Any,)->None:"""Delete by vector IDs. Args: ids: List of ids to delete. delete_all: Delete all records in the table. """ifdelete_all:self._delete_all()self.wait_for_indexing(ndocs=0)elifidsisnotNone:chunk_size=500foriinrange(0,len(ids),chunk_size):chunk=ids[i:i+chunk_size]operations=[{"delete":{"table":self._table_name,"id":id}}foridinchunk]self._client.records().transaction(payload={"operations":operations})else:raiseValueError("Either ids or delete_all must be set.")
def_delete_all(self)->None:"""Delete all records in the table."""whileTrue:r=self._client.data().query(self._table_name,payload={"columns":["id"]})ifr.status_code!=200:raiseException(f"Error running query: {r.status_code}{r}")ids=[rec["id"]forrecinr["records"]]iflen(ids)==0:breakoperations=[{"delete":{"table":self._table_name,"id":id}}foridinids]self._client.records().transaction(payload={"operations":operations})
[docs]defwait_for_indexing(self,timeout:float=5,ndocs:int=1)->None:"""Wait for the search index to contain a certain number of documents. Useful in tests. """start=time.time()whileTrue:r=self._client.data().search_table(self._table_name,payload={"query":"","page":{"size":0}})ifr.status_code!=200:raiseException(f"Error running search: {r.status_code}{r}")ifr["totalCount"]==ndocs:breakiftime.time()-start>timeout:raiseException("Timed out waiting for indexing to complete.")time.sleep(0.5)