Source code for langchain_experimental.pal_chain.base
"""Implements Program-Aided Language Models.This module implements the Program-Aided Language Models (PAL) for generating codesolutions. PAL is a technique described in the paper "Program-Aided Language Models"(https://arxiv.org/pdf/2211.10435.pdf)."""from__future__importannotationsimportastfromtypingimportAny,Dict,List,Optionalfromlangchain.chains.baseimportChainfromlangchain.chains.llmimportLLMChainfromlangchain_core.callbacks.managerimportCallbackManagerForChainRunfromlangchain_core.language_modelsimportBaseLanguageModelfrompydanticimportConfigDict,Field,model_validatorfromtyping_extensionsimportSelffromlangchain_experimental.pal_chain.colored_object_promptimportCOLORED_OBJECT_PROMPTfromlangchain_experimental.pal_chain.math_promptimportMATH_PROMPTfromlangchain_experimental.utilitiesimportPythonREPLCOMMAND_EXECUTION_FUNCTIONS=["system","exec","execfile","eval","__import__","compile",]COMMAND_EXECUTION_ATTRIBUTES=["__import__","__subclasses__","__builtins__","__globals__","__getattribute__","__code__","__bases__","__mro__","__base__",]
[docs]classPALValidation:"""Validation for PAL generated code."""SOLUTION_EXPRESSION_TYPE_FUNCTION=ast.FunctionDefSOLUTION_EXPRESSION_TYPE_VARIABLE=ast.Name
[docs]def__init__(self,solution_expression_name:Optional[str]=None,solution_expression_type:Optional[type]=None,allow_imports:bool=False,allow_command_exec:bool=False,):"""Initialize a PALValidation instance. Args: solution_expression_name (str): Name of the expected solution expression. If passed, solution_expression_type must be passed as well. solution_expression_type (str): AST type of the expected solution expression. If passed, solution_expression_name must be passed as well. Must be one of PALValidation.SOLUTION_EXPRESSION_TYPE_FUNCTION, PALValidation.SOLUTION_EXPRESSION_TYPE_VARIABLE. allow_imports (bool): Allow import statements. allow_command_exec (bool): Allow using known command execution functions. """self.solution_expression_name=solution_expression_nameself.solution_expression_type=solution_expression_typeifsolution_expression_nameisnotNone:ifnotisinstance(self.solution_expression_name,str):raiseValueError(f"Expected solution_expression_name to be str, "f"instead found {type(self.solution_expression_name)}")ifsolution_expression_typeisnotNone:if(self.solution_expression_typeisnotself.SOLUTION_EXPRESSION_TYPE_FUNCTIONandself.solution_expression_typeisnotself.SOLUTION_EXPRESSION_TYPE_VARIABLE):raiseValueError(f"Expected solution_expression_type to be one of "f"({self.SOLUTION_EXPRESSION_TYPE_FUNCTION},"f"{self.SOLUTION_EXPRESSION_TYPE_VARIABLE}),"f"instead found {self.solution_expression_type}")ifsolution_expression_nameisnotNoneandsolution_expression_typeisNone:raiseTypeError("solution_expression_name ""requires solution_expression_type to be passed as well")ifsolution_expression_nameisNoneandsolution_expression_typeisnotNone:raiseTypeError("solution_expression_type ""requires solution_expression_name to be passed as well")self.allow_imports=allow_importsself.allow_command_exec=allow_command_exec
[docs]classPALChain(Chain):"""Chain that implements Program-Aided Language Models (PAL). This class implements the Program-Aided Language Models (PAL) for generating code solutions. PAL is a technique described in the paper "Program-Aided Language Models" (https://arxiv.org/pdf/2211.10435.pdf). *Security note*: This class implements an AI technique that generates and evaluates Python code, which can be dangerous and requires a specially sandboxed environment to be safely used. While this class implements some basic guardrails by limiting available locals/globals and by parsing and inspecting the generated Python AST using `PALValidation`, those guardrails will not deter sophisticated attackers and are not a replacement for a proper sandbox. Do not use this class on untrusted inputs, with elevated permissions, or without consulting your security team about proper sandboxing! """llm_chain:LLMChainstop:str="\n\n""""Stop token to use when generating code."""get_answer_expr:str="print(solution())""""Expression to use to get the answer from the generated code."""python_globals:Optional[Dict[str,Any]]=None"""Python globals and locals to use when executing the generated code."""python_locals:Optional[Dict[str,Any]]=None"""Python globals and locals to use when executing the generated code."""output_key:str="result"#: :meta private:return_intermediate_steps:bool=False"""Whether to return intermediate steps in the generated code."""code_validations:PALValidation=Field(default_factory=PALValidation)"""Validations to perform on the generated code."""timeout:Optional[int]=10"""Timeout in seconds for the generated code to execute."""allow_dangerous_code:bool=False"""This chain relies on the execution of generated code, which can be dangerous. This class implements an AI technique that generates and evaluates Python code, which can be dangerous and requires a specially sandboxed environment to be safely used. While this class implements some basic guardrails by limiting available locals/globals and by parsing and inspecting the generated Python AST using `PALValidation`, those guardrails will not deter sophisticated attackers and are not a replacement for a proper sandbox. Do not use this class on untrusted inputs, with elevated permissions, or without consulting your security team about proper sandboxing! Failure to properly sandbox this class can lead to arbitrary code execution vulnerabilities, which can lead to data breaches, data loss, or other security incidents. """@model_validator(mode="after")defpost_init(self)->Self:ifnotself.allow_dangerous_code:raiseValueError("This chain relies on the execution of generated code, ""which can be dangerous. ""Please read the security notice for this class, and only ""use it if you understand the security implications. ""If you want to proceed, you will need to opt-in, by setting ""`allow_dangerous_code` to `True`.")returnselfmodel_config=ConfigDict(arbitrary_types_allowed=True,extra="forbid",)@propertydefinput_keys(self)->List[str]:"""Return the singular input key. :meta private: """returnself.llm_chain.prompt.input_variables@propertydefoutput_keys(self)->List[str]:"""Return the singular output key. :meta private: """ifnotself.return_intermediate_steps:return[self.output_key]else:return[self.output_key,"intermediate_steps"]def_call(self,inputs:Dict[str,Any],run_manager:Optional[CallbackManagerForChainRun]=None,)->Dict[str,str]:_run_manager=run_managerorCallbackManagerForChainRun.get_noop_manager()code=self.llm_chain.predict(stop=[self.stop],callbacks=_run_manager.get_child(),**inputs)_run_manager.on_text(code,color="green",end="\n",verbose=self.verbose)PALChain.validate_code(code,self.code_validations)# TODO: look into why mypy thinks PythonREPL's type here is `Any`# and therefore not callablerepl=PythonREPL(_globals=self.python_globals,_locals=self.python_locals,)# type: ignore[misc]res=repl.run(code+f"\n{self.get_answer_expr}",timeout=self.timeout)output={self.output_key:res.strip()}ifself.return_intermediate_steps:output["intermediate_steps"]=codereturnoutput
[docs]@classmethoddefvalidate_code(cls,code:str,code_validations:PALValidation)->None:try:code_tree=ast.parse(code)except(SyntaxError,UnicodeDecodeError):raiseValueError(f"Generated code is not valid python code: {code}")exceptTypeError:raiseValueError(f"Generated code is expected to be a string, "f"instead found {type(code)}")exceptOverflowError:raiseValueError(f"Generated code too long / complex to be parsed by ast: {code}")found_solution_expr=Falseifcode_validations.solution_expression_nameisNone:# Skip validation if no solution_expression_name was givenfound_solution_expr=Truehas_imports=Falsetop_level_nodes=list(ast.iter_child_nodes(code_tree))fornodeintop_level_nodes:if(code_validations.solution_expression_nameisnotNoneandcode_validations.solution_expression_typeisnotNone):# Check root nodes (like func def)if(isinstance(node,code_validations.solution_expression_type)andhasattr(node,"name")andnode.name==code_validations.solution_expression_name):found_solution_expr=True# Check assigned nodes (like answer variable)ifisinstance(node,ast.Assign):fortarget_nodeinnode.targets:if(isinstance(target_node,code_validations.solution_expression_type)andhasattr(target_node,"id")andtarget_node.id==code_validations.solution_expression_name):found_solution_expr=Trueifisinstance(node,ast.Import)orisinstance(node,ast.ImportFrom):has_imports=Trueifnotfound_solution_expr:raiseValueError(f"Generated code is missing the solution expression: "f"{code_validations.solution_expression_name} of type: "f"{code_validations.solution_expression_type}")ifnotcode_validations.allow_importsandhas_imports:raiseValueError(f"Generated code has disallowed imports: {code}")if(notcode_validations.allow_command_execornotcode_validations.allow_imports):fornodeinast.walk(code_tree):if(notcode_validations.allow_command_execandisinstance(node,ast.Attribute)andnode.attrinCOMMAND_EXECUTION_ATTRIBUTES):raiseValueError(f"Found illegal command execution function "f"{node.attr} in code {code}")if(notcode_validations.allow_command_exec)andisinstance(node,ast.Call):if(hasattr(node.func,"id")andnode.func.idinCOMMAND_EXECUTION_FUNCTIONS):raiseValueError(f"Found illegal command execution function "f"{node.func.id} in code {code}")if(isinstance(node.func,ast.Attribute)andnode.func.attrinCOMMAND_EXECUTION_FUNCTIONS):raiseValueError(f"Found illegal command execution function "f"{node.func.attr} in code {code}")if(notcode_validations.allow_imports)and(isinstance(node,ast.Import)orisinstance(node,ast.ImportFrom)):raiseValueError(f"Generated code has disallowed imports: {code}")
[docs]@classmethoddeffrom_math_prompt(cls,llm:BaseLanguageModel,**kwargs:Any)->PALChain:"""Load PAL from math prompt. Args: llm (BaseLanguageModel): The language model to use for generating code. Returns: PALChain: An instance of PALChain. """llm_chain=LLMChain(llm=llm,prompt=MATH_PROMPT)code_validations=PALValidation(solution_expression_name="solution",solution_expression_type=PALValidation.SOLUTION_EXPRESSION_TYPE_FUNCTION,)returncls(llm_chain=llm_chain,stop="\n\n",get_answer_expr="print(solution())",code_validations=code_validations,**kwargs,)
[docs]@classmethoddeffrom_colored_object_prompt(cls,llm:BaseLanguageModel,**kwargs:Any)->PALChain:"""Load PAL from colored object prompt. Args: llm (BaseLanguageModel): The language model to use for generating code. Returns: PALChain: An instance of PALChain. """llm_chain=LLMChain(llm=llm,prompt=COLORED_OBJECT_PROMPT)code_validations=PALValidation(solution_expression_name="answer",solution_expression_type=PALValidation.SOLUTION_EXPRESSION_TYPE_VARIABLE,)returncls(llm_chain=llm_chain,stop="\n\n\n",get_answer_expr="print(answer)",code_validations=code_validations,**kwargs,)