"""CPAL Chain and its subchains"""from__future__importannotationsimportjsonfromtypingimportAny,ClassVar,Dict,List,Optional,Type,castimportpydanticfromlangchain.base_languageimportBaseLanguageModelfromlangchain.chains.baseimportChainfromlangchain.chains.llmimportLLMChainfromlangchain.output_parsersimportPydanticOutputParserfromlangchain_core.callbacks.managerimportCallbackManagerForChainRunfromlangchain_core.prompts.promptimportPromptTemplatefromlangchain_experimental.cpal.constantsimportConstantfromlangchain_experimental.cpal.modelsimport(CausalModel,InterventionModel,NarrativeModel,QueryModel,StoryModel,)fromlangchain_experimental.cpal.templates.univariate.causalimport(templateascausal_template,)fromlangchain_experimental.cpal.templates.univariate.interventionimport(templateasintervention_template,)fromlangchain_experimental.cpal.templates.univariate.narrativeimport(templateasnarrative_template,)fromlangchain_experimental.cpal.templates.univariate.queryimport(templateasquery_template,)class_BaseStoryElementChain(Chain):chain:LLMChaininput_key:str=Constant.narrative_input.value#: :meta private:output_key:str=Constant.chain_answer.value#: :meta private:pydantic_model:ClassVar[Optional[Type[pydantic.BaseModel]]]=(None#: :meta private:)template:ClassVar[Optional[str]]=None#: :meta private:@classmethoddefparser(cls)->PydanticOutputParser:"""Parse LLM output into a pydantic object."""ifcls.pydantic_modelisNone:raiseNotImplementedError(f"pydantic_model not implemented for {cls.__name__}")returnPydanticOutputParser(pydantic_object=cls.pydantic_model)@propertydefinput_keys(self)->List[str]:"""Return the input keys. :meta private: """return[self.input_key]@propertydefoutput_keys(self)->List[str]:"""Return the output keys. :meta private: """_output_keys=[self.output_key]return_output_keys@classmethoddeffrom_univariate_prompt(cls,llm:BaseLanguageModel,**kwargs:Any,)->Any:returncls(chain=LLMChain(llm=llm,prompt=PromptTemplate(input_variables=[Constant.narrative_input.value],template=kwargs.get("template",cls.template),partial_variables={"format_instructions":cls.parser().get_format_instructions()},),),**kwargs,)def_call(self,inputs:Dict[str,Any],run_manager:Optional[CallbackManagerForChainRun]=None,)->Dict[str,Any]:completion=self.chain.run(inputs[self.input_key])pydantic_data=self.__class__.parser().parse(completion)return{Constant.chain_data.value:pydantic_data,Constant.chain_answer.value:None,}
[docs]classNarrativeChain(_BaseStoryElementChain):"""Decompose the narrative into its story elements. - causal model - query - intervention """pydantic_model:ClassVar[Type[pydantic.BaseModel]]=NarrativeModeltemplate:ClassVar[str]=narrative_template
[docs]classCausalChain(_BaseStoryElementChain):"""Translate the causal narrative into a stack of operations."""pydantic_model:ClassVar[Type[pydantic.BaseModel]]=CausalModeltemplate:ClassVar[str]=causal_template
[docs]classInterventionChain(_BaseStoryElementChain):"""Set the hypothetical conditions for the causal model."""pydantic_model:ClassVar[Type[pydantic.BaseModel]]=InterventionModeltemplate:ClassVar[str]=intervention_template
[docs]classQueryChain(_BaseStoryElementChain):"""Query the outcome table using SQL. *Security note*: This class implements an AI technique that generates SQL code. If those SQL commands are executed, it's critical to ensure they use credentials that are narrowly-scoped to only include the permissions this chain needs. Failure to do so may result in data corruption or loss, since this chain may attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted. The best way to guard against such negative outcomes is to (as appropriate) limit the permissions granted to the credentials used with this chain. """pydantic_model:ClassVar[Type[pydantic.BaseModel]]=QueryModeltemplate:ClassVar[str]=query_template# TODO: incl. table schema
[docs]classCPALChain(_BaseStoryElementChain):"""Causal program-aided language (CPAL) chain implementation. *Security note*: The building blocks of this class include the implementation of an AI technique that generates SQL code. If those SQL commands are executed, it's critical to ensure they use credentials that are narrowly-scoped to only include the permissions this chain needs. Failure to do so may result in data corruption or loss, since this chain may attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted. The best way to guard against such negative outcomes is to (as appropriate) limit the permissions granted to the credentials used with this chain. """llm:BaseLanguageModelnarrative_chain:Optional[NarrativeChain]=Nonecausal_chain:Optional[CausalChain]=Noneintervention_chain:Optional[InterventionChain]=Nonequery_chain:Optional[QueryChain]=None# TODO: change name of _story?_story:Optional[StoryModel]=pydantic.PrivateAttr(default=None)
[docs]@classmethoddeffrom_univariate_prompt(cls,llm:BaseLanguageModel,**kwargs:Any,)->CPALChain:"""instantiation depends on component chains *Security note*: The building blocks of this class include the implementation of an AI technique that generates SQL code. If those SQL commands are executed, it's critical to ensure they use credentials that are narrowly-scoped to only include the permissions this chain needs. Failure to do so may result in data corruption or loss, since this chain may attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted. The best way to guard against such negative outcomes is to (as appropriate) limit the permissions granted to the credentials used with this chain. """returncls(llm=llm,chain=LLMChain(llm=llm,prompt=PromptTemplate(input_variables=["question","query_result"],template=("Summarize this answer '{query_result}' to this ""question '{question}'? "),),),narrative_chain=NarrativeChain.from_univariate_prompt(llm=llm),causal_chain=CausalChain.from_univariate_prompt(llm=llm),intervention_chain=InterventionChain.from_univariate_prompt(llm=llm),query_chain=QueryChain.from_univariate_prompt(llm=llm),**kwargs,)
def_call(self,inputs:Dict[str,Any],run_manager:Optional[CallbackManagerForChainRun]=None,**kwargs:Any,)->Dict[str,Any]:# instantiate component chainsifself.narrative_chainisNone:self.narrative_chain=NarrativeChain.from_univariate_prompt(llm=self.llm)ifself.causal_chainisNone:self.causal_chain=CausalChain.from_univariate_prompt(llm=self.llm)ifself.intervention_chainisNone:self.intervention_chain=InterventionChain.from_univariate_prompt(llm=self.llm)ifself.query_chainisNone:self.query_chain=QueryChain.from_univariate_prompt(llm=self.llm)# decompose narrative into three causal story elementsnarrative=self.narrative_chain(inputs[Constant.narrative_input.value])[Constant.chain_data.value]story=StoryModel(causal_operations=self.causal_chain(narrative.story_plot)[Constant.chain_data.value],intervention=self.intervention_chain(narrative.story_hypothetical)[Constant.chain_data.value],query=self.query_chain(narrative.story_outcome_question)[Constant.chain_data.value],)self._story=storydefpretty_print_str(title:str,d:str)->str:returntitle+"\n"+d_run_manager=run_managerorCallbackManagerForChainRun.get_noop_manager()_run_manager.on_text(pretty_print_str("story outcome data",story._outcome_table.to_string()),color="green",end="\n\n",verbose=self.verbose,)defpretty_print_dict(title:str,d:dict)->str:returntitle+"\n"+json.dumps(d,indent=4)_run_manager.on_text(pretty_print_dict("query data",story.query.dict()),color="blue",end="\n\n",verbose=self.verbose,)ifstory.query._result_table.empty:# prevent piping bad data into subsequent chainsraiseValueError(("unanswerable, query and outcome are incoherent\n""\n""outcome:\n"f"{story._outcome_table}\n""query:\n"f"{story.query.dict()}"))else:query_result=float(story.query._result_table.values[0][-1])ifFalse:"""TODO: add this back in when demanded by composable chains"""reporting_chain=self.chainhuman_report=reporting_chain.run(question=story.query.question,query_result=query_result)query_result={"query_result":query_result,"human_report":human_report,}output={Constant.chain_data.value:story,self.output_key:query_result,**kwargs,}returnoutput
[docs]defdraw(self,**kwargs:Any)->None:""" CPAL chain can draw its resulting DAG. Usage in a jupyter notebook: >>> from IPython.display import SVG >>> cpal_chain.draw(path="graph.svg") >>> SVG('graph.svg') """cast(StoryModel,self._story)._networkx_wrapper.draw_graphviz(**kwargs)