"""Wrapper around scikit-learn NearestNeighbors implementation.
The vector store can be persisted in json, bson or parquet format.
"""
import json
import math
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Type
from uuid import uuid4
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import guard_import
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
DEFAULT_K = 4 # Number of Documents to return.
DEFAULT_FETCH_K = 20 # Number of Documents to initially fetch during MMR search.
[docs]
class BaseSerializer(ABC):
"""Base class for serializing data."""
[docs]
def __init__(self, persist_path: str) -> None:
self.persist_path = persist_path
[docs]
@classmethod
@abstractmethod
def extension(cls) -> str:
"""The file extension suggested by this serializer (without dot)."""
[docs]
@abstractmethod
def save(self, data: Any) -> None:
"""Saves the data to the persist_path"""
[docs]
@abstractmethod
def load(self) -> Any:
"""Loads the data from the persist_path"""
[docs]
class JsonSerializer(BaseSerializer):
"""Serialize data in JSON using the json package from python standard library."""
[docs]
@classmethod
def extension(cls) -> str:
return "json"
[docs]
def save(self, data: Any) -> None:
with open(self.persist_path, "w") as fp:
json.dump(data, fp)
[docs]
def load(self) -> Any:
with open(self.persist_path, "r") as fp:
return json.load(fp)
[docs]
class BsonSerializer(BaseSerializer):
"""Serialize data in Binary JSON using the `bson` python package."""
[docs]
def __init__(self, persist_path: str) -> None:
super().__init__(persist_path)
self.bson = guard_import("bson")
[docs]
@classmethod
def extension(cls) -> str:
return "bson"
[docs]
def save(self, data: Any) -> None:
with open(self.persist_path, "wb") as fp:
fp.write(self.bson.dumps(data))
[docs]
def load(self) -> Any:
with open(self.persist_path, "rb") as fp:
return self.bson.loads(fp.read())
[docs]
class ParquetSerializer(BaseSerializer):
"""Serialize data in `Apache Parquet` format using the `pyarrow` package."""
[docs]
def __init__(self, persist_path: str) -> None:
super().__init__(persist_path)
self.pd = guard_import("pandas")
self.pa = guard_import("pyarrow")
self.pq = guard_import("pyarrow.parquet")
[docs]
@classmethod
def extension(cls) -> str:
return "parquet"
[docs]
def save(self, data: Any) -> None:
df = self.pd.DataFrame(data)
table = self.pa.Table.from_pandas(df)
if os.path.exists(self.persist_path):
backup_path = str(self.persist_path) + "-backup"
os.rename(self.persist_path, backup_path)
try:
self.pq.write_table(table, self.persist_path)
except Exception as exc:
os.rename(backup_path, self.persist_path)
raise exc
else:
os.remove(backup_path)
else:
self.pq.write_table(table, self.persist_path)
[docs]
def load(self) -> Any:
table = self.pq.read_table(self.persist_path)
df = table.to_pandas()
return {col: series.tolist() for col, series in df.items()}
SERIALIZER_MAP: Dict[str, Type[BaseSerializer]] = {
"json": JsonSerializer,
"bson": BsonSerializer,
"parquet": ParquetSerializer,
}
[docs]
class SKLearnVectorStoreException(RuntimeError):
"""Exception raised by SKLearnVectorStore."""
pass
[docs]
class SKLearnVectorStore(VectorStore):
"""Simple in-memory vector store based on the `scikit-learn` library
`NearestNeighbors`."""
[docs]
def __init__(
self,
embedding: Embeddings,
*,
persist_path: Optional[str] = None,
serializer: Literal["json", "bson", "parquet"] = "json",
metric: str = "cosine",
**kwargs: Any,
) -> None:
np = guard_import("numpy")
sklearn_neighbors = guard_import("sklearn.neighbors", pip_name="scikit-learn")
# non-persistent properties
self._np = np
self._neighbors = sklearn_neighbors.NearestNeighbors(metric=metric, **kwargs)
self._neighbors_fitted = False
self._embedding_function = embedding
self._persist_path = persist_path
self._serializer: Optional[BaseSerializer] = None
if self._persist_path is not None:
serializer_cls = SERIALIZER_MAP[serializer]
self._serializer = serializer_cls(persist_path=self._persist_path)
# data properties
self._embeddings: List[List[float]] = []
self._texts: List[str] = []
self._metadatas: List[dict] = []
self._ids: List[str] = []
# cache properties
self._embeddings_np: Any = np.asarray([])
if self._persist_path is not None and os.path.isfile(self._persist_path):
self._load()
@property
def embeddings(self) -> Embeddings:
return self._embedding_function
[docs]
def persist(self) -> None:
if self._serializer is None:
raise SKLearnVectorStoreException(
"You must specify a persist_path on creation to persist the "
"collection."
)
data = {
"ids": self._ids,
"texts": self._texts,
"metadatas": self._metadatas,
"embeddings": self._embeddings,
}
self._serializer.save(data)
def _load(self) -> None:
if self._serializer is None:
raise SKLearnVectorStoreException(
"You must specify a persist_path on creation to load the " "collection."
)
data = self._serializer.load()
self._embeddings = data["embeddings"]
self._texts = data["texts"]
self._metadatas = data["metadatas"]
self._ids = data["ids"]
self._update_neighbors()
[docs]
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
_texts = list(texts)
_ids = ids or [str(uuid4()) for _ in _texts]
self._texts.extend(_texts)
self._embeddings.extend(self._embedding_function.embed_documents(_texts))
self._metadatas.extend(metadatas or ([{}] * len(_texts)))
self._ids.extend(_ids)
self._update_neighbors()
return _ids
def _update_neighbors(self) -> None:
if len(self._embeddings) == 0:
raise SKLearnVectorStoreException(
"No data was added to SKLearnVectorStore."
)
self._embeddings_np = self._np.asarray(self._embeddings)
self._neighbors.fit(self._embeddings_np)
self._neighbors_fitted = True
def _similarity_index_search_with_score(
self, query_embedding: List[float], *, k: int = DEFAULT_K, **kwargs: Any
) -> List[Tuple[int, float]]:
"""Search k embeddings similar to the query embedding. Returns a list of
(index, distance) tuples."""
if not self._neighbors_fitted:
raise SKLearnVectorStoreException(
"No data was added to SKLearnVectorStore."
)
neigh_dists, neigh_idxs = self._neighbors.kneighbors(
[query_embedding], n_neighbors=k
)
return list(zip(neigh_idxs[0], neigh_dists[0]))
[docs]
def similarity_search_with_score(
self, query: str, *, k: int = DEFAULT_K, **kwargs: Any
) -> List[Tuple[Document, float]]:
query_embedding = self._embedding_function.embed_query(query)
indices_dists = self._similarity_index_search_with_score(
query_embedding, k=k, **kwargs
)
return [
(
Document(
page_content=self._texts[idx],
metadata={"id": self._ids[idx], **self._metadatas[idx]},
),
dist,
)
for idx, dist in indices_dists
]
[docs]
def similarity_search(
self, query: str, k: int = DEFAULT_K, **kwargs: Any
) -> List[Document]:
docs_scores = self.similarity_search_with_score(query, k=k, **kwargs)
return [doc for doc, _ in docs_scores]
def _similarity_search_with_relevance_scores(
self, query: str, k: int = DEFAULT_K, **kwargs: Any
) -> List[Tuple[Document, float]]:
docs_dists = self.similarity_search_with_score(query, k=k, **kwargs)
docs, dists = zip(*docs_dists)
scores = [1 / math.exp(dist) for dist in dists]
return list(zip(list(docs), scores))
[docs]
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = DEFAULT_K,
fetch_k: int = DEFAULT_FETCH_K,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
indices_dists = self._similarity_index_search_with_score(
embedding, k=fetch_k, **kwargs
)
indices, _ = zip(*indices_dists)
result_embeddings = self._embeddings_np[indices,]
mmr_selected = maximal_marginal_relevance(
self._np.array(embedding, dtype=self._np.float32),
result_embeddings,
k=k,
lambda_mult=lambda_mult,
)
mmr_indices = [indices[i] for i in mmr_selected]
return [
Document(
page_content=self._texts[idx],
metadata={"id": self._ids[idx], **self._metadatas[idx]},
)
for idx in mmr_indices
]
[docs]
def max_marginal_relevance_search(
self,
query: str,
k: int = DEFAULT_K,
fetch_k: int = DEFAULT_FETCH_K,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
if self._embedding_function is None:
raise ValueError(
"For MMR search, you must specify an embedding function on creation."
)
embedding = self._embedding_function.embed_query(query)
docs = self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, lambda_mul=lambda_mult
)
return docs
[docs]
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
persist_path: Optional[str] = None,
**kwargs: Any,
) -> "SKLearnVectorStore":
vs = SKLearnVectorStore(embedding, persist_path=persist_path, **kwargs)
vs.add_texts(texts, metadatas=metadatas, ids=ids)
return vs