Source code for langchain.chains.router.multi_retrieval_qa
"""Use a single chain to route an input to one of multiple retrieval qa chains."""
from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain.chains import ConversationChain
from langchain.chains.base import Chain
from langchain.chains.conversation.prompt import DEFAULT_TEMPLATE
from langchain.chains.retrieval_qa.base import BaseRetrievalQA, RetrievalQA
from langchain.chains.router.base import MultiRouteChain
from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser
from langchain.chains.router.multi_retrieval_prompt import (
MULTI_RETRIEVAL_ROUTER_TEMPLATE,
)
[docs]
class MultiRetrievalQAChain(MultiRouteChain): # type: ignore[override]
"""A multi-route chain that uses an LLM router chain to choose amongst retrieval
qa chains."""
router_chain: LLMRouterChain
"""Chain for deciding a destination chain and the input to it."""
destination_chains: Mapping[str, BaseRetrievalQA]
"""Map of name to candidate chains that inputs can be routed to."""
default_chain: Chain
"""Default chain to use when router doesn't map input to one of the destinations."""
@property
def output_keys(self) -> List[str]:
return ["result"]
[docs]
@classmethod
def from_retrievers(
cls,
llm: BaseLanguageModel,
retriever_infos: List[Dict[str, Any]],
default_retriever: Optional[BaseRetriever] = None,
default_prompt: Optional[PromptTemplate] = None,
default_chain: Optional[Chain] = None,
*,
default_chain_llm: Optional[BaseLanguageModel] = None,
**kwargs: Any,
) -> MultiRetrievalQAChain:
if default_prompt and not default_retriever:
raise ValueError(
"`default_retriever` must be specified if `default_prompt` is "
"provided. Received only `default_prompt`."
)
destinations = [f"{r['name']}: {r['description']}" for r in retriever_infos]
destinations_str = "\n".join(destinations)
router_template = MULTI_RETRIEVAL_ROUTER_TEMPLATE.format(
destinations=destinations_str
)
router_prompt = PromptTemplate(
template=router_template,
input_variables=["input"],
output_parser=RouterOutputParser(next_inputs_inner_key="query"),
)
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
destination_chains = {}
for r_info in retriever_infos:
prompt = r_info.get("prompt")
retriever = r_info["retriever"]
chain = RetrievalQA.from_llm(llm, prompt=prompt, retriever=retriever)
name = r_info["name"]
destination_chains[name] = chain
if default_chain:
_default_chain = default_chain
elif default_retriever:
_default_chain = RetrievalQA.from_llm(
llm, prompt=default_prompt, retriever=default_retriever
)
else:
prompt_template = DEFAULT_TEMPLATE.replace("input", "query")
prompt = PromptTemplate(
template=prompt_template, input_variables=["history", "query"]
)
if default_chain_llm is None:
raise NotImplementedError(
"conversation_llm must be provided if default_chain is not "
"specified. This API has been changed to avoid instantiating "
"default LLMs on behalf of users."
"You can provide a conversation LLM like so:\n"
"from langchain_openai import ChatOpenAI\n"
"llm = ChatOpenAI()"
)
_default_chain = ConversationChain(
llm=default_chain_llm,
prompt=prompt,
input_key="query",
output_key="result",
)
return cls(
router_chain=router_chain,
destination_chains=destination_chains,
default_chain=_default_chain,
**kwargs,
)