Source code for langchain.chains.router.multi_retrieval_qa
"""Use a single chain to route an input to one of multiple retrieval qa chains."""from__future__importannotationsfromtypingimportAny,Dict,List,Mapping,Optionalfromlangchain_core.language_modelsimportBaseLanguageModelfromlangchain_core.promptsimportPromptTemplatefromlangchain_core.retrieversimportBaseRetrieverfromlangchain.chainsimportConversationChainfromlangchain.chains.baseimportChainfromlangchain.chains.conversation.promptimportDEFAULT_TEMPLATEfromlangchain.chains.retrieval_qa.baseimportBaseRetrievalQA,RetrievalQAfromlangchain.chains.router.baseimportMultiRouteChainfromlangchain.chains.router.llm_routerimportLLMRouterChain,RouterOutputParserfromlangchain.chains.router.multi_retrieval_promptimport(MULTI_RETRIEVAL_ROUTER_TEMPLATE,)
[docs]classMultiRetrievalQAChain(MultiRouteChain):# type: ignore[override]"""A multi-route chain that uses an LLM router chain to choose amongst retrieval qa chains."""router_chain:LLMRouterChain"""Chain for deciding a destination chain and the input to it."""destination_chains:Mapping[str,BaseRetrievalQA]"""Map of name to candidate chains that inputs can be routed to."""default_chain:Chain"""Default chain to use when router doesn't map input to one of the destinations."""@propertydefoutput_keys(self)->List[str]:return["result"]
[docs]@classmethoddeffrom_retrievers(cls,llm:BaseLanguageModel,retriever_infos:List[Dict[str,Any]],default_retriever:Optional[BaseRetriever]=None,default_prompt:Optional[PromptTemplate]=None,default_chain:Optional[Chain]=None,*,default_chain_llm:Optional[BaseLanguageModel]=None,**kwargs:Any,)->MultiRetrievalQAChain:ifdefault_promptandnotdefault_retriever:raiseValueError("`default_retriever` must be specified if `default_prompt` is ""provided. Received only `default_prompt`.")destinations=[f"{r['name']}: {r['description']}"forrinretriever_infos]destinations_str="\n".join(destinations)router_template=MULTI_RETRIEVAL_ROUTER_TEMPLATE.format(destinations=destinations_str)router_prompt=PromptTemplate(template=router_template,input_variables=["input"],output_parser=RouterOutputParser(next_inputs_inner_key="query"),)router_chain=LLMRouterChain.from_llm(llm,router_prompt)destination_chains={}forr_infoinretriever_infos:prompt=r_info.get("prompt")retriever=r_info["retriever"]chain=RetrievalQA.from_llm(llm,prompt=prompt,retriever=retriever)name=r_info["name"]destination_chains[name]=chainifdefault_chain:_default_chain=default_chainelifdefault_retriever:_default_chain=RetrievalQA.from_llm(llm,prompt=default_prompt,retriever=default_retriever)else:prompt_template=DEFAULT_TEMPLATE.replace("input","query")prompt=PromptTemplate(template=prompt_template,input_variables=["history","query"])ifdefault_chain_llmisNone:raiseNotImplementedError("conversation_llm must be provided if default_chain is not ""specified. This API has been changed to avoid instantiating ""default LLMs on behalf of users.""You can provide a conversation LLM like so:\n""from langchain_openai import ChatOpenAI\n""llm = ChatOpenAI()")_default_chain=ConversationChain(llm=default_chain_llm,prompt=prompt,input_key="query",output_key="result",)returncls(router_chain=router_chain,destination_chains=destination_chains,default_chain=_default_chain,**kwargs,)