"""Ensemble retriever that ensemble the results ofmultiple retrievers by using weighted Reciprocal Rank Fusion"""importasynciofromcollectionsimportdefaultdictfromcollections.abcimportHashablefromitertoolsimportchainfromtypingimport(Any,Callable,Dict,Iterable,Iterator,List,Optional,TypeVar,cast,)fromlangchain_core.callbacksimport(AsyncCallbackManagerForRetrieverRun,CallbackManagerForRetrieverRun,)fromlangchain_core.documentsimportDocumentfromlangchain_core.retrieversimportBaseRetriever,RetrieverLikefromlangchain_core.runnablesimportRunnableConfigfromlangchain_core.runnables.configimportensure_config,patch_configfromlangchain_core.runnables.utilsimport(ConfigurableFieldSpec,get_unique_config_specs,)frompydanticimportmodel_validatorT=TypeVar("T")H=TypeVar("H",bound=Hashable)
[docs]defunique_by_key(iterable:Iterable[T],key:Callable[[T],H])->Iterator[T]:"""Yield unique elements of an iterable based on a key function. Args: iterable: The iterable to filter. key: A function that returns a hashable key for each element. Yields: Unique elements of the iterable based on the key function. """seen=set()foreiniterable:if(k:=key(e))notinseen:seen.add(k)yielde
[docs]classEnsembleRetriever(BaseRetriever):"""Retriever that ensembles the multiple retrievers. It uses a rank fusion. Args: retrievers: A list of retrievers to ensemble. weights: A list of weights corresponding to the retrievers. Defaults to equal weighting for all retrievers. c: A constant added to the rank, controlling the balance between the importance of high-ranked items and the consideration given to lower-ranked items. Default is 60. id_key: The key in the document's metadata used to determine unique documents. If not specified, page_content is used. """retrievers:List[RetrieverLike]weights:List[float]c:int=60id_key:Optional[str]=None@propertydefconfig_specs(self)->List[ConfigurableFieldSpec]:"""List configurable fields for this runnable."""returnget_unique_config_specs(specforretrieverinself.retrieversforspecinretriever.config_specs)@model_validator(mode="before")@classmethoddefset_weights(cls,values:Dict[str,Any])->Any:ifnotvalues.get("weights"):n_retrievers=len(values["retrievers"])values["weights"]=[1/n_retrievers]*n_retrieversreturnvalues
def_get_relevant_documents(self,query:str,*,run_manager:CallbackManagerForRetrieverRun,)->List[Document]:""" Get the relevant documents for a given query. Args: query: The query to search for. Returns: A list of reranked documents. """# Get fused result of the retrievers.fused_documents=self.rank_fusion(query,run_manager)returnfused_documentsasyncdef_aget_relevant_documents(self,query:str,*,run_manager:AsyncCallbackManagerForRetrieverRun,)->List[Document]:""" Asynchronously get the relevant documents for a given query. Args: query: The query to search for. Returns: A list of reranked documents. """# Get fused result of the retrievers.fused_documents=awaitself.arank_fusion(query,run_manager)returnfused_documents
[docs]defrank_fusion(self,query:str,run_manager:CallbackManagerForRetrieverRun,*,config:Optional[RunnableConfig]=None,)->List[Document]:""" Retrieve the results of the retrievers and use rank_fusion_func to get the final result. Args: query: The query to search for. Returns: A list of reranked documents. """# Get the results of all retrievers.retriever_docs=[retriever.invoke(query,patch_config(config,callbacks=run_manager.get_child(tag=f"retriever_{i+1}")),)fori,retrieverinenumerate(self.retrievers)]# Enforce that retrieved docs are Documents for each list in retriever_docsforiinrange(len(retriever_docs)):retriever_docs[i]=[Document(page_content=cast(str,doc))ifisinstance(doc,str)elsedocfordocinretriever_docs[i]]# apply rank fusionfused_documents=self.weighted_reciprocal_rank(retriever_docs)returnfused_documents
[docs]asyncdefarank_fusion(self,query:str,run_manager:AsyncCallbackManagerForRetrieverRun,*,config:Optional[RunnableConfig]=None,)->List[Document]:""" Asynchronously retrieve the results of the retrievers and use rank_fusion_func to get the final result. Args: query: The query to search for. Returns: A list of reranked documents. """# Get the results of all retrievers.retriever_docs=awaitasyncio.gather(*[retriever.ainvoke(query,patch_config(config,callbacks=run_manager.get_child(tag=f"retriever_{i+1}"),),)fori,retrieverinenumerate(self.retrievers)])# Enforce that retrieved docs are Documents for each list in retriever_docsforiinrange(len(retriever_docs)):retriever_docs[i]=[Document(page_content=doc)ifnotisinstance(doc,Document)elsedoc# type: ignore[arg-type]fordocinretriever_docs[i]]# apply rank fusionfused_documents=self.weighted_reciprocal_rank(retriever_docs)returnfused_documents
[docs]defweighted_reciprocal_rank(self,doc_lists:List[List[Document]])->List[Document]:""" Perform weighted Reciprocal Rank Fusion on multiple rank lists. You can find more details about RRF here: https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf Args: doc_lists: A list of rank lists, where each rank list contains unique items. Returns: list: The final aggregated list of items sorted by their weighted RRF scores in descending order. """iflen(doc_lists)!=len(self.weights):raiseValueError("Number of rank lists must be equal to the number of weights.")# Associate each doc's content with its RRF score for later sorting by it# Duplicated contents across retrievers are collapsed & scored cumulativelyrrf_score:Dict[str,float]=defaultdict(float)fordoc_list,weightinzip(doc_lists,self.weights):forrank,docinenumerate(doc_list,start=1):rrf_score[(doc.page_contentifself.id_keyisNoneelsedoc.metadata[self.id_key])]+=weight/(rank+self.c)# Docs are deduplicated by their contents then sorted by their scoresall_docs=chain.from_iterable(doc_lists)sorted_docs=sorted(unique_by_key(all_docs,lambdadoc:(doc.page_contentifself.id_keyisNoneelsedoc.metadata[self.id_key]),),reverse=True,key=lambdadoc:rrf_score[doc.page_contentifself.id_keyisNoneelsedoc.metadata[self.id_key]],)returnsorted_docs