Source code for langchain_community.vectorstores.sklearn

"""Wrapper around scikit-learn NearestNeighbors implementation.

The vector store can be persisted in json, bson or parquet format.
"""

import json
import math
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Type
from uuid import uuid4

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import guard_import
from langchain_core.vectorstores import VectorStore

from langchain_community.vectorstores.utils import maximal_marginal_relevance

DEFAULT_K = 4  # Number of Documents to return.
DEFAULT_FETCH_K = 20  # Number of Documents to initially fetch during MMR search.


[docs]class BaseSerializer(ABC): """Base class for serializing data."""
[docs] def __init__(self, persist_path: str) -> None: self.persist_path = persist_path
[docs] @classmethod @abstractmethod def extension(cls) -> str: """The file extension suggested by this serializer (without dot)."""
[docs] @abstractmethod def save(self, data: Any) -> None: """Saves the data to the persist_path"""
[docs] @abstractmethod def load(self) -> Any: """Loads the data from the persist_path"""
[docs]class JsonSerializer(BaseSerializer): """Serialize data in JSON using the json package from python standard library."""
[docs] @classmethod def extension(cls) -> str: return "json"
[docs] def save(self, data: Any) -> None: with open(self.persist_path, "w") as fp: json.dump(data, fp)
[docs] def load(self) -> Any: with open(self.persist_path, "r") as fp: return json.load(fp)
[docs]class BsonSerializer(BaseSerializer): """Serialize data in Binary JSON using the `bson` python package."""
[docs] def __init__(self, persist_path: str) -> None: super().__init__(persist_path) self.bson = guard_import("bson")
[docs] @classmethod def extension(cls) -> str: return "bson"
[docs] def save(self, data: Any) -> None: with open(self.persist_path, "wb") as fp: fp.write(self.bson.dumps(data))
[docs] def load(self) -> Any: with open(self.persist_path, "rb") as fp: return self.bson.loads(fp.read())
[docs]class ParquetSerializer(BaseSerializer): """Serialize data in `Apache Parquet` format using the `pyarrow` package."""
[docs] def __init__(self, persist_path: str) -> None: super().__init__(persist_path) self.pd = guard_import("pandas") self.pa = guard_import("pyarrow") self.pq = guard_import("pyarrow.parquet")
[docs] @classmethod def extension(cls) -> str: return "parquet"
[docs] def save(self, data: Any) -> None: df = self.pd.DataFrame(data) table = self.pa.Table.from_pandas(df) if os.path.exists(self.persist_path): backup_path = str(self.persist_path) + "-backup" os.rename(self.persist_path, backup_path) try: self.pq.write_table(table, self.persist_path) except Exception as exc: os.rename(backup_path, self.persist_path) raise exc else: os.remove(backup_path) else: self.pq.write_table(table, self.persist_path)
[docs] def load(self) -> Any: table = self.pq.read_table(self.persist_path) df = table.to_pandas() return {col: series.tolist() for col, series in df.items()}
SERIALIZER_MAP: Dict[str, Type[BaseSerializer]] = { "json": JsonSerializer, "bson": BsonSerializer, "parquet": ParquetSerializer, }
[docs]class SKLearnVectorStoreException(RuntimeError): """Exception raised by SKLearnVectorStore.""" pass
[docs]class SKLearnVectorStore(VectorStore): """Simple in-memory vector store based on the `scikit-learn` library `NearestNeighbors`."""
[docs] def __init__( self, embedding: Embeddings, *, persist_path: Optional[str] = None, serializer: Literal["json", "bson", "parquet"] = "json", metric: str = "cosine", **kwargs: Any, ) -> None: np = guard_import("numpy") sklearn_neighbors = guard_import("sklearn.neighbors", pip_name="scikit-learn") # non-persistent properties self._np = np self._neighbors = sklearn_neighbors.NearestNeighbors(metric=metric, **kwargs) self._neighbors_fitted = False self._embedding_function = embedding self._persist_path = persist_path self._serializer: Optional[BaseSerializer] = None if self._persist_path is not None: serializer_cls = SERIALIZER_MAP[serializer] self._serializer = serializer_cls(persist_path=self._persist_path) # data properties self._embeddings: List[List[float]] = [] self._texts: List[str] = [] self._metadatas: List[dict] = [] self._ids: List[str] = [] # cache properties self._embeddings_np: Any = np.asarray([]) if self._persist_path is not None and os.path.isfile(self._persist_path): self._load()
@property def embeddings(self) -> Embeddings: return self._embedding_function
[docs] def persist(self) -> None: if self._serializer is None: raise SKLearnVectorStoreException( "You must specify a persist_path on creation to persist the " "collection." ) data = { "ids": self._ids, "texts": self._texts, "metadatas": self._metadatas, "embeddings": self._embeddings, } self._serializer.save(data)
def _load(self) -> None: if self._serializer is None: raise SKLearnVectorStoreException( "You must specify a persist_path on creation to load the " "collection." ) data = self._serializer.load() self._embeddings = data["embeddings"] self._texts = data["texts"] self._metadatas = data["metadatas"] self._ids = data["ids"] self._update_neighbors()
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, **kwargs: Any, ) -> List[str]: _texts = list(texts) _ids = ids or [str(uuid4()) for _ in _texts] self._texts.extend(_texts) self._embeddings.extend(self._embedding_function.embed_documents(_texts)) self._metadatas.extend(metadatas or ([{}] * len(_texts))) self._ids.extend(_ids) self._update_neighbors() return _ids
def _update_neighbors(self) -> None: if len(self._embeddings) == 0: raise SKLearnVectorStoreException( "No data was added to SKLearnVectorStore." ) self._embeddings_np = self._np.asarray(self._embeddings) self._neighbors.fit(self._embeddings_np) self._neighbors_fitted = True def _similarity_index_search_with_score( self, query_embedding: List[float], *, k: int = DEFAULT_K, **kwargs: Any ) -> List[Tuple[int, float]]: """Search k embeddings similar to the query embedding. Returns a list of (index, distance) tuples.""" if not self._neighbors_fitted: raise SKLearnVectorStoreException( "No data was added to SKLearnVectorStore." ) neigh_dists, neigh_idxs = self._neighbors.kneighbors( [query_embedding], n_neighbors=k ) return list(zip(neigh_idxs[0], neigh_dists[0]))
[docs] def similarity_search_with_score( self, query: str, *, k: int = DEFAULT_K, **kwargs: Any ) -> List[Tuple[Document, float]]: query_embedding = self._embedding_function.embed_query(query) indices_dists = self._similarity_index_search_with_score( query_embedding, k=k, **kwargs ) return [ ( Document( page_content=self._texts[idx], metadata={"id": self._ids[idx], **self._metadatas[idx]}, ), dist, ) for idx, dist in indices_dists ]
def _similarity_search_with_relevance_scores( self, query: str, k: int = DEFAULT_K, **kwargs: Any ) -> List[Tuple[Document, float]]: docs_dists = self.similarity_search_with_score(query, k=k, **kwargs) docs, dists = zip(*docs_dists) scores = [1 / math.exp(dist) for dist in dists] return list(zip(list(docs), scores))
[docs] def max_marginal_relevance_search_by_vector( self, embedding: List[float], k: int = DEFAULT_K, fetch_k: int = DEFAULT_FETCH_K, lambda_mult: float = 0.5, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. fetch_k: Number of Documents to fetch to pass to MMR algorithm. lambda_mult: Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. Returns: List of Documents selected by maximal marginal relevance. """ indices_dists = self._similarity_index_search_with_score( embedding, k=fetch_k, **kwargs ) indices, _ = zip(*indices_dists) result_embeddings = self._embeddings_np[indices,] mmr_selected = maximal_marginal_relevance( self._np.array(embedding, dtype=self._np.float32), result_embeddings, k=k, lambda_mult=lambda_mult, ) mmr_indices = [indices[i] for i in mmr_selected] return [ Document( page_content=self._texts[idx], metadata={"id": self._ids[idx], **self._metadatas[idx]}, ) for idx in mmr_indices ]
[docs] @classmethod def from_texts( cls, texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, persist_path: Optional[str] = None, **kwargs: Any, ) -> "SKLearnVectorStore": vs = SKLearnVectorStore(embedding, persist_path=persist_path, **kwargs) vs.add_texts(texts, metadatas=metadatas, ids=ids) return vs