[docs]classSparkSQL:"""SparkSQL is a utility class for interacting with Spark SQL."""
[docs]def__init__(self,spark_session:Optional[SparkSession]=None,catalog:Optional[str]=None,schema:Optional[str]=None,ignore_tables:Optional[List[str]]=None,include_tables:Optional[List[str]]=None,sample_rows_in_table_info:int=3,):"""Initialize a SparkSQL object. Args: spark_session: A SparkSession object. If not provided, one will be created. catalog: The catalog to use. If not provided, the default catalog will be used. schema: The schema to use. If not provided, the default schema will be used. ignore_tables: A list of tables to ignore. If not provided, all tables will be used. include_tables: A list of tables to include. If not provided, all tables will be used. sample_rows_in_table_info: The number of rows to include in the table info. Defaults to 3. """try:frompyspark.sqlimportSparkSessionexceptImportError:raiseImportError("pyspark is not installed. Please install it with `pip install pyspark`")self._spark=(spark_sessionifspark_sessionelseSparkSession.builder.getOrCreate())ifcatalogisnotNone:self._spark.catalog.setCurrentCatalog(catalog)ifschemaisnotNone:self._spark.catalog.setCurrentDatabase(schema)self._all_tables=set(self._get_all_table_names())self._include_tables=set(include_tables)ifinclude_tableselseset()ifself._include_tables:missing_tables=self._include_tables-self._all_tablesifmissing_tables:raiseValueError(f"include_tables {missing_tables} not found in database")self._ignore_tables=set(ignore_tables)ifignore_tableselseset()ifself._ignore_tables:missing_tables=self._ignore_tables-self._all_tablesifmissing_tables:raiseValueError(f"ignore_tables {missing_tables} not found in database")usable_tables=self.get_usable_table_names()self._usable_tables=set(usable_tables)ifusable_tableselseself._all_tablesifnotisinstance(sample_rows_in_table_info,int):raiseTypeError("sample_rows_in_table_info must be an integer")self._sample_rows_in_table_info=sample_rows_in_table_info
[docs]@classmethoddeffrom_uri(cls,database_uri:str,engine_args:Optional[dict]=None,**kwargs:Any)->SparkSQL:"""Creating a remote Spark Session via Spark connect. For example: SparkSQL.from_uri("sc://localhost:15002") """try:frompyspark.sqlimportSparkSessionexceptImportError:raiseImportError("pyspark is not installed. Please install it with `pip install pyspark`")spark=SparkSession.builder.remote(database_uri).getOrCreate()returncls(spark,**kwargs)
[docs]defget_usable_table_names(self)->Iterable[str]:"""Get names of tables available."""ifself._include_tables:returnself._include_tables# sorting the result can help LLM understanding it.returnsorted(self._all_tables-self._ignore_tables)
def_get_all_table_names(self)->Iterable[str]:rows=self._spark.sql("SHOW TABLES").select("tableName").collect()returnlist(map(lambdarow:row.tableName,rows))def_get_create_table_stmt(self,table:str)->str:statement=(self._spark.sql(f"SHOW CREATE TABLE {table}").collect()[0].createtab_stmt)# Ignore the data source provider and options to reduce the number of tokens.using_clause_index=statement.find("USING")returnstatement[:using_clause_index]+";"
[docs]defget_table_info(self,table_names:Optional[List[str]]=None)->str:all_table_names=self.get_usable_table_names()iftable_namesisnotNone:missing_tables=set(table_names).difference(all_table_names)ifmissing_tables:raiseValueError(f"table_names {missing_tables} not found in database")all_table_names=table_namestables=[]fortable_nameinall_table_names:table_info=self._get_create_table_stmt(table_name)ifself._sample_rows_in_table_info:table_info+="\n\n/*"table_info+=f"\n{self._get_sample_spark_rows(table_name)}\n"table_info+="*/"tables.append(table_info)final_str="\n\n".join(tables)returnfinal_str
def_get_sample_spark_rows(self,table:str)->str:query=f"SELECT * FROM {table} LIMIT {self._sample_rows_in_table_info}"df=self._spark.sql(query)columns_str="\t".join(list(map(lambdaf:f.name,df.schema.fields)))try:sample_rows=self._get_dataframe_results(df)# save the sample rows in string formatsample_rows_str="\n".join(["\t".join(row)forrowinsample_rows])exceptException:sample_rows_str=""return(f"{self._sample_rows_in_table_info} rows from {table} table:\n"f"{columns_str}\n"f"{sample_rows_str}")def_convert_row_as_tuple(self,row:Row)->tuple:returntuple(map(str,row.asDict().values()))def_get_dataframe_results(self,df:DataFrame)->list:returnlist(map(self._convert_row_as_tuple,df.collect()))
[docs]defget_table_info_no_throw(self,table_names:Optional[List[str]]=None)->str:"""Get information about specified tables. Follows best practices as specified in: Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498) If `sample_rows_in_table_info`, the specified number of sample rows will be appended to each table description. This can increase performance as demonstrated in the paper. """try:returnself.get_table_info(table_names)exceptValueErrorase:"""Format the error message"""returnf"Error: {e}"
[docs]defrun_no_throw(self,command:str,fetch:str="all")->str:"""Execute a SQL command and return a string representing the results. If the statement returns rows, a string of the results is returned. If the statement returns no rows, an empty string is returned. If the statement throws an error, the error message is returned. """try:returnself.run(command,fetch)exceptExceptionase:"""Format the error message"""returnf"Error: {e}"