"""Agent for working with pandas objects."""
from typing import Any, Dict, List, Optional
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.language_models import BaseLLM
from langchain_experimental.agents.agent_toolkits.spark.prompt import PREFIX, SUFFIX
from langchain_experimental.tools.python.tool import PythonAstREPLTool
def _validate_spark_df(df: Any) -> bool:
try:
from pyspark.sql import DataFrame as SparkLocalDataFrame
return isinstance(df, SparkLocalDataFrame)
except ImportError:
return False
def _validate_spark_connect_df(df: Any) -> bool:
try:
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
return isinstance(df, SparkConnectDataFrame)
except ImportError:
return False
[docs]
def create_spark_dataframe_agent(
llm: BaseLLM,
df: Any,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
input_variables: Optional[List[str]] = None,
verbose: bool = False,
return_intermediate_steps: bool = False,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
allow_dangerous_code: bool = False,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a Spark agent from an LLM and dataframe.
Security Notice:
This agent relies on access to a python repl tool which can execute
arbitrary code. This can be dangerous and requires a specially sandboxed
environment to be safely used. Failure to run this code in a properly
sandboxed environment can lead to arbitrary code execution vulnerabilities,
which can lead to data breaches, data loss, or other security incidents.
Do not use this code with untrusted inputs, with elevated permissions,
or without consulting your security team about proper sandboxing!
You must opt in to use this functionality by setting allow_dangerous_code=True.
Args:
allow_dangerous_code: bool, default False
This agent relies on access to a python repl tool which can execute
arbitrary code. This can be dangerous and requires a specially sandboxed
environment to be safely used.
Failure to properly sandbox this class can lead to arbitrary code execution
vulnerabilities, which can lead to data breaches, data loss, or
other security incidents.
You must opt in to use this functionality by setting
allow_dangerous_code=True.
"""
if not allow_dangerous_code:
raise ValueError(
"This agent relies on access to a python repl tool which can execute "
"arbitrary code. This can be dangerous and requires a specially sandboxed "
"environment to be safely used. Please read the security notice in the "
"doc-string of this function. You must opt-in to use this functionality "
"by setting allow_dangerous_code=True."
"For general security guidelines, please see: "
"https://python.langchain.com/docs/security/"
)
if not _validate_spark_df(df) and not _validate_spark_connect_df(df):
raise ImportError("Spark is not installed. run `pip install pyspark`.")
if input_variables is None:
input_variables = ["df", "input", "agent_scratchpad"]
tools = [PythonAstREPLTool(locals={"df": df})]
prompt = ZeroShotAgent.create_prompt(
tools, prefix=prefix, suffix=suffix, input_variables=input_variables
)
partial_prompt = prompt.partial(df=str(df.first()))
llm_chain = LLMChain(
llm=llm,
prompt=partial_prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent( # type: ignore[call-arg]
llm_chain=llm_chain,
allowed_tools=tool_names,
callback_manager=callback_manager,
**kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
return_intermediate_steps=return_intermediate_steps,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)