Source code for langchain_mongodb.retrievers.parent_document

from __future__ import annotations

from importlib.metadata import version
from typing import Any, List

import pymongo
from langchain.retrievers.parent_document_retriever import ParentDocumentRetriever
from langchain_core.callbacks import (
    CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_text_splitters import TextSplitter
from pymongo import MongoClient
from pymongo.driver_info import DriverInfo

from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.docstores import MongoDBDocStore
from langchain_mongodb.pipelines import vector_search_stage
from langchain_mongodb.utils import make_serializable


[docs] class MongoDBAtlasParentDocumentRetriever(ParentDocumentRetriever): """MongoDB Atlas's ParentDocumentRetriever Uses ONE Collection for both Vector and Doc store. For details, see parent classes :class:`~langchain.retrievers.parent_document_retriever.ParentDocumentRetriever` and :class:`~langchain.retrievers.MultiVectorRetriever` for further details. Examples: >>> from langchain_mongodb.retrievers.parent_document import ( >>> ParentDocumentRetriever >>> ) >>> from langchain_text_splitters import RecursiveCharacterTextSplitter >>> from langchain_openai import OpenAIEmbeddings >>> >>> retriever = ParentDocumentRetriever.from_connection_string( >>> "mongodb+srv://<user>:<clustername>.mongodb.net", >>> OpenAIEmbeddings(model="text-embedding-3-large"), >>> RecursiveCharacterTextSplitter(chunk_size=400), >>> "example_database" >>> ) retriever.add_documents([Document(..., technical_report_pages) >>> resp = retriever.invoke("Langchain MongDB Partnership Ecosystem") >>> print(resp) [Document(...), ...] """ vectorstore: MongoDBAtlasVectorSearch """Vectorstore API to add, embed, and search through child documents""" docstore: MongoDBDocStore """Provides an API around the Collection to add the parent documents""" id_key: str = "doc_id" """Key stored in metadata pointing to parent document""" def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: query_vector = self.vectorstore._embedding.embed_query(query) pipeline = [ vector_search_stage( query_vector, self.vectorstore._embedding_key, self.vectorstore._index_name, **self.search_kwargs, # See MongoDBAtlasVectorSearch ), {"$set": {"score": {"$meta": "vectorSearchScore"}}}, {"$project": {"embedding": 0}}, { # Find corresponding parent doc "$lookup": { "from": self.vectorstore.collection.name, "localField": self.id_key, "foreignField": "_id", "as": "parent_context", "pipeline": [ # Discard sub-documents {"$match": {f"metadata.{self.id_key}": {"$exists": False}}}, ], } }, # Remove duplicate parent docs and reformat {"$unwind": {"path": "$parent_context"}}, { "$group": { "_id": "$parent_context._id", "uniqueDocument": {"$first": "$parent_context"}, } }, {"$replaceRoot": {"newRoot": "$uniqueDocument"}}, ] # Execute cursor = self.vectorstore._collection.aggregate(pipeline) # type: ignore[arg-type] docs = [] # Format into Documents for res in cursor: text = res.pop(self.vectorstore._text_key) make_serializable(res) docs.append(Document(page_content=text, metadata=res)) return docs
[docs] @classmethod def from_connection_string( cls, connection_string: str, embedding_model: Embeddings, child_splitter: TextSplitter, database_name: str, collection_name: str = "document_with_chunks", id_key: str = "doc_id", **kwargs: Any, ) -> MongoDBAtlasParentDocumentRetriever: """Construct Retriever using one Collection for VectorStore and one for DocStore See parent classes :class:`~langchain.retrievers.parent_document_retriever.ParentDocumentRetriever` and :class:`~langchain.retrievers.MultiVectorRetriever` for further details. Args: connection_string: A valid MongoDB Atlas connection URI. embedding_model: The text embedding model to use for the vector store. child_splitter: Splits documents into chunks. If parent_splitter is given, the documents will have already been split. database_name: Name of database to connect to. Created if it does not exist. collection_name: Name of collection to use. It includes parent documents, sub-documents and their embeddings. id_key: Key used to identify parent documents. **kwargs: Additional keyword arguments. See parent classes for more. Returns: A new MongoDBAtlasParentDocumentRetriever """ client: MongoClient = MongoClient( connection_string, driver=DriverInfo(name="langchain", version=version("langchain-mongodb")), ) collection = client[database_name][collection_name] vectorstore = MongoDBAtlasVectorSearch( collection=collection, embedding=embedding_model, **kwargs ) docstore = MongoDBDocStore(collection=collection) docstore.collection.create_index([(id_key, pymongo.ASCENDING)]) return cls( vectorstore=vectorstore, docstore=docstore, child_splitter=child_splitter, id_key=id_key, **kwargs, )