Source code for langchain_community.document_loaders.pyspark_dataframe
import itertools
import logging
import sys
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple
from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseLoader
logger = logging.getLogger(__file__)
if TYPE_CHECKING:
from pyspark.sql import SparkSession
[docs]class PySparkDataFrameLoader(BaseLoader):
"""Load `PySpark` DataFrames."""
[docs] def __init__(
self,
spark_session: Optional["SparkSession"] = None,
df: Optional[Any] = None,
page_content_column: str = "text",
fraction_of_memory: float = 0.1,
):
"""Initialize with a Spark DataFrame object.
Args:
spark_session: The SparkSession object.
df: The Spark DataFrame object.
page_content_column: The name of the column containing the page content.
Defaults to "text".
fraction_of_memory: The fraction of memory to use. Defaults to 0.1.
"""
try:
from pyspark.sql import DataFrame, SparkSession
except ImportError:
raise ImportError(
"pyspark is not installed. "
"Please install it with `pip install pyspark`"
)
self.spark = (
spark_session if spark_session else SparkSession.builder.getOrCreate()
)
if not isinstance(df, DataFrame):
raise ValueError(
f"Expected data_frame to be a PySpark DataFrame, got {type(df)}"
)
self.df = df
self.page_content_column = page_content_column
self.fraction_of_memory = fraction_of_memory
self.num_rows, self.max_num_rows = self.get_num_rows()
self.rdd_df = self.df.rdd.map(list)
self.column_names = self.df.columns
[docs] def get_num_rows(self) -> Tuple[int, int]:
"""Gets the number of "feasible" rows for the DataFrame"""
try:
import psutil
except ImportError as e:
raise ImportError(
"psutil not installed. Please install it with `pip install psutil`."
) from e
row = self.df.limit(1).collect()[0]
estimated_row_size = sys.getsizeof(row)
mem_info = psutil.virtual_memory()
available_memory = mem_info.available
max_num_rows = int(
(available_memory / estimated_row_size) * self.fraction_of_memory
)
return min(max_num_rows, self.df.count()), max_num_rows
[docs] def lazy_load(self) -> Iterator[Document]:
"""A lazy loader for document content."""
for row in self.rdd_df.toLocalIterator():
metadata = {self.column_names[i]: row[i] for i in range(len(row))}
text = metadata[self.page_content_column]
metadata.pop(self.page_content_column)
yield Document(page_content=text, metadata=metadata)
[docs] def load(self) -> List[Document]:
"""Load from the dataframe."""
if self.df.count() > self.max_num_rows:
logger.warning(
f"The number of DataFrame rows is {self.df.count()}, "
f"but we will only include the amount "
f"of rows that can reasonably fit in memory: {self.num_rows}."
)
lazy_load_iterator = self.lazy_load()
return list(itertools.islice(lazy_load_iterator, self.num_rows))