Source code for langchain_community.document_loaders.mongodb

import asyncio
import logging
from typing import Dict, List, Optional, Sequence

from langchain_core.documents import Document

from langchain_community.document_loaders.base import BaseLoader

logger = logging.getLogger(__name__)


[docs] class MongodbLoader(BaseLoader): """Load MongoDB documents."""
[docs] def __init__( self, connection_string: str, db_name: str, collection_name: str, *, filter_criteria: Optional[Dict] = None, field_names: Optional[Sequence[str]] = None, metadata_names: Optional[Sequence[str]] = None, include_db_collection_in_metadata: bool = True, ) -> None: """ Initializes the MongoDB loader with necessary database connection details and configurations. Args: connection_string (str): MongoDB connection URI. db_name (str):Name of the database to connect to. collection_name (str): Name of the collection to fetch documents from. filter_criteria (Optional[Dict]): MongoDB filter criteria for querying documents. field_names (Optional[Sequence[str]]): List of field names to retrieve from documents. metadata_names (Optional[Sequence[str]]): Additional metadata fields to extract from documents. include_db_collection_in_metadata (bool): Flag to include database and collection names in metadata. Raises: ImportError: If the motor library is not installed. ValueError: If any necessary argument is missing. """ try: from motor.motor_asyncio import AsyncIOMotorClient except ImportError as e: raise ImportError( "Cannot import from motor, please install with `pip install motor`." ) from e if not connection_string: raise ValueError("connection_string must be provided.") if not db_name: raise ValueError("db_name must be provided.") if not collection_name: raise ValueError("collection_name must be provided.") self.client = AsyncIOMotorClient(connection_string) self.db_name = db_name self.collection_name = collection_name self.field_names = field_names or [] self.filter_criteria = filter_criteria or {} self.metadata_names = metadata_names or [] self.include_db_collection_in_metadata = include_db_collection_in_metadata self.db = self.client.get_database(db_name) self.collection = self.db.get_collection(collection_name)
[docs] def load(self) -> List[Document]: """Load data into Document objects. Attention: This implementation starts an asyncio event loop which will only work if running in a sync env. In an async env, it should fail since there is already an event loop running. This code should be updated to kick off the event loop from a separate thread if running within an async context. """ return asyncio.run(self.aload())
[docs] async def aload(self) -> List[Document]: """Asynchronously loads data into Document objects.""" result = [] total_docs = await self.collection.count_documents(self.filter_criteria) projection = self._construct_projection() async for doc in self.collection.find(self.filter_criteria, projection): metadata = self._extract_fields(doc, self.metadata_names, default="") # Optionally add database and collection names to metadata if self.include_db_collection_in_metadata: metadata.update( {"database": self.db_name, "collection": self.collection_name} ) # Extract text content from filtered fields or use the entire document if self.field_names is not None: fields = self._extract_fields(doc, self.field_names, default="") texts = [str(value) for value in fields.values()] text = " ".join(texts) else: text = str(doc) result.append(Document(page_content=text, metadata=metadata)) if len(result) != total_docs: logger.warning( f"Only partial collection of documents returned. " f"Loaded {len(result)} docs, expected {total_docs}." ) return result
def _construct_projection(self) -> Optional[Dict]: """Constructs the projection dictionary for MongoDB query based on the specified field names and metadata names.""" field_names = list(self.field_names) or [] metadata_names = list(self.metadata_names) or [] all_fields = field_names + metadata_names return {field: 1 for field in all_fields} if all_fields else None def _extract_fields( self, document: Dict, fields: Sequence[str], default: str = "", ) -> Dict: """Extracts and returns values for specified fields from a document.""" extracted = {} for field in fields or []: value = document for key in field.split("."): value = value.get(key, default) if value == default: break new_field_name = field.replace(".", "_") extracted[new_field_name] = value return extracted