Source code for langchain_astradb.utils.astradb

"""Utilities for AstraDB setup and management."""

from __future__ import annotations

import asyncio
import inspect
import json
import logging
import os
import warnings
from asyncio import InvalidStateError, Task
from enum import Enum
from typing import TYPE_CHECKING, Any, Awaitable

import langchain_core
from astrapy import AsyncDatabase, DataAPIClient, Database
from astrapy.exceptions import DataAPIException

if TYPE_CHECKING:
    from astrapy.authentication import EmbeddingHeadersProvider, TokenProvider
    from astrapy.db import AstraDB, AsyncAstraDB
    from astrapy.info import CollectionDescriptor, CollectionVectorServiceOptions

TOKEN_ENV_VAR = "ASTRA_DB_APPLICATION_TOKEN"  # noqa: S105
API_ENDPOINT_ENV_VAR = "ASTRA_DB_API_ENDPOINT"
NAMESPACE_ENV_VAR = "ASTRA_DB_KEYSPACE"

# Default settings for API data operations (concurrency & similar):
# Chunk size for many-document insertions (None meaning defer to astrapy):
DEFAULT_DOCUMENT_CHUNK_SIZE = None
# thread/coroutine count for bulk inserts
MAX_CONCURRENT_DOCUMENT_INSERTIONS = 20
# Thread/coroutine count for one-doc-at-a-time overwrites
MAX_CONCURRENT_DOCUMENT_REPLACEMENTS = 20
# Thread/coroutine count for one-doc-at-a-time deletes:
MAX_CONCURRENT_DOCUMENT_DELETIONS = 20

# Amount of (max) number of documents for surveying a collection
SURVEY_NUMBER_OF_DOCUMENTS = 15

logger = logging.getLogger()


[docs]class SetupMode(Enum): """Setup mode for the Astra DB collection.""" SYNC = 1 ASYNC = 2 OFF = 3
def _survey_collection( collection_name: str, *, token: str | TokenProvider | None = None, api_endpoint: str | None = None, environment: str | None = None, astra_db_client: AstraDB | None = None, async_astra_db_client: AsyncAstraDB | None = None, namespace: str | None = None, ) -> tuple[CollectionDescriptor | None, list[dict[str, Any]]]: """Return the collection descriptor (if found) and a sample of documents.""" _environment = _AstraDBEnvironment( token=token, api_endpoint=api_endpoint, environment=environment, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, namespace=namespace, ) descriptors = [ coll_d for coll_d in _environment.database.list_collections() if coll_d.name == collection_name ] if not descriptors: return None, [] descriptor = descriptors[0] # fetch some documents document_ite = _environment.database.get_collection(collection_name).find( filter={}, projection={"*": True}, limit=SURVEY_NUMBER_OF_DOCUMENTS, ) return (descriptor, list(document_ite)) class _AstraDBEnvironment: def __init__( self, token: str | TokenProvider | None = None, api_endpoint: str | None = None, environment: str | None = None, astra_db_client: AstraDB | None = None, async_astra_db_client: AsyncAstraDB | None = None, namespace: str | None = None, ) -> None: self.token: str | TokenProvider | None self.api_endpoint: str | None self.namespace: str | None self.environment: str | None self.data_api_client: DataAPIClient self.database: Database self.async_database: AsyncDatabase if astra_db_client is not None or async_astra_db_client is not None: if token is not None or api_endpoint is not None or environment is not None: msg = ( "You cannot pass 'astra_db_client' or 'async_astra_db_client' " "to AstraDBEnvironment if passing 'token', 'api_endpoint' or " "'environment'." ) raise ValueError(msg) _astra_db = astra_db_client.copy() if astra_db_client is not None else None _async_astra_db = ( async_astra_db_client.copy() if async_astra_db_client is not None else None ) # deprecation of the 'core classes' in constructor and conversion # to token/endpoint(-environment) based init, with checks # at least one of the two (core) clients is not None: warnings.warn( ( "Initializing Astra DB LangChain classes by passing " "AstraDB/AsyncAstraDB ready clients is deprecated starting " "with langchain-astradb==0.3.5. Please switch to passing " "'token', 'api_endpoint' (and optionally 'environment') " "instead." ), DeprecationWarning, stacklevel=2, ) _tokens = list( { klient.token for klient in [astra_db_client, async_astra_db_client] if klient is not None } ) _api_endpoints = list( { klient.api_endpoint for klient in [astra_db_client, async_astra_db_client] if klient is not None } ) _namespaces = list( { klient.namespace for klient in [astra_db_client, async_astra_db_client] if klient is not None } ) if len(_tokens) != 1: msg = ( "Conflicting tokens found in the sync and async AstraDB " "constructor parameters. Please check the tokens and " "ensure they match." ) raise ValueError(msg) if len(_api_endpoints) != 1: msg = ( "Conflicting API endpoints found in the sync and async " "AstraDB constructor parameters. Please check the tokens " "and ensure they match." ) raise ValueError(msg) if len(_namespaces) != 1: msg = ( "Conflicting namespaces found in the sync and async " "AstraDB constructor parameters. Please check the tokens " "and ensure they match." ) raise ValueError(msg) # all good: these are 1-element lists here self.token = _tokens[0] self.api_endpoint = _api_endpoints[0] self.namespace = _namespaces[0] else: _token: str | TokenProvider | None # secrets-based initialization if token is None: logger.info( "Attempting to fetch token from environment " "variable '%s'", TOKEN_ENV_VAR, ) _token = os.environ.get(TOKEN_ENV_VAR) else: _token = token if api_endpoint is None: logger.info( "Attempting to fetch API endpoint from environment " "variable '%s'", API_ENDPOINT_ENV_VAR, ) _api_endpoint = os.environ.get(API_ENDPOINT_ENV_VAR) else: _api_endpoint = api_endpoint if namespace is None: _namespace = os.environ.get(NAMESPACE_ENV_VAR) else: _namespace = namespace self.token = _token self.api_endpoint = _api_endpoint self.namespace = _namespace self.environment = environment # init parameters are normalized to self.{token, api_endpoint, namespace}. # Proceed. Namespace and token can be None (resp. on Astra DB and non-Astra) if self.api_endpoint is None: msg = ( "API endpoint for Data API not provided. " "Either pass it explicitly to the object constructor " f"or set the {API_ENDPOINT_ENV_VAR} environment variable." ) raise ValueError(msg) # create the clients caller_name = "langchain" caller_version = getattr(langchain_core, "__version__", None) self.data_api_client = DataAPIClient( environment=self.environment, caller_name=caller_name, caller_version=caller_version, ) self.database = self.data_api_client.get_database( api_endpoint=self.api_endpoint, token=self.token, namespace=self.namespace, ) self.async_database = self.database.to_async() class _AstraDBCollectionEnvironment(_AstraDBEnvironment): def __init__( self, collection_name: str, *, token: str | TokenProvider | None = None, api_endpoint: str | None = None, environment: str | None = None, astra_db_client: AstraDB | None = None, async_astra_db_client: AsyncAstraDB | None = None, namespace: str | None = None, setup_mode: SetupMode = SetupMode.SYNC, pre_delete_collection: bool = False, embedding_dimension: int | Awaitable[int] | None = None, metric: str | None = None, requested_indexing_policy: dict[str, Any] | None = None, default_indexing_policy: dict[str, Any] | None = None, collection_vector_service_options: CollectionVectorServiceOptions | None = None, collection_embedding_api_key: str | EmbeddingHeadersProvider | None = None, ) -> None: super().__init__( token=token, api_endpoint=api_endpoint, environment=environment, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, namespace=namespace, ) self.collection_name = collection_name self.collection = self.database.get_collection( name=self.collection_name, embedding_api_key=collection_embedding_api_key, ) self.async_collection = self.collection.to_async() self.async_setup_db_task: Task | None = None if setup_mode == SetupMode.ASYNC: self.async_setup_db_task = asyncio.create_task( self._asetup_db( pre_delete_collection=pre_delete_collection, embedding_dimension=embedding_dimension, metric=metric, default_indexing_policy=default_indexing_policy, requested_indexing_policy=requested_indexing_policy, collection_vector_service_options=collection_vector_service_options, ) ) elif setup_mode == SetupMode.SYNC: if pre_delete_collection: self.database.drop_collection(collection_name) if inspect.isawaitable(embedding_dimension): msg = ( "Cannot use an awaitable embedding_dimension with async_setup " "set to False" ) raise ValueError(msg) try: self.database.create_collection( name=collection_name, dimension=embedding_dimension, metric=metric, indexing=requested_indexing_policy, # Used for enabling $vectorize on the collection service=collection_vector_service_options, check_exists=False, ) except DataAPIException: # possibly the collection is preexisting and may have legacy, # or custom, indexing settings: verify collection_descriptors = list(self.database.list_collections()) if not self._validate_indexing_policy( collection_descriptors=collection_descriptors, collection_name=self.collection_name, requested_indexing_policy=requested_indexing_policy, default_indexing_policy=default_indexing_policy, ): # other reasons for the exception raise async def _asetup_db( self, *, pre_delete_collection: bool, embedding_dimension: int | Awaitable[int] | None, metric: str | None, requested_indexing_policy: dict[str, Any] | None, default_indexing_policy: dict[str, Any] | None, collection_vector_service_options: CollectionVectorServiceOptions | None, ) -> None: if pre_delete_collection: await self.async_database.drop_collection(self.collection_name) if inspect.isawaitable(embedding_dimension): dimension = await embedding_dimension else: dimension = embedding_dimension try: await self.async_database.create_collection( name=self.collection_name, dimension=dimension, metric=metric, indexing=requested_indexing_policy, # Used for enabling $vectorize on the collection service=collection_vector_service_options, check_exists=False, ) except DataAPIException: # possibly the collection is preexisting and may have legacy, # or custom, indexing settings: verify collection_descriptors = [ coll_desc async for coll_desc in self.async_database.list_collections() ] if not self._validate_indexing_policy( collection_descriptors=collection_descriptors, collection_name=self.collection_name, requested_indexing_policy=requested_indexing_policy, default_indexing_policy=default_indexing_policy, ): # other reasons for the exception raise @staticmethod def _validate_indexing_policy( collection_descriptors: list[CollectionDescriptor], collection_name: str, requested_indexing_policy: dict[str, Any] | None, default_indexing_policy: dict[str, Any] | None, ) -> bool: """Validate indexing policy. This is a validation helper, to be called when the collection-creation call has failed. Args: collection_descriptors: collection descriptors for the database. collection_name: the name of the collection whose attempted creation failed requested_indexing_policy: the 'indexing' part of the collection options, e.g. `{"deny": ["field1", "field2"]}`. Leave to its default of None if no options required. default_indexing_policy: an optional 'default value' for the above, used to issue just a gentle warning in the special case that no policy is detected on a preexisting collection on DB and the default is requested. This is to enable a warning-only transition to new code using indexing without disrupting usage of a legacy collection, i.e. one created before adopting the usage of indexing policies altogether. You cannot pass this one without requested_indexing_policy. This function may raise an error (indexing mismatches), issue a warning (about legacy collections), or do nothing. In any case, when the function returns, it returns either - True: the exception was handled here as part of the indexing management - False: the exception is unrelated to indexing and the caller has to reraise it. """ if requested_indexing_policy is None and default_indexing_policy is not None: msg = ( "Cannot specify a default indexing policy " "when no indexing policy is requested for this collection " "(requested_indexing_policy is None, " "default_indexing_policy is not None)." ) raise ValueError(msg) preexisting = [ collection for collection in collection_descriptors if collection.name == collection_name ] if not preexisting: # foreign-origin for the original exception return False pre_collection = preexisting[0] # if it has no "indexing", it is a legacy collection pre_col_options = pre_collection.options if not pre_col_options.indexing: # legacy collection on DB if requested_indexing_policy != default_indexing_policy: msg = ( f"Astra DB collection '{collection_name}' is " "detected as having indexing turned on for all " "fields (either created manually or by older " "versions of this plugin). This is incompatible with " "the requested indexing policy for this object. " "Consider indexing anew on a fresh " "collection with the requested indexing " "policy, or alternatively leave the indexing " "settings for this object to their defaults " "to keep using this collection." ) raise ValueError(msg) warnings.warn( ( f"Astra DB collection '{collection_name}' is " "detected as having indexing turned on for all " "fields (either created manually or by older " "versions of this plugin). This implies stricter " "limitations on the amount of text each string in a " "document can store. Consider indexing anew on a " "fresh collection to be able to store longer texts. " "See https://github.com/langchain-ai/langchain-" "datastax/blob/main/libs/astradb/README.md#" "warnings-about-indexing for more details." ), UserWarning, stacklevel=2, ) # the original exception, related to indexing, was handled here return True if pre_col_options.indexing != requested_indexing_policy: # collection on DB has indexing settings, but different options_json = json.dumps(pre_col_options.indexing) default_desc = ( " (default setting)" if pre_col_options.indexing == default_indexing_policy else "" ) msg = ( f"Astra DB collection '{collection_name}' is " "detected as having the following indexing policy: " f"{options_json}{default_desc}. This is incompatible " "with the requested indexing policy for this object. " "Consider indexing anew on a fresh " "collection with the requested indexing " "policy, or alternatively align the requested " "indexing settings to the collection to keep using it." ) raise ValueError(msg) # the discrepancies have to do with options other than indexing return False def ensure_db_setup(self) -> None: if self.async_setup_db_task: try: self.async_setup_db_task.result() except InvalidStateError as e: msg = ( "Asynchronous setup of the DB not finished. " "NB: Astra DB components sync methods shouldn't be called from the " "event loop. Consider using their async equivalents." ) raise ValueError(msg) from e async def aensure_db_setup(self) -> None: if self.async_setup_db_task: await self.async_setup_db_task