"""Chain pipeline where the outputs of one step feed directly into next."""fromtypingimportAny,Dict,List,Optionalfromlangchain_core.callbacksimport(AsyncCallbackManagerForChainRun,CallbackManagerForChainRun,)fromlangchain_core.utils.inputimportget_color_mappingfrompydanticimportConfigDict,model_validatorfromtyping_extensionsimportSelffromlangchain.chains.baseimportChain
[docs]classSequentialChain(Chain):"""Chain where the outputs of one chain feed directly into next."""chains:List[Chain]input_variables:List[str]output_variables:List[str]#: :meta private:return_all:bool=Falsemodel_config=ConfigDict(arbitrary_types_allowed=True,extra="forbid",)@propertydefinput_keys(self)->List[str]:"""Return expected input keys to the chain. :meta private: """returnself.input_variables@propertydefoutput_keys(self)->List[str]:"""Return output key. :meta private: """returnself.output_variables@model_validator(mode="before")@classmethoddefvalidate_chains(cls,values:Dict)->Any:"""Validate that the correct inputs exist for all chains."""chains=values["chains"]input_variables=values["input_variables"]memory_keys=list()if"memory"invaluesandvalues["memory"]isnotNone:"""Validate that prompt input variables are consistent."""memory_keys=values["memory"].memory_variablesifset(input_variables).intersection(set(memory_keys)):overlapping_keys=set(input_variables)&set(memory_keys)raiseValueError(f"The input key(s) {''.join(overlapping_keys)} are found "f"in the Memory keys ({memory_keys}) - please use input and "f"memory keys that don't overlap.")known_variables=set(input_variables+memory_keys)forchaininchains:missing_vars=set(chain.input_keys).difference(known_variables)ifchain.memory:missing_vars=missing_vars.difference(chain.memory.memory_variables)ifmissing_vars:raiseValueError(f"Missing required input keys: {missing_vars}, "f"only had {known_variables}")overlapping_keys=known_variables.intersection(chain.output_keys)ifoverlapping_keys:raiseValueError(f"Chain returned keys that already exist: {overlapping_keys}")known_variables|=set(chain.output_keys)if"output_variables"notinvalues:ifvalues.get("return_all",False):output_keys=known_variables.difference(input_variables)else:output_keys=chains[-1].output_keysvalues["output_variables"]=output_keyselse:missing_vars=set(values["output_variables"]).difference(known_variables)ifmissing_vars:raiseValueError(f"Expected output variables that were not found: {missing_vars}.")returnvaluesdef_call(self,inputs:Dict[str,str],run_manager:Optional[CallbackManagerForChainRun]=None,)->Dict[str,str]:known_values=inputs.copy()_run_manager=run_managerorCallbackManagerForChainRun.get_noop_manager()fori,chaininenumerate(self.chains):callbacks=_run_manager.get_child()outputs=chain(known_values,return_only_outputs=True,callbacks=callbacks)known_values.update(outputs)return{k:known_values[k]forkinself.output_variables}asyncdef_acall(self,inputs:Dict[str,Any],run_manager:Optional[AsyncCallbackManagerForChainRun]=None,)->Dict[str,Any]:known_values=inputs.copy()_run_manager=run_managerorAsyncCallbackManagerForChainRun.get_noop_manager()callbacks=_run_manager.get_child()fori,chaininenumerate(self.chains):outputs=awaitchain.acall(known_values,return_only_outputs=True,callbacks=callbacks)known_values.update(outputs)return{k:known_values[k]forkinself.output_variables}
[docs]classSimpleSequentialChain(Chain):"""Simple chain where the outputs of one step feed directly into next."""chains:List[Chain]strip_outputs:bool=Falseinput_key:str="input"#: :meta private:output_key:str="output"#: :meta private:model_config=ConfigDict(arbitrary_types_allowed=True,extra="forbid",)@propertydefinput_keys(self)->List[str]:"""Expect input key. :meta private: """return[self.input_key]@propertydefoutput_keys(self)->List[str]:"""Return output key. :meta private: """return[self.output_key]@model_validator(mode="after")defvalidate_chains(self)->Self:"""Validate that chains are all single input/output."""forchaininself.chains:iflen(chain.input_keys)!=1:raiseValueError("Chains used in SimplePipeline should all have one input, got "f"{chain} with {len(chain.input_keys)} inputs.")iflen(chain.output_keys)!=1:raiseValueError("Chains used in SimplePipeline should all have one output, got "f"{chain} with {len(chain.output_keys)} outputs.")returnselfdef_call(self,inputs:Dict[str,str],run_manager:Optional[CallbackManagerForChainRun]=None,)->Dict[str,str]:_run_manager=run_managerorCallbackManagerForChainRun.get_noop_manager()_input=inputs[self.input_key]color_mapping=get_color_mapping([str(i)foriinrange(len(self.chains))])fori,chaininenumerate(self.chains):_input=chain.run(_input,callbacks=_run_manager.get_child(f"step_{i+1}"))ifself.strip_outputs:_input=_input.strip()_run_manager.on_text(_input,color=color_mapping[str(i)],end="\n",verbose=self.verbose)return{self.output_key:_input}asyncdef_acall(self,inputs:Dict[str,Any],run_manager:Optional[AsyncCallbackManagerForChainRun]=None,)->Dict[str,Any]:_run_manager=run_managerorAsyncCallbackManagerForChainRun.get_noop_manager()_input=inputs[self.input_key]color_mapping=get_color_mapping([str(i)foriinrange(len(self.chains))])fori,chaininenumerate(self.chains):_input=awaitchain.arun(_input,callbacks=_run_manager.get_child(f"step_{i+1}"))ifself.strip_outputs:_input=_input.strip()await_run_manager.on_text(_input,color=color_mapping[str(i)],end="\n",verbose=self.verbose)return{self.output_key:_input}