Source code for langchain_db2.db2vs

"""DB2 vector store wrapper."""

from __future__ import annotations

import functools
import hashlib
import json
import logging
import os
import re
import uuid
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    TypeVar,
    cast,
)

import ibm_db_dbi  # type: ignore[import-untyped]
import numpy as np
from langchain_community.vectorstores.utils import (
    DistanceStrategy,
    maximal_marginal_relevance,
)
from langchain_core.documents import Document
from langchain_core.vectorstores import VectorStore

from langchain_db2.utils import EmbeddingsSchema

if TYPE_CHECKING:
    from collections.abc import Iterable

    from ibm_db_dbi import Connection
    from langchain_core.embeddings import Embeddings

logger = logging.getLogger(__name__)
log_level = os.getenv("LOG_LEVEL", "ERROR").upper()
logging.basicConfig(
    level=getattr(logging, log_level),
    format="%(asctime)s - %(levelname)s - %(message)s",
)


# Define a type variable that can be any kind of function
T = TypeVar("T", bound=Callable[..., Any])


def _handle_exceptions(func: T) -> T:
    @functools.wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        try:
            return func(*args, **kwargs)
        except RuntimeError as db_err:
            # Handle a known type of error (e.g., DB-related) specifically
            exception_msg = "DB-related error occurred."
            logger.exception(exception_msg)
            error_msg = f"Failed due to a DB issue: {db_err}"
            raise RuntimeError(error_msg) from db_err
        except ValueError as val_err:
            # Handle another known type of error specifically
            exception_msg = "Validation error."
            logger.exception(exception_msg)
            error_msg = f"Validation failed: {val_err}"
            raise ValueError(error_msg) from val_err
        except Exception as e:
            # Generic handler for all other exceptions
            exception_msg = f"An unexpected error occurred: {e}"
            logger.exception(exception_msg)
            error_msg = f"Unexpected error: {e}"
            raise RuntimeError(error_msg) from e

    return cast("T", wrapper)


def _table_exists(client: Connection, table_name: str) -> bool:
    cursor = client.cursor()
    try:
        cursor.execute(f"SELECT COUNT(*) FROM {table_name}")  # noqa: S608
    except Exception as ex:
        if "SQL0204N" in str(ex):
            return False
        raise
    finally:
        cursor.close()
    return True


def _get_distance_function(distance_strategy: DistanceStrategy) -> str:
    # Dictionary to map distance strategies to their corresponding function
    # names
    distance_strategy2function = {
        DistanceStrategy.EUCLIDEAN_DISTANCE: "EUCLIDEAN",
        DistanceStrategy.DOT_PRODUCT: "DOT",
        DistanceStrategy.COSINE: "COSINE",
    }

    # Attempt to return the corresponding distance function
    if distance_strategy in distance_strategy2function:
        return distance_strategy2function[distance_strategy]

    # If it's an unsupported distance strategy, raise an error
    error_msg = f"Unsupported distance strategy: {distance_strategy}"
    raise ValueError(error_msg)


@_handle_exceptions
def _create_table(
    client: Connection,
    table_name: str,
    embedding_dim: int,
    text_field: str = "text",
) -> None:
    cols_dict = {
        "id": "CHAR(16) PRIMARY KEY NOT NULL",
        text_field: "CLOB",
        "metadata": "BLOB",
        "embedding": f"vector({embedding_dim}, FLOAT32)",
    }

    if not _table_exists(client, table_name):
        cursor = client.cursor()
        ddl_body = ", ".join(
            f"{col_name} {col_type}" for col_name, col_type in cols_dict.items()
        )
        ddl = f"CREATE TABLE {table_name} ({ddl_body})"
        try:
            cursor.execute(ddl)
            cursor.execute("COMMIT")
            info_msg = f"Table {table_name} created successfully..."
            logger.info(info_msg)
        finally:
            cursor.close()
    else:
        info_msg = f"Table {table_name} already exists..."
        logger.info(info_msg)


[docs] @_handle_exceptions def drop_table(client: Connection, table_name: str) -> None: """Drop a table from the database. Args: client: The `ibm_db_dbi` connection object table_name: The name of the table to drop Raises: RuntimeError: If an error occurs while dropping the table ??? example "Example" ```python from langchain_db2.db2vs import drop_table drop_table( client=db_client, # ibm_db_dbi.Connection table_name="TABLE_NAME", ) ``` """ if _table_exists(client, table_name): cursor = client.cursor() ddl = f"DROP TABLE {table_name}" try: cursor.execute(ddl) cursor.execute("COMMIT") info_msg = f"Table {table_name} dropped successfully..." logger.info(info_msg) finally: cursor.close() else: info_msg = f"Table {table_name} not found..." logger.info(info_msg)
[docs] @_handle_exceptions def clear_table(client: Connection, table_name: str) -> None: """Remove all records from the table using TRUNCATE. Args: client: The ibm_db_dbi connection object table_name: The name of the table to clear ??? example "Example" ```python from langchain_db2.db2vs import clear_table clear_table( client=db_client, # ibm_db_dbi.Connection table_name="TABLE_NAME", ) ``` """ if not _table_exists(client, table_name): info_msg = f"Table {table_name} not found…" logger.info(info_msg) return cursor = client.cursor() ddl = f"TRUNCATE TABLE {table_name} IMMEDIATE" try: client.commit() cursor.execute(ddl) client.commit() info_msg = f"Table {table_name} cleared successfully." logger.info(info_msg) except Exception: client.rollback() exception_msg = f"Failed to clear table {table_name}. Rolled back." logger.exception(exception_msg) raise finally: cursor.close()
[docs] class DB2VS(VectorStore): """`DB2VS` vector store. Args: embedding_function: The embedding backend used to generate vectors for stored texts and queries table_name: DB2 table name client: Existing DB2 connection. Required if `connection_args` is not provided distance_strategy: Similarity metric used by Db2 `VECTOR_DISTANCE` when ranking results query: Probe text used once to infer embedding dimension params: Extra options connection_args: Connection parameters used when `client` is not supplied. Expected keys: `{"database": str, "host": str, "port": str, "username": str, "password": str, "security": bool}` text_field: Column name for the raw text (CLOB) ???+ info "Setup" To use, you should have: - the `langchain_db2` python package installed - a connection to db2 database with vector store feature (v12.1.2+) ```bash pip install -U langchain-db2 # or using uv uv add langchain-db2 ``` ??? info "Instantiate" Create a Vector Store instance with `ibm_db_dbi.Connection` object ```python from langchain_db2 import DB2VS db2vs = DB2VS( embedding_function=embeddings, table_name=table_name, client=db_client ) ``` Create a Vector Store instance with `connection_args` ```python from langchain_db2 import DB2VS db2vs = DB2VS( embedding_function=embeddings, table_name=table_name, connection_args={ "database": "<DATABASE>", "host": "<HOST>", "port": "<PORT>", "username": "<USERNAME>", "password": "<PASSWORD>", "security": False, }, ) ``` """
[docs] def __init__( self, embedding_function: Callable[[str], list[float]] | Embeddings, table_name: str, client: Connection | None = None, distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE, query: str | None = "What is a Db2 database", params: dict[str, Any] | None = None, connection_args: dict[str, Any] | None = None, text_field: str = "text", ): """`DB2VS` vector store.""" if client is None: if connection_args is not None: database = connection_args.get("database") host = connection_args.get("host") port = connection_args.get("port") username = connection_args.get("username") password = connection_args.get("password") conn_str = ( f"DATABASE={database};hostname={host};port={port};" f"uid={username};pwd={password};" ) if "security" in connection_args: security = connection_args.get("security") conn_str += f"security={security};" self.client = ibm_db_dbi.connect(conn_str, "", "") else: error_msg = "No valid connection or connection_args is passed" raise ValueError(error_msg) else: """Initialize with ibm_db_dbi client.""" self.client = client try: """Initialize with necessary components.""" if not isinstance(embedding_function, EmbeddingsSchema): logger.warning( "`embedding_function` is expected to be an Embeddings " "object, support for passing in a function will soon " "be removed.", ) self.embedding_function = embedding_function self.query = query embedding_dim = self.get_embedding_dimension() self.table_name = table_name self.distance_strategy = distance_strategy self.params = params self._text_field = text_field _create_table( self.client, self.table_name, embedding_dim, text_field=self._text_field, ) except ibm_db_dbi.DatabaseError as db_err: exception_msg = f"Database error occurred while create table: {db_err}" logger.exception(exception_msg) error_msg = "Failed to create table due to a database error." raise RuntimeError(error_msg) from db_err except ValueError as val_err: exception_msg = f"Validation error: {val_err}" logger.exception(exception_msg) error_msg = "Failed to create table due to a validation error." raise RuntimeError(error_msg) from val_err except Exception as ex: exception_msg = "An unexpected error occurred while creating the table." logger.exception(exception_msg) error_msg = "Failed to create table due to an unexpected error." raise RuntimeError(error_msg) from ex
@property def embeddings(self) -> Embeddings | None: """A property that returns an Embeddings instance. Returns: Embeddings instance if embedding_function is an instance of Embeddings, otherwise returns None """ return ( self.embedding_function if isinstance(self.embedding_function, EmbeddingsSchema) else None )
[docs] def get_embedding_dimension(self) -> int: """Embed the single document by wrapping it in a list.""" embedded_document = self._embed_documents( [self.query if self.query is not None else ""], ) # Get the first (and only) embedding's dimension return len(embedded_document[0])
def _embed_documents(self, texts: list[str]) -> list[list[float]]: if isinstance(self.embedding_function, EmbeddingsSchema): return self.embedding_function.embed_documents(texts) if callable(self.embedding_function): return [self.embedding_function(text) for text in texts] error_msg = "The embedding_function is neither Embeddings nor callable." # type: ignore[unreachable] raise TypeError(error_msg) def _embed_query(self, text: str) -> list[float]: if isinstance(self.embedding_function, EmbeddingsSchema): return self.embedding_function.embed_query(text) return self.embedding_function(text)
[docs] @_handle_exceptions def add_texts( self, texts: Iterable[str], metadatas: list[dict[Any, Any]] | None = None, ids: list[str] | None = None, **kwargs: Any, ) -> list[str]: """Add more texts to the vectorstore. Args: texts: Iterable of strings to add to the vectorstore metadatas: Optional list of metadatas associated with the texts ids: Optional list of ids for the texts that are being added to the vector store kwargs: vectorstore specific parameters Returns: List of ids from adding the texts into the vectorstore """ texts = list(texts) if metadatas and len(metadatas) != len(texts): msg = ( f"metadatas must be the same length as texts. " f"Got {len(metadatas)} metadatas and {len(texts)} texts." ) raise ValueError(msg) if ids: if len(ids) != len(texts): msg = ( f"ids must be the same length as texts. " f"Got {len(ids)} ids and {len(texts)} texts." ) raise ValueError(msg) # If ids are provided, hash them to maintain consistency processed_ids = [ hashlib.sha256(_id.encode()).hexdigest()[:16].upper() for _id in ids ] elif metadatas: if all("id" in metadata for metadata in metadatas): # If no ids are provided but metadatas with ids are, generate # ids from metadatas processed_ids = [ hashlib.sha256(metadata["id"].encode()).hexdigest()[:16].upper() for metadata in metadatas ] else: # In the case partial metadata has id, generate new id if metadate # doesn't have it. processed_ids = [] for metadata in metadatas: if "id" in metadata: processed_ids.append( hashlib.sha256(metadata["id"].encode()) .hexdigest()[:16] .upper(), ) else: processed_ids.append( hashlib.sha256(str(uuid.uuid4()).encode()) .hexdigest()[:16] .upper(), ) else: # Generate new ids if none are provided generated_ids = [ str(uuid.uuid4()) for _ in texts ] # uuid4 is more standard for random UUIDs processed_ids = [ hashlib.sha256(_id.encode()).hexdigest()[:16].upper() for _id in generated_ids ] embeddings = self._embed_documents(texts) if not metadatas: metadatas = [{} for _ in texts] embedding_len = self.get_embedding_dimension() docs: list[tuple[Any, Any, Any, Any]] docs = [ (id_, f"{embedding}", json.dumps(metadata), text) for id_, embedding, metadata, text in zip( processed_ids, embeddings, metadatas, texts, ) ] sql_insert = ( f"INSERT INTO " # noqa: S608 f"{self.table_name} (id, embedding, metadata, {self._text_field}) " f"VALUES (?, VECTOR(?, {embedding_len}, FLOAT32), SYSTOOLS.JSON2BSON(?), ?)" ) cursor = self.client.cursor() try: cursor.executemany(sql_insert, docs) cursor.execute("COMMIT") finally: cursor.close() return processed_ids
[docs] def similarity_search_by_vector( self, embedding: list[float], k: int = 4, filter: dict[str, Any] | None = None, # noqa: A002 **kwargs: Any, ) -> list[Document]: """Return documents most similar to a query embedding. Args: embedding: Embedding to look up documents similar to k: Number of Documents to return filter: Filter by metadata kwargs: Additional keyword args Returns: Documents ordered from most to least similar """ docs_and_scores = self.similarity_search_by_vector_with_relevance_scores( embedding=embedding, k=k, filter=filter, **kwargs, ) return [doc for doc, _ in docs_and_scores]
[docs] def similarity_search_with_score( self, query: str, k: int = 4, filter: dict[str, Any] | None = None, # noqa: A002 **kwargs: Any, ) -> list[tuple[Document, float]]: """Return the top-k documents most similar to a text query, with scores. Args: query: Natural-language query to embed and search with k: Number of results to return filter: Filter by metadata kwargs: Additional keyword args Returns: A list of (document, score) pairs ordered by similarity. The score is the vector **distance**; lower values indicate closer matches. """ if isinstance(self.embedding_function, EmbeddingsSchema): embedding = self.embedding_function.embed_query(query) return self.similarity_search_by_vector_with_relevance_scores( embedding=embedding, k=k, filter=filter, **kwargs, )
[docs] @_handle_exceptions def similarity_search_by_vector_with_relevance_scores( self, embedding: list[float], k: int = 4, filter: dict[str, Any] | None = None, # noqa: A002 ) -> list[tuple[Document, float]]: """Return top-k documents for a query embedding, with relevance scores. Args: embedding: Embedding to look up documents similar to k: Number of Documents to return filter: Filter by metadata Returns: A list of `(Document, distance)` pairs ordered from most to least similar (smallest distance first). """ docs_and_scores = [] embedding_len = self.get_embedding_dimension() query = f""" SELECT id, {self._text_field}, SYSTOOLS.BSON2JSON(metadata), vector_distance(embedding, VECTOR('{embedding}', {embedding_len}, FLOAT32), {_get_distance_function(self.distance_strategy)}) as distance FROM {self.table_name} ORDER BY distance FETCH FIRST {k} ROWS ONLY """ # noqa: S608 # TODO: # noqa: FIX002 TD003 TD002 # No APPROX in "FETCH APPROX FIRST" now. This will be added once # approximate nearest neighbors search in db2 is implemented. # Execute the query cursor = self.client.cursor() try: cursor.execute(query) results = cursor.fetchall() # Filter results if filter is provided for result in results: metadata = json.loads(result[2] if result[2] is not None else "{}") # Apply filtering based on the 'filter' dictionary if filter: if all(metadata.get(key) in value for key, value in filter.items()): doc = Document( page_content=(result[1] if result[1] is not None else ""), metadata=metadata, ) distance = result[3] docs_and_scores.append((doc, distance)) else: doc = Document( page_content=(result[1] if result[1] is not None else ""), metadata=metadata, ) distance = result[3] docs_and_scores.append((doc, distance)) finally: cursor.close() return docs_and_scores
[docs] @_handle_exceptions def similarity_search_by_vector_returning_embeddings( self, embedding: list[float], k: int, filter: dict[str, Any] | None = None, # noqa: A002 ) -> list[tuple[Document, float, np.ndarray]]: """Return top-k documents, their distances, and stored embeddings. Args: embedding: Embedding to look up documents similar to k: Number of Documents to return filter: Filter by metadata Returns: Tuples of `(document, distance, embedding_array)`, ordered from most to least similar (ascending distance) """ documents = [] embedding_len = self.get_embedding_dimension() query = f""" SELECT id, {self._text_field}, SYSTOOLS.BSON2JSON(metadata), vector_distance(embedding, VECTOR('{embedding}', {embedding_len}, FLOAT32), {_get_distance_function(self.distance_strategy)}) as distance, embedding FROM {self.table_name} ORDER BY distance FETCH FIRST {k} ROWS ONLY """ # noqa: S608 # TODO: # noqa: FIX002 TD003 TD002 # No APPROX in "FETCH APPROX FIRST" now. This will be added once # approximate nearest neighbors search in db2 is implemented. # Execute the query cursor = self.client.cursor() try: cursor.execute(query) results = cursor.fetchall() for result in results: page_content_str = result[1] if result[1] is not None else "" metadata = json.loads(result[2] if result[2] is not None else "{}") # Apply filter if provided and matches; otherwise, add all # documents if not filter or all( metadata.get(key) in value for key, value in filter.items() ): document = Document( page_content=page_content_str, metadata=metadata, ) distance = result[3] # Assuming result[4] is already in the correct format; # adjust if necessary current_embedding = ( np.array(json.loads(result[4]), dtype=np.float32) if result[4] else np.empty(0, dtype=np.float32) ) documents.append((document, distance, current_embedding)) finally: cursor.close() return documents
[docs] @_handle_exceptions def max_marginal_relevance_search_with_score_by_vector( self, embedding: list[float], *, k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, filter: dict[str, Any] | None = None, # noqa: A002 ) -> list[tuple[Document, float]]: """Return docs and their similarity scores selected. Return docs and their similarity scores selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. Args: embedding: Embedding to look up documents similar to k: Number of Documents to return fetch_k: Number of Documents to fetch before filtering to pass to MMR algorithm lambda_mult: Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity filter: Filter by metadata Returns: List of Documents and similarity scores selected by maximal marginal relevance and score for each. """ # Fetch documents and their scores docs_scores_embeddings = self.similarity_search_by_vector_returning_embeddings( embedding, fetch_k, filter=filter, ) # Assuming documents_with_scores is a list of tuples (Document, score) # If you need to split documents and scores for processing (e.g., # for MMR calculation) documents, scores, embeddings = ( zip(*docs_scores_embeddings) if docs_scores_embeddings else ([], [], []) ) # Assume maximal_marginal_relevance method accepts embeddings and # scores, and returns indices of selected docs mmr_selected_indices = maximal_marginal_relevance( np.array(embedding, dtype=np.float32), list(embeddings), k=k, lambda_mult=lambda_mult, ) # Filter documents based on MMR-selected indices and map scores return [(documents[i], scores[i]) for i in mmr_selected_indices]
[docs] @_handle_exceptions def max_marginal_relevance_search_by_vector( self, embedding: list[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, filter: dict[str, Any] | None = None, # noqa: A002 **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. Args: embedding: Embedding to look up documents similar to k: Number of Documents to return fetch_k: Number of Documents to fetch to pass to MMR algorithm. lambda_mult: Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity filter: Filter by metadata kwargs: Additional keyword args Returns: List of Documents selected by maximal marginal relevance """ docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter, ) return [doc for doc, _ in docs_and_scores]
[docs] @_handle_exceptions def delete(self, ids: list[str] | None = None, **kwargs: Any) -> None: """Delete by vector IDs. Args: ids: List of ids to delete kwargs: Additional keyword args """ if ids is None: error_msg = "No ids provided to delete." raise ValueError(error_msg) is_hashed = bool(ids) and all(re.fullmatch(r"[A-F0-9]{16}", _id) for _id in ids) if is_hashed: hashed_ids = ids # use as-is else: # Compute SHA-256 hashes of the raw ids and truncate them hashed_ids = [ hashlib.sha256(_id.encode("utf-8")).hexdigest()[:16].upper() for _id in ids ] # Constructing the SQL statement with individual placeholders placeholders = ", ".join("?" for _ in hashed_ids) ddl = f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})" # noqa: S608 cursor = self.client.cursor() try: cursor.execute(ddl, hashed_ids) cursor.execute("COMMIT") finally: cursor.close()
[docs] @classmethod @_handle_exceptions def from_texts( cls: type[DB2VS], texts: Iterable[str], embedding: Embeddings, metadatas: list[dict] | None = None, **kwargs: Any, ) -> DB2VS: """Return VectorStore initialized from texts and embeddings. Args: texts: Iterable of strings to add to the vectorstore embedding: Embedding to look up documents similar to metadatas: Optional list of metadatas associated with the texts kwargs: Additional keyword args Returns: A ready-to-use vector store with the provided texts loaded """ client = kwargs.get("client") if client is None: error_msg = "client parameter is required..." raise ValueError(error_msg) params = kwargs.get("params", {}) table_name = str(kwargs.get("table_name", "langchain")) distance_strategy = cast("DistanceStrategy", kwargs.get("distance_strategy")) if not isinstance(distance_strategy, DistanceStrategy): error_msg = ( # type: ignore[unreachable] f"Expected DistanceStrategy got {type(distance_strategy).__name__} " ) raise TypeError(error_msg) query = kwargs.get("query", "What is a Db2 database") drop_table(client, table_name) vss = cls( client=client, embedding_function=embedding, table_name=table_name, distance_strategy=distance_strategy, query=query, params=params, ) vss.add_texts(texts=list(texts), metadatas=metadatas) return vss
[docs] @_handle_exceptions def get_pks(self, expr: str | None = None) -> list[str]: """Get primary keys, optionally filtered by expr. Args: expr: SQL boolean expression to filter rows, e.g.: `id IN ('ABC123','DEF456')` or `title LIKE 'Abc%'`. If None, returns all rows. Returns: List of matching primary-key values. """ sql = f"SELECT id FROM {self.table_name}" # noqa: S608 if expr: sql += f" WHERE {expr}" cursor = self.client.cursor() try: cursor.execute(sql) rows = cursor.fetchall() finally: cursor.close() return [row[0] for row in rows]