Source code for langchain_weaviate.vectorstores

from __future__ import annotations

import datetime
import logging
from collections.abc import Generator
from contextlib import contextmanager
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Literal,
    Optional,
    Tuple,
    Union,
    overload,
)
from uuid import uuid4

import numpy as np
import weaviate  # type: ignore
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore

from langchain_weaviate.utils import maximal_marginal_relevance

if TYPE_CHECKING:
    import weaviate


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

handler = logging.StreamHandler()
formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%b-%d %I:%M %p"
)
handler.setFormatter(formatter)

logger.addHandler(handler)


def _default_schema(index_name: str) -> Dict:
    return {
        "class": index_name,
        "properties": [
            {
                "name": "text",
                "dataType": ["text"],
            }
        ],
    }


def _default_score_normalizer(val: float) -> float:
    # prevent overflow
    # use 709 because that's the largest exponent that doesn't overflow
    # use -709 because that's the smallest exponent that doesn't underflow
    val = np.clip(val, -709, 709)
    return 1 - 1 / (1 + np.exp(val))


def _json_serializable(value: Any) -> Any:
    if isinstance(value, datetime.datetime):
        return value.isoformat()
    return value


[docs] class WeaviateVectorStore(VectorStore): """`Weaviate` vector store. To use, you should have the ``weaviate-client`` python package installed. Example: .. code-block:: python import weaviate from langchain_community.vectorstores import Weaviate client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...) weaviate = Weaviate(client, index_name, text_key) """
[docs] def __init__( self, client: weaviate.WeaviateClient, index_name: Optional[str], text_key: str, embedding: Optional[Embeddings] = None, attributes: Optional[List[str]] = None, relevance_score_fn: Optional[ Callable[[float], float] ] = _default_score_normalizer, use_multi_tenancy: bool = False, ): """Initialize with Weaviate client.""" if not isinstance(client, weaviate.WeaviateClient): raise ValueError( "client should be an instance of" f" weaviate.WeaviateClient, got {type(client)}" ) self._client = client self._index_name = index_name or f"LangChain_{uuid4().hex}" self._embedding = embedding self._text_key = text_key self._query_attrs = [self._text_key] self.relevance_score_fn = relevance_score_fn if attributes is not None: self._query_attrs.extend(attributes) schema = _default_schema(self._index_name) schema["MultiTenancyConfig"] = {"enabled": use_multi_tenancy} # check whether the index already exists if not client.collections.exists(self._index_name): client.collections.create_from_dict(schema) # store collection for convenience # this does not actually send a request to weaviate self._collection = client.collections.get(self._index_name) # store this setting so we don't have to send a request to weaviate # every time we want to do a CRUD operation self._multi_tenancy_enabled = self._collection.config.get( simple=False ).multi_tenancy_config.enabled
@property def embeddings(self) -> Optional[Embeddings]: return self._embedding def _select_relevance_score_fn(self) -> Callable[[float], float]: return ( self.relevance_score_fn if self.relevance_score_fn else _default_score_normalizer )
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, tenant: Optional[str] = None, **kwargs: Any, ) -> List[str]: """Upload texts with metadata (properties) to Weaviate.""" from weaviate.util import get_valid_uuid # type: ignore if tenant and not self._does_tenant_exist(tenant): logger.info( f"Tenant {tenant} does not exist in index {self._index_name}. " "Creating tenant." ) tenant_objs = [weaviate.classes.tenants.Tenant(name=tenant)] self._collection.tenants.create(tenants=tenant_objs) ids = [] embeddings: Optional[List[List[float]]] = None if self._embedding: embeddings = self._embedding.embed_documents(list(texts)) with self._client.batch.dynamic() as batch: for i, text in enumerate(texts): data_properties = {self._text_key: text} if metadatas is not None: for key, val in metadatas[i].items(): data_properties[key] = _json_serializable(val) # Allow for ids (consistent w/ other methods) # # Or uuids (backwards compatible w/ existing arg) # If the UUID of one of the objects already exists # then the existing object will be replaced by the new object. _id = get_valid_uuid(uuid4()) if "uuids" in kwargs: _id = kwargs["uuids"][i] elif "ids" in kwargs: _id = kwargs["ids"][i] batch.add_object( collection=self._index_name, properties=data_properties, uuid=_id, vector=embeddings[i] if embeddings else None, tenant=tenant, ) ids.append(_id) failed_objs = self._client.batch.failed_objects for obj in failed_objs: err_message = ( f"Failed to add object: {obj.original_uuid}\nReason: {obj.message}" ) logger.error(err_message) return ids
@overload def _perform_search( self, query: Optional[str], k: int, return_score: Literal[False] = False, tenant: Optional[str] = None, **kwargs: Any, ) -> List[Document]: ... @overload def _perform_search( self, query: Optional[str], k: int, return_score: Literal[True], tenant: Optional[str] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: ... def _perform_search( self, query: Optional[str], k: int, return_score: bool = False, tenant: Optional[str] = None, **kwargs: Any, ) -> Union[List[Document], List[Tuple[Document, float]]]: """ Perform a similarity search. Parameters: query (str): The query string to search for. k (int): The number of results to return. return_score (bool, optional): Whether to return the score along with the document. Defaults to False. tenant (Optional[str], optional): The tenant name. Defaults to None. **kwargs: Additional parameters to pass to the search method. These parameters will be directly passed to the underlying Weaviate client's search method. Returns: List[Union[Document, Tuple[Document, float]]]: A list of documents that match the query. If return_score is True, each document is returned as a tuple with the document and its score. Raises: ValueError: If _embedding is None or an invalid search method is provided. """ if self._embedding is None: raise ValueError("_embedding cannot be None for similarity_search") if "return_metadata" not in kwargs: kwargs["return_metadata"] = ["score"] elif "score" not in kwargs["return_metadata"]: kwargs["return_metadata"].append("score") if ( "return_properties" in kwargs and self._text_key not in kwargs["return_properties"] ): kwargs["return_properties"].append(self._text_key) vector = kwargs.pop("vector", None) # workaround to handle test_max_marginal_relevance_search if vector is None: if query is None: # raise an error because weaviate will do a fetch object query # if both query and vector are None raise ValueError("Either query or vector must be provided.") else: vector = self._embedding.embed_query(query) return_uuids = kwargs.pop("return_uuids", False) with self._tenant_context(tenant) as collection: try: result = collection.query.hybrid( query=query, vector=vector, limit=k, **kwargs ) except weaviate.exceptions.WeaviateQueryException as e: raise ValueError(f"Error during query: {e}") docs_and_scores: List[Tuple[Document, float]] = [] for obj in result.objects: text = obj.properties.pop(self._text_key) filtered_metadata = { k: v for k, v in obj.metadata.__dict__.items() if v is not None and k != "score" } merged_props = { **obj.properties, **filtered_metadata, **({"vector": obj.vector["default"]} if obj.vector else {}), **({"uuid": str(obj.uuid)} if return_uuids else {}), } doc = Document(page_content=text, metadata=merged_props) score = obj.metadata.score docs_and_scores.append((doc, score)) if return_score: return docs_and_scores else: return [doc for doc, _ in docs_and_scores]
[docs] def max_marginal_relevance_search_by_vector( self, embedding: List[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. fetch_k: Number of Documents to fetch to pass to MMR algorithm. lambda_mult: Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. Returns: List of Documents selected by maximal marginal relevance. """ results = self._perform_search( query=None, k=fetch_k, include_vector=True, vector=embedding, **kwargs, ) embeddings = [result.metadata["vector"] for result in results] mmr_selected = maximal_marginal_relevance( np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult ) docs = [] for idx in mmr_selected: text = results[idx].page_content results[idx].metadata.pop("vector") docs.append(Document(page_content=text, metadata=results[idx].metadata)) return docs
[docs] def similarity_search_with_score( self, query: str, k: int = 4, **kwargs: Any ) -> List[Tuple[Document, float]]: """ Return list of documents most similar to the query text and cosine distance in float for each. Lower score represents more similarity. """ results = self._perform_search(query, k, return_score=True, **kwargs) return results
[docs] @classmethod def from_texts( cls, texts: List[str], embedding: Optional[Embeddings], metadatas: Optional[List[dict]] = None, *, tenant: Optional[str] = None, client: weaviate.WeaviateClient = None, index_name: Optional[str] = None, text_key: str = "text", relevance_score_fn: Optional[ Callable[[float], float] ] = _default_score_normalizer, **kwargs: Any, ) -> WeaviateVectorStore: """Construct Weaviate wrapper from raw documents. This is a user-friendly interface that: 1. Embeds documents. 2. Creates a new index for the embeddings in the Weaviate instance. 3. Adds the documents to the newly created Weaviate index. This is intended to be a quick way to get started. Args: texts: Texts to add to vector store. embedding: Text embedding model to use. client: weaviate.Client to use. metadatas: Metadata associated with each text. tenant: The tenant name. Defaults to None. index_name: Index name. text_key: Key to use for uploading/retrieving text to/from vectorstore. relevance_score_fn: Function for converting whatever distance function the vector store uses to a relevance score, which is a normalized similarity score (0 means dissimilar, 1 means similar). **kwargs: Additional named parameters to pass to ``Weaviate.__init__()``. Example: .. code-block:: python from langchain_community.embeddings import OpenAIEmbeddings from langchain_community.vectorstores import Weaviate embeddings = OpenAIEmbeddings() weaviate = Weaviate.from_texts( texts, embeddings, client=client ) """ attributes = list(metadatas[0].keys()) if metadatas else None weaviate_vector_store = cls( client, index_name, text_key, embedding=embedding, attributes=attributes, relevance_score_fn=relevance_score_fn, use_multi_tenancy=tenant is not None, ) weaviate_vector_store.add_texts(texts, metadatas, tenant=tenant, **kwargs) return weaviate_vector_store
[docs] def delete( self, ids: Optional[List[str]] = None, tenant: Optional[str] = None, **kwargs: Any, ) -> None: """Delete by vector IDs. Args: ids: List of ids to delete. tenant: The tenant name. Defaults to None. """ if ids is None: raise ValueError("No ids provided to delete.") id_filter = weaviate.classes.query.Filter.by_id().contains_any(ids) with self._tenant_context(tenant) as collection: collection.data.delete_many(where=id_filter)
def _does_tenant_exist(self, tenant: str) -> bool: """Check if tenant exists in Weaviate.""" assert ( self._multi_tenancy_enabled ), "Cannot check for tenant existence when multi-tenancy is not enabled" tenants = self._collection.tenants.get() return tenant in tenants @contextmanager def _tenant_context( self, tenant: Optional[str] = None ) -> Generator[weaviate.collections.Collection, None, None]: """Context manager for handling tenants. Args: tenant: The tenant name. Defaults to None. """ if tenant is not None and not self._multi_tenancy_enabled: raise ValueError( "Cannot use tenant context when multi-tenancy is not enabled" ) if tenant is None and self._multi_tenancy_enabled: raise ValueError("Must use tenant context when multi-tenancy is enabled") try: yield self._collection.with_tenant(tenant) finally: pass