"""Base classes for chain routing."""from__future__importannotationsfromabcimportABCfromtypingimportAny,Dict,List,Mapping,NamedTuple,Optionalfromlangchain_core.callbacksimport(AsyncCallbackManagerForChainRun,CallbackManagerForChainRun,Callbacks,)frompydanticimportConfigDictfromlangchain.chains.baseimportChain
[docs]classRouterChain(Chain,ABC):"""Chain that outputs the name of a destination chain and the inputs to it."""@propertydefoutput_keys(self)->List[str]:return["destination","next_inputs"]
[docs]defroute(self,inputs:Dict[str,Any],callbacks:Callbacks=None)->Route:""" Route inputs to a destination chain. Args: inputs: inputs to the chain callbacks: callbacks to use for the chain Returns: a Route object """result=self(inputs,callbacks=callbacks)returnRoute(result["destination"],result["next_inputs"])
[docs]classMultiRouteChain(Chain):"""Use a single chain to route an input to one of multiple candidate chains."""router_chain:RouterChain"""Chain that routes inputs to destination chains."""destination_chains:Mapping[str,Chain]"""Chains that return final answer to inputs."""default_chain:Chain"""Default chain to use when none of the destination chains are suitable."""silent_errors:bool=False"""If True, use default_chain when an invalid destination name is provided. Defaults to False."""model_config=ConfigDict(arbitrary_types_allowed=True,extra="forbid",)@propertydefinput_keys(self)->List[str]:"""Will be whatever keys the router chain prompt expects. :meta private: """returnself.router_chain.input_keys@propertydefoutput_keys(self)->List[str]:"""Will always return text key. :meta private: """return[]def_call(self,inputs:Dict[str,Any],run_manager:Optional[CallbackManagerForChainRun]=None,)->Dict[str,Any]:_run_manager=run_managerorCallbackManagerForChainRun.get_noop_manager()callbacks=_run_manager.get_child()route=self.router_chain.route(inputs,callbacks=callbacks)_run_manager.on_text(str(route.destination)+": "+str(route.next_inputs),verbose=self.verbose)ifnotroute.destination:returnself.default_chain(route.next_inputs,callbacks=callbacks)elifroute.destinationinself.destination_chains:returnself.destination_chains[route.destination](route.next_inputs,callbacks=callbacks)elifself.silent_errors:returnself.default_chain(route.next_inputs,callbacks=callbacks)else:raiseValueError(f"Received invalid destination chain name '{route.destination}'")asyncdef_acall(self,inputs:Dict[str,Any],run_manager:Optional[AsyncCallbackManagerForChainRun]=None,)->Dict[str,Any]:_run_manager=run_managerorAsyncCallbackManagerForChainRun.get_noop_manager()callbacks=_run_manager.get_child()route=awaitself.router_chain.aroute(inputs,callbacks=callbacks)await_run_manager.on_text(str(route.destination)+": "+str(route.next_inputs),verbose=self.verbose)ifnotroute.destination:returnawaitself.default_chain.acall(route.next_inputs,callbacks=callbacks)elifroute.destinationinself.destination_chains:returnawaitself.destination_chains[route.destination].acall(route.next_inputs,callbacks=callbacks)elifself.silent_errors:returnawaitself.default_chain.acall(route.next_inputs,callbacks=callbacks)else:raiseValueError(f"Received invalid destination chain name '{route.destination}'")