Source code for langchain_astradb.document_loaders

"""Loader for loading documents from DataStax Astra DB."""

from __future__ import annotations

import json
import logging
import warnings
from typing import (
    TYPE_CHECKING,
    Any,
    AsyncIterator,
    Callable,
    Iterator,
)

from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from typing_extensions import override

from langchain_astradb.utils.astradb import (
    SetupMode,
    _AstraDBCollectionEnvironment,
)

if TYPE_CHECKING:
    from astrapy.authentication import TokenProvider
    from astrapy.db import AstraDB, AsyncAstraDB

logger = logging.getLogger(__name__)

_NOT_SET = object()


[docs]class AstraDBLoader(BaseLoader):
[docs] def __init__( self, collection_name: str, *, token: str | TokenProvider | None = None, api_endpoint: str | None = None, environment: str | None = None, astra_db_client: AstraDB | None = None, async_astra_db_client: AsyncAstraDB | None = None, namespace: str | None = None, filter_criteria: dict[str, Any] | None = None, projection: dict[str, Any] | None = _NOT_SET, # type: ignore[assignment] find_options: dict[str, Any] | None = None, limit: int | None = None, nb_prefetched: int = _NOT_SET, # type: ignore[assignment] page_content_mapper: Callable[[dict], str] = json.dumps, metadata_mapper: Callable[[dict], dict[str, Any]] | None = None, ) -> None: """Load DataStax Astra DB documents. Args: collection_name: name of the Astra DB collection to use. token: API token for Astra DB usage, either in the form of a string or a subclass of `astrapy.authentication.TokenProvider`. If not provided, the environment variable ASTRA_DB_APPLICATION_TOKEN is inspected. api_endpoint: full URL to the API endpoint, such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`. If not provided, the environment variable ASTRA_DB_API_ENDPOINT is inspected. environment: a string specifying the environment of the target Data API. If omitted, defaults to "prod" (Astra DB production). Other values are in `astrapy.constants.Environment` enum class. astra_db_client: *DEPRECATED starting from version 0.3.5.* *Please use 'token', 'api_endpoint' and optionally 'environment'.* you can pass an already-created 'astrapy.db.AstraDB' instance (alternatively to 'token', 'api_endpoint' and 'environment'). async_astra_db_client: *DEPRECATED starting from version 0.3.5.* *Please use 'token', 'api_endpoint' and optionally 'environment'.* you can pass an already-created 'astrapy.db.AsyncAstraDB' instance (alternatively to 'token', 'api_endpoint' and 'environment'). namespace: namespace (aka keyspace) where the collection resides. If not provided, the environment variable ASTRA_DB_KEYSPACE is inspected. Defaults to the database's "default namespace". filter_criteria: Criteria to filter documents. projection: Specifies the fields to return. If not provided, reads fall back to the Data API default projection. find_options: Additional options for the query. *DEPRECATED starting from version 0.3.5.* *For limiting, please use `limit`. Other options are ignored.* limit: a maximum number of documents to return in the read query. nb_prefetched: Max number of documents to pre-fetch. *IGNORED starting from v. 0.3.5: astrapy v1.0+ does not support it.* page_content_mapper: Function applied to collection documents to create the `page_content` of the LangChain Document. Defaults to `json.dumps`. metadata_mapper: Function applied to collection documents to create the `metadata` of the LangChain Document. Defaults to returning the namespace, API endpoint and collection name. """ astra_db_env = _AstraDBCollectionEnvironment( collection_name=collection_name, token=token, api_endpoint=api_endpoint, environment=environment, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, namespace=namespace, setup_mode=SetupMode.OFF, ) self.astra_db_env = astra_db_env self.filter = filter_criteria self._projection: dict[str, Any] | None = ( projection if projection is not _NOT_SET else {"*": True} ) # warning if 'prefetched' passed if nb_prefetched is not _NOT_SET: warnings.warn( ( "Parameter 'nb_prefetched' is not supported by the Data API " "client and will be ignored in reading document." ), UserWarning, stacklevel=2, ) # normalizing limit and options and deprecations _find_options = find_options.copy() if find_options else {} if "limit" in _find_options: if limit is not None: msg = ( "Duplicate 'limit' directive supplied. Please remove it " "from the 'find_options' map parameter." ) raise ValueError(msg) warnings.warn( ( "Passing 'limit' as part of the 'find_options' " "dictionary is deprecated starting from version 0.3.5. " "Please switch to passing 'limit=<number>' " "directly in the constructor." ), DeprecationWarning, stacklevel=2, ) self.limit = _find_options.pop("limit", limit) if _find_options: warnings.warn( ( "Unknown keys passed in the 'find_options' dictionary. " "This parameter is deprecated starting from version 0.3.5." ), DeprecationWarning, stacklevel=2, ) self.nb_prefetched = nb_prefetched self.page_content_mapper = page_content_mapper self.metadata_mapper = metadata_mapper or ( lambda _: { "namespace": self.astra_db_env.database.namespace, "api_endpoint": self.astra_db_env.database.api_endpoint, "collection": collection_name, } )
def _to_langchain_doc(self, doc: dict[str, Any]) -> Document: return Document( page_content=self.page_content_mapper(doc), metadata=self.metadata_mapper(doc), )
[docs] @override def lazy_load(self) -> Iterator[Document]: for doc in self.astra_db_env.collection.find( filter=self.filter, projection=self._projection, limit=self.limit, # prefetch: not available at the moment (silently ignored) # prefetched=self.nb_prefetched, ): yield self._to_langchain_doc(doc)
[docs] async def aload(self) -> list[Document]: """Load data into Document objects.""" return [doc async for doc in self.alazy_load()]
[docs] @override async def alazy_load(self) -> AsyncIterator[Document]: async for doc in self.astra_db_env.async_collection.find( filter=self.filter, projection=self._projection, limit=self.limit, # prefetch: not available at the moment (silently ignored): # prefetched=self.nb_prefetched, ): yield self._to_langchain_doc(doc)