[docs]def__init__(self,spark_session:Optional["SparkSession"]=None,df:Optional[Any]=None,page_content_column:str="text",fraction_of_memory:float=0.1,):"""Initialize with a Spark DataFrame object. Args: spark_session: The SparkSession object. df: The Spark DataFrame object. page_content_column: The name of the column containing the page content. Defaults to "text". fraction_of_memory: The fraction of memory to use. Defaults to 0.1. """try:frompyspark.sqlimportDataFrame,SparkSessionexceptImportError:raiseImportError("pyspark is not installed. ""Please install it with `pip install pyspark`")self.spark=(spark_sessionifspark_sessionelseSparkSession.builder.getOrCreate())ifnotisinstance(df,DataFrame):raiseValueError(f"Expected data_frame to be a PySpark DataFrame, got {type(df)}")self.df=dfself.page_content_column=page_content_columnself.fraction_of_memory=fraction_of_memoryself.num_rows,self.max_num_rows=self.get_num_rows()self.rdd_df=self.df.rdd.map(list)self.column_names=self.df.columns
[docs]defget_num_rows(self)->Tuple[int,int]:"""Gets the number of "feasible" rows for the DataFrame"""try:importpsutilexceptImportErrorase:raiseImportError("psutil not installed. Please install it with `pip install psutil`.")fromerow=self.df.limit(1).collect()[0]estimated_row_size=sys.getsizeof(row)mem_info=psutil.virtual_memory()available_memory=mem_info.availablemax_num_rows=int((available_memory/estimated_row_size)*self.fraction_of_memory)returnmin(max_num_rows,self.df.count()),max_num_rows
[docs]deflazy_load(self)->Iterator[Document]:"""A lazy loader for document content."""forrowinself.rdd_df.toLocalIterator():metadata={self.column_names[i]:row[i]foriinrange(len(row))}text=metadata[self.page_content_column]metadata.pop(self.page_content_column)yieldDocument(page_content=text,metadata=metadata)
[docs]defload(self)->List[Document]:"""Load from the dataframe."""ifself.df.count()>self.max_num_rows:logger.warning(f"The number of DataFrame rows is {self.df.count()}, "f"but we will only include the amount "f"of rows that can reasonably fit in memory: {self.num_rows}.")lazy_load_iterator=self.lazy_load()returnlist(itertools.islice(lazy_load_iterator,self.num_rows))