# Default promptDEFAULT_QUERY_PROMPT=PromptTemplate(input_variables=["question"],template="""You are an AI language model assistant. Your task is to generate 3 different versions of the given user question to retrieve relevant documents from a vector database. By generating multiple perspectives on the user question, your goal is to help the user overcome some of the limitations of distance-based similarity search. Provide these alternative questions separated by newlines. Original question: {question}""",)def_unique_documents(documents:Sequence[Document])->List[Document]:return[docfori,docinenumerate(documents)ifdocnotindocuments[:i]]
[docs]classMultiQueryRetriever(BaseRetriever):"""Given a query, use an LLM to write a set of queries. Retrieve docs for each query. Return the unique union of all retrieved docs. """retriever:BaseRetrieverllm_chain:Runnableverbose:bool=Trueparser_key:str="lines""""DEPRECATED. parser_key is no longer used and should not be specified."""include_original:bool=False"""Whether to include the original query in the list of generated queries."""
[docs]@classmethoddeffrom_llm(cls,retriever:BaseRetriever,llm:BaseLanguageModel,prompt:BasePromptTemplate=DEFAULT_QUERY_PROMPT,parser_key:Optional[str]=None,include_original:bool=False,)->"MultiQueryRetriever":"""Initialize from llm using default template. Args: retriever: retriever to query documents from llm: llm for query generation using DEFAULT_QUERY_PROMPT prompt: The prompt which aims to generate several different versions of the given user query include_original: Whether to include the original query in the list of generated queries. Returns: MultiQueryRetriever """output_parser=LineListOutputParser()llm_chain=prompt|llm|output_parserreturncls(retriever=retriever,llm_chain=llm_chain,include_original=include_original,)
asyncdef_aget_relevant_documents(self,query:str,*,run_manager:AsyncCallbackManagerForRetrieverRun,)->List[Document]:"""Get relevant documents given a user query. Args: query: user query Returns: Unique union of relevant documents from all generated queries """queries=awaitself.agenerate_queries(query,run_manager)ifself.include_original:queries.append(query)documents=awaitself.aretrieve_documents(queries,run_manager)returnself.unique_union(documents)
[docs]asyncdefagenerate_queries(self,question:str,run_manager:AsyncCallbackManagerForRetrieverRun)->List[str]:"""Generate queries based upon user input. Args: question: user query Returns: List of LLM generated queries that are similar to the user input """response=awaitself.llm_chain.ainvoke({"question":question},config={"callbacks":run_manager.get_child()})ifisinstance(self.llm_chain,LLMChain):lines=response["text"]else:lines=responseifself.verbose:logger.info(f"Generated queries: {lines}")returnlines
[docs]asyncdefaretrieve_documents(self,queries:List[str],run_manager:AsyncCallbackManagerForRetrieverRun)->List[Document]:"""Run all LLM generated queries. Args: queries: query list Returns: List of retrieved Documents """document_lists=awaitasyncio.gather(*(self.retriever.ainvoke(query,config={"callbacks":run_manager.get_child()})forqueryinqueries))return[docfordocsindocument_listsfordocindocs]
def_get_relevant_documents(self,query:str,*,run_manager:CallbackManagerForRetrieverRun,)->List[Document]:"""Get relevant documents given a user query. Args: query: user query Returns: Unique union of relevant documents from all generated queries """queries=self.generate_queries(query,run_manager)ifself.include_original:queries.append(query)documents=self.retrieve_documents(queries,run_manager)returnself.unique_union(documents)
[docs]defgenerate_queries(self,question:str,run_manager:CallbackManagerForRetrieverRun)->List[str]:"""Generate queries based upon user input. Args: question: user query Returns: List of LLM generated queries that are similar to the user input """response=self.llm_chain.invoke({"question":question},config={"callbacks":run_manager.get_child()})ifisinstance(self.llm_chain,LLMChain):lines=response["text"]else:lines=responseifself.verbose:logger.info(f"Generated queries: {lines}")returnlines
[docs]defretrieve_documents(self,queries:List[str],run_manager:CallbackManagerForRetrieverRun)->List[Document]:"""Run all LLM generated queries. Args: queries: query list Returns: List of retrieved Documents """documents=[]forqueryinqueries:docs=self.retriever.invoke(query,config={"callbacks":run_manager.get_child()})documents.extend(docs)returndocuments
[docs]defunique_union(self,documents:List[Document])->List[Document]:"""Get unique Documents. Args: documents: List of retrieved Documents Returns: List of unique retrieved Documents """return_unique_documents(documents)