Source code for langchain.retrievers.merger_retriever
import asyncio
from typing import List
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
[docs]
class MergerRetriever(BaseRetriever):
"""Retriever that merges the results of multiple retrievers."""
retrievers: List[BaseRetriever]
"""A list of retrievers to merge."""
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""
Get the relevant documents for a given query.
Args:
query: The query to search for.
Returns:
A list of relevant documents.
"""
# Merge the results of the retrievers.
merged_documents = self.merge_documents(query, run_manager)
return merged_documents
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
"""
Asynchronously get the relevant documents for a given query.
Args:
query: The query to search for.
Returns:
A list of relevant documents.
"""
# Merge the results of the retrievers.
merged_documents = await self.amerge_documents(query, run_manager)
return merged_documents
[docs]
def merge_documents(
self, query: str, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""
Merge the results of the retrievers.
Args:
query: The query to search for.
Returns:
A list of merged documents.
"""
# Get the results of all retrievers.
retriever_docs = [
retriever.invoke(
query,
config={
"callbacks": run_manager.get_child("retriever_{}".format(i + 1))
},
)
for i, retriever in enumerate(self.retrievers)
]
# Merge the results of the retrievers.
merged_documents = []
max_docs = max(map(len, retriever_docs), default=0)
for i in range(max_docs):
for retriever, doc in zip(self.retrievers, retriever_docs):
if i < len(doc):
merged_documents.append(doc[i])
return merged_documents
[docs]
async def amerge_documents(
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""
Asynchronously merge the results of the retrievers.
Args:
query: The query to search for.
Returns:
A list of merged documents.
"""
# Get the results of all retrievers.
retriever_docs = await asyncio.gather(
*(
retriever.ainvoke(
query,
config={
"callbacks": run_manager.get_child("retriever_{}".format(i + 1))
},
)
for i, retriever in enumerate(self.retrievers)
)
)
# Merge the results of the retrievers.
merged_documents = []
max_docs = max(map(len, retriever_docs), default=0)
for i in range(max_docs):
for retriever, doc in zip(self.retrievers, retriever_docs):
if i < len(doc):
merged_documents.append(doc[i])
return merged_documents