Source code for langchain_experimental.agents.agent_toolkits.spark.base

"""Agent for working with pandas objects."""

from typing import Any, Dict, List, Optional

from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.language_models import BaseLLM

from langchain_experimental.agents.agent_toolkits.spark.prompt import PREFIX, SUFFIX
from langchain_experimental.tools.python.tool import PythonAstREPLTool


def _validate_spark_df(df: Any) -> bool:
    try:
        from pyspark.sql import DataFrame as SparkLocalDataFrame

        return isinstance(df, SparkLocalDataFrame)
    except ImportError:
        return False


def _validate_spark_connect_df(df: Any) -> bool:
    try:
        from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame

        return isinstance(df, SparkConnectDataFrame)
    except ImportError:
        return False


[docs]def create_spark_dataframe_agent( llm: BaseLLM, df: Any, callback_manager: Optional[BaseCallbackManager] = None, prefix: str = PREFIX, suffix: str = SUFFIX, input_variables: Optional[List[str]] = None, verbose: bool = False, return_intermediate_steps: bool = False, max_iterations: Optional[int] = 15, max_execution_time: Optional[float] = None, early_stopping_method: str = "force", agent_executor_kwargs: Optional[Dict[str, Any]] = None, allow_dangerous_code: bool = False, **kwargs: Any, ) -> AgentExecutor: """Construct a Spark agent from an LLM and dataframe. Security Notice: This agent relies on access to a python repl tool which can execute arbitrary code. This can be dangerous and requires a specially sandboxed environment to be safely used. Failure to run this code in a properly sandboxed environment can lead to arbitrary code execution vulnerabilities, which can lead to data breaches, data loss, or other security incidents. Do not use this code with untrusted inputs, with elevated permissions, or without consulting your security team about proper sandboxing! You must opt in to use this functionality by setting allow_dangerous_code=True. Args: allow_dangerous_code: bool, default False This agent relies on access to a python repl tool which can execute arbitrary code. This can be dangerous and requires a specially sandboxed environment to be safely used. Failure to properly sandbox this class can lead to arbitrary code execution vulnerabilities, which can lead to data breaches, data loss, or other security incidents. You must opt in to use this functionality by setting allow_dangerous_code=True. """ if not allow_dangerous_code: raise ValueError( "This agent relies on access to a python repl tool which can execute " "arbitrary code. This can be dangerous and requires a specially sandboxed " "environment to be safely used. Please read the security notice in the " "doc-string of this function. You must opt-in to use this functionality " "by setting allow_dangerous_code=True." "For general security guidelines, please see: " "https://python.langchain.com/v0.2/docs/security/" ) if not _validate_spark_df(df) and not _validate_spark_connect_df(df): raise ImportError("Spark is not installed. run `pip install pyspark`.") if input_variables is None: input_variables = ["df", "input", "agent_scratchpad"] tools = [PythonAstREPLTool(locals={"df": df})] prompt = ZeroShotAgent.create_prompt( tools, prefix=prefix, suffix=suffix, input_variables=input_variables ) partial_prompt = prompt.partial(df=str(df.first())) llm_chain = LLMChain( llm=llm, prompt=partial_prompt, callback_manager=callback_manager, ) tool_names = [tool.name for tool in tools] agent = ZeroShotAgent( # type: ignore[call-arg] llm_chain=llm_chain, allowed_tools=tool_names, callback_manager=callback_manager, **kwargs, ) return AgentExecutor.from_agent_and_tools( agent=agent, tools=tools, callback_manager=callback_manager, verbose=verbose, return_intermediate_steps=return_intermediate_steps, max_iterations=max_iterations, max_execution_time=max_execution_time, early_stopping_method=early_stopping_method, **(agent_executor_kwargs or {}), )