"""Chain for interacting with SQL Database."""from__future__importannotationsimportwarningsfromtypingimportAny,Dict,List,Optionalfromlangchain.chains.baseimportChainfromlangchain.chains.llmimportLLMChainfromlangchain.chains.sql_database.promptimportDECIDER_PROMPT,PROMPT,SQL_PROMPTSfromlangchain.schemaimportBasePromptTemplatefromlangchain_community.tools.sql_database.promptimportQUERY_CHECKERfromlangchain_community.utilities.sql_databaseimportSQLDatabasefromlangchain_core.callbacks.managerimportCallbackManagerForChainRunfromlangchain_core.language_modelsimportBaseLanguageModelfromlangchain_core.prompts.promptimportPromptTemplatefrompydanticimportConfigDict,Field,model_validatorINTERMEDIATE_STEPS_KEY="intermediate_steps"SQL_QUERY="SQLQuery:"SQL_RESULT="SQLResult:"
[docs]classSQLDatabaseChain(Chain):"""Chain for interacting with SQL Database. Example: .. code-block:: python from langchain_experimental.sql import SQLDatabaseChain from langchain_community.llms import OpenAI, SQLDatabase db = SQLDatabase(...) db_chain = SQLDatabaseChain.from_llm(OpenAI(), db) *Security note*: Make sure that the database connection uses credentials that are narrowly-scoped to only include the permissions this chain needs. Failure to do so may result in data corruption or loss, since this chain may attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted. The best way to guard against such negative outcomes is to (as appropriate) limit the permissions granted to the credentials used with this chain. This issue shows an example negative outcome if these steps are not taken: https://github.com/langchain-ai/langchain/issues/5923 """llm_chain:LLMChainllm:Optional[BaseLanguageModel]=None"""[Deprecated] LLM wrapper to use."""database:SQLDatabase=Field(exclude=True)"""SQL Database to connect to."""prompt:Optional[BasePromptTemplate]=None"""[Deprecated] Prompt to use to translate natural language to SQL."""top_k:int=5"""Number of results to return from the query"""input_key:str="query"#: :meta private:output_key:str="result"#: :meta private:return_sql:bool=False"""Will return sql-command directly without executing it"""return_intermediate_steps:bool=False"""Whether or not to return the intermediate steps along with the final answer."""return_direct:bool=False"""Whether or not to return the result of querying the SQL table directly."""use_query_checker:bool=False"""Whether or not the query checker tool should be used to attempt to fix the initial SQL from the LLM."""query_checker_prompt:Optional[BasePromptTemplate]=None"""The prompt template that should be used by the query checker"""model_config=ConfigDict(arbitrary_types_allowed=True,extra="forbid",)@model_validator(mode="before")@classmethoddefraise_deprecation(cls,values:Dict)->Any:if"llm"invalues:warnings.warn("Directly instantiating an SQLDatabaseChain with an llm is deprecated. ""Please instantiate with llm_chain argument or using the from_llm ""class method.")if"llm_chain"notinvaluesandvalues["llm"]isnotNone:database=values["database"]prompt=values.get("prompt")orSQL_PROMPTS.get(database.dialect,PROMPT)values["llm_chain"]=LLMChain(llm=values["llm"],prompt=prompt)returnvalues@propertydefinput_keys(self)->List[str]:"""Return the singular input key. :meta private: """return[self.input_key]@propertydefoutput_keys(self)->List[str]:"""Return the singular output key. :meta private: """ifnotself.return_intermediate_steps:return[self.output_key]else:return[self.output_key,INTERMEDIATE_STEPS_KEY]def_call(self,inputs:Dict[str,Any],run_manager:Optional[CallbackManagerForChainRun]=None,)->Dict[str,Any]:_run_manager=run_managerorCallbackManagerForChainRun.get_noop_manager()input_text=f"{inputs[self.input_key]}\n{SQL_QUERY}"_run_manager.on_text(input_text,verbose=self.verbose)# If not present, then defaults to None which is all tables.table_names_to_use=inputs.get("table_names_to_use")table_info=self.database.get_table_info(table_names=table_names_to_use)llm_inputs={"input":input_text,"top_k":str(self.top_k),"dialect":self.database.dialect,"table_info":table_info,"stop":["\nSQLResult:"],}ifself.memoryisnotNone:forkinself.memory.memory_variables:llm_inputs[k]=inputs[k]intermediate_steps:List=[]try:intermediate_steps.append(llm_inputs.copy())# input: sql generationsql_cmd=self.llm_chain.predict(callbacks=_run_manager.get_child(),**llm_inputs,).strip()ifself.return_sql:return{self.output_key:sql_cmd}ifnotself.use_query_checker:_run_manager.on_text(sql_cmd,color="green",verbose=self.verbose)intermediate_steps.append(sql_cmd)# output: sql generation (no checker)intermediate_steps.append({"sql_cmd":sql_cmd})# input: sql execifSQL_QUERYinsql_cmd:sql_cmd=sql_cmd.split(SQL_QUERY)[1].strip()ifSQL_RESULTinsql_cmd:sql_cmd=sql_cmd.split(SQL_RESULT)[0].strip()result=self.database.run(sql_cmd)intermediate_steps.append(str(result))# output: sql execelse:query_checker_prompt=self.query_checker_promptorPromptTemplate(template=QUERY_CHECKER,input_variables=["query","dialect"])query_checker_chain=LLMChain(llm=self.llm_chain.llm,prompt=query_checker_prompt)query_checker_inputs={"query":sql_cmd,"dialect":self.database.dialect,}checked_sql_command:str=query_checker_chain.predict(callbacks=_run_manager.get_child(),**query_checker_inputs).strip()intermediate_steps.append(checked_sql_command)# output: sql generation (checker)_run_manager.on_text(checked_sql_command,color="green",verbose=self.verbose)intermediate_steps.append({"sql_cmd":checked_sql_command})# input: sql execresult=self.database.run(checked_sql_command)intermediate_steps.append(str(result))# output: sql execsql_cmd=checked_sql_command_run_manager.on_text("\nSQLResult: ",verbose=self.verbose)_run_manager.on_text(str(result),color="yellow",verbose=self.verbose)# If return direct, we just set the final result equal to# the result of the sql query result, otherwise try to get a human readable# final answerifself.return_direct:final_result=resultelse:_run_manager.on_text("\nAnswer:",verbose=self.verbose)input_text+=f"{sql_cmd}\nSQLResult: {result}\nAnswer:"llm_inputs["input"]=input_textintermediate_steps.append(llm_inputs.copy())# input: final answerfinal_result=self.llm_chain.predict(callbacks=_run_manager.get_child(),**llm_inputs,).strip()intermediate_steps.append(final_result)# output: final answer_run_manager.on_text(final_result,color="green",verbose=self.verbose)chain_result:Dict[str,Any]={self.output_key:final_result}ifself.return_intermediate_steps:chain_result[INTERMEDIATE_STEPS_KEY]=intermediate_stepsreturnchain_resultexceptExceptionasexc:# Append intermediate steps to exception, to aid in logging and later# improvement of few shot prompt seedsexc.intermediate_steps=intermediate_steps# type: ignoreraiseexc@propertydef_chain_type(self)->str:return"sql_database_chain"
[docs]@classmethoddeffrom_llm(cls,llm:BaseLanguageModel,db:SQLDatabase,prompt:Optional[BasePromptTemplate]=None,**kwargs:Any,)->SQLDatabaseChain:"""Create a SQLDatabaseChain from an LLM and a database connection. *Security note*: Make sure that the database connection uses credentials that are narrowly-scoped to only include the permissions this chain needs. Failure to do so may result in data corruption or loss, since this chain may attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted. The best way to guard against such negative outcomes is to (as appropriate) limit the permissions granted to the credentials used with this chain. This issue shows an example negative outcome if these steps are not taken: https://github.com/langchain-ai/langchain/issues/5923 """prompt=promptorSQL_PROMPTS.get(db.dialect,PROMPT)llm_chain=LLMChain(llm=llm,prompt=prompt)returncls(llm_chain=llm_chain,database=db,**kwargs)
[docs]classSQLDatabaseSequentialChain(Chain):"""Chain for querying SQL database that is a sequential chain. The chain is as follows: 1. Based on the query, determine which tables to use. 2. Based on those tables, call the normal SQL database chain. This is useful in cases where the number of tables in the database is large. """decider_chain:LLMChainsql_chain:SQLDatabaseChaininput_key:str="query"#: :meta private:output_key:str="result"#: :meta private:return_intermediate_steps:bool=False
[docs]@classmethoddeffrom_llm(cls,llm:BaseLanguageModel,db:SQLDatabase,query_prompt:BasePromptTemplate=PROMPT,decider_prompt:BasePromptTemplate=DECIDER_PROMPT,**kwargs:Any,)->SQLDatabaseSequentialChain:"""Load the necessary chains."""sql_chain=SQLDatabaseChain.from_llm(llm,db,prompt=query_prompt,**kwargs)decider_chain=LLMChain(llm=llm,prompt=decider_prompt,output_key="table_names")returncls(sql_chain=sql_chain,decider_chain=decider_chain,**kwargs)
@propertydefinput_keys(self)->List[str]:"""Return the singular input key. :meta private: """return[self.input_key]@propertydefoutput_keys(self)->List[str]:"""Return the singular output key. :meta private: """ifnotself.return_intermediate_steps:return[self.output_key]else:return[self.output_key,INTERMEDIATE_STEPS_KEY]def_call(self,inputs:Dict[str,Any],run_manager:Optional[CallbackManagerForChainRun]=None,)->Dict[str,Any]:_run_manager=run_managerorCallbackManagerForChainRun.get_noop_manager()_table_names=self.sql_chain.database.get_usable_table_names()table_names=", ".join(_table_names)llm_inputs={"query":inputs[self.input_key],"table_names":table_names,}_lowercased_table_names=[name.lower()fornamein_table_names]table_names_from_chain=self.decider_chain.predict_and_parse(**llm_inputs)table_names_to_use=[namefornameintable_names_from_chainifname.lower()in_lowercased_table_names]_run_manager.on_text("Table names to use:",end="\n",verbose=self.verbose)_run_manager.on_text(str(table_names_to_use),color="yellow",verbose=self.verbose)new_inputs={self.sql_chain.input_key:inputs[self.input_key],"table_names_to_use":table_names_to_use,}returnself.sql_chain(new_inputs,callbacks=_run_manager.get_child(),return_only_outputs=True)@propertydef_chain_type(self)->str:return"sql_database_sequential_chain"