Source code for langchain.chains.router.llm_router

"""Base classes for LLM-powered router chains."""

from __future__ import annotations

from typing import Any, Dict, List, Optional, Type, cast

from langchain_core._api import deprecated
from langchain_core.callbacks import (
    AsyncCallbackManagerForChainRun,
    CallbackManagerForChainRun,
)
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.utils.json import parse_and_check_json_markdown
from pydantic import model_validator
from typing_extensions import Self

from langchain.chains import LLMChain
from langchain.chains.router.base import RouterChain


[docs] @deprecated( since="0.2.12", removal="1.0", message=( "Use RunnableLambda to select from multiple prompt templates. See example " "in API reference: " "https://api.python.langchain.com/en/latest/chains/langchain.chains.router.llm_router.LLMRouterChain.html" # noqa: E501 ), ) class LLMRouterChain(RouterChain): """A router chain that uses an LLM chain to perform routing. This class is deprecated. See below for a replacement, which offers several benefits, including streaming and batch support. Below is an example implementation: .. code-block:: python from operator import itemgetter from typing import Literal from typing_extensions import TypedDict from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableLambda, RunnablePassthrough from langchain_openai import ChatOpenAI llm = ChatOpenAI(model="gpt-4o-mini") prompt_1 = ChatPromptTemplate.from_messages( [ ("system", "You are an expert on animals."), ("human", "{query}"), ] ) prompt_2 = ChatPromptTemplate.from_messages( [ ("system", "You are an expert on vegetables."), ("human", "{query}"), ] ) chain_1 = prompt_1 | llm | StrOutputParser() chain_2 = prompt_2 | llm | StrOutputParser() route_system = "Route the user's query to either the animal or vegetable expert." route_prompt = ChatPromptTemplate.from_messages( [ ("system", route_system), ("human", "{query}"), ] ) class RouteQuery(TypedDict): \"\"\"Route query to destination.\"\"\" destination: Literal["animal", "vegetable"] route_chain = ( route_prompt | llm.with_structured_output(RouteQuery) | itemgetter("destination") ) chain = { "destination": route_chain, # "animal" or "vegetable" "query": lambda x: x["query"], # pass through input query } | RunnableLambda( # if animal, chain_1. otherwise, chain_2. lambda x: chain_1 if x["destination"] == "animal" else chain_2, ) chain.invoke({"query": "what color are carrots"}) """ # noqa: E501 llm_chain: LLMChain """LLM chain used to perform routing""" @model_validator(mode="after") def validate_prompt(self) -> Self: prompt = self.llm_chain.prompt if prompt.output_parser is None: raise ValueError( "LLMRouterChain requires base llm_chain prompt to have an output" " parser that converts LLM text output to a dictionary with keys" " 'destination' and 'next_inputs'. Received a prompt with no output" " parser." ) return self @property def input_keys(self) -> List[str]: """Will be whatever keys the LLM chain prompt expects. :meta private: """ return self.llm_chain.input_keys def _validate_outputs(self, outputs: Dict[str, Any]) -> None: super()._validate_outputs(outputs) if not isinstance(outputs["next_inputs"], dict): raise ValueError def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() prediction = self.llm_chain.predict(callbacks=callbacks, **inputs) output = cast( Dict[str, Any], self.llm_chain.prompt.output_parser.parse(prediction), ) return output async def _acall( self, inputs: Dict[str, Any], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, Any]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() output = cast( Dict[str, Any], await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs), ) return output
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any ) -> LLMRouterChain: """Convenience constructor.""" llm_chain = LLMChain(llm=llm, prompt=prompt) return cls(llm_chain=llm_chain, **kwargs)
[docs] class RouterOutputParser(BaseOutputParser[Dict[str, str]]): """Parser for output of router chain in the multi-prompt chain.""" default_destination: str = "DEFAULT" next_inputs_type: Type = str next_inputs_inner_key: str = "input"
[docs] def parse(self, text: str) -> Dict[str, Any]: try: expected_keys = ["destination", "next_inputs"] parsed = parse_and_check_json_markdown(text, expected_keys) if not isinstance(parsed["destination"], str): raise ValueError("Expected 'destination' to be a string.") if not isinstance(parsed["next_inputs"], self.next_inputs_type): raise ValueError( f"Expected 'next_inputs' to be {self.next_inputs_type}." ) parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]} if ( parsed["destination"].strip().lower() == self.default_destination.lower() ): parsed["destination"] = None else: parsed["destination"] = parsed["destination"].strip() return parsed except Exception as e: raise OutputParserException( f"Parsing text\n{text}\n raised following error:\n{e}" )