How to estimate a PySpark DataFrame size?

Sometimes it is an important question, how much memory does our DataFrame use? And there is no easy answer if you are working with PySpark. You can try to collect the data sample and run local memory profiler. You can estimate the size of the data in the source (for example, in parquet file). But we will go another way and try to analyze the logical plan of Spark from PySpark. In case when we are working with Scala Spark API we are able to work with resolved or unresolved logical plans and physical plan via a special API. But from PySpark API only string representation is available and we will work with it.

Getting Logical Plan

There are two ways to get the logical plan: the first one is via SQL command EXPLAIN and the second one is via df.explain. We will use the second one just to skip the process of creating a temporary view from a DataFrame. df.explain does not provide and API to get a string representation of the plan, it is sent to standard output instead. We need to redirect standard output to get the plan:

  import contextlib
  import io

  with contextlib.redirect_stdout(io.StringIO()) as stdout:
      df.explain(mode="cost")

  logical_plan = stdout.getvalue().split("\n")

Here we used the argument mode="extended". Based on documentation, this argument means:

cost: Print a logical plan and statistics if they are available.

So, we will get all the available statistics, include the estimated size of the DataFrame that we are looking for.

Let's see what the plan looks like. But first we need some data frames of different sizes for our experiments. I was working on the project to rewrite h2o db benchamrk generation from R to Rust (it is now part of the farsante repository) and had some CSV benchmark datasets of different sizes, from a few Kb to a few Gb. I can use them for our experiments:

  from pyspark.sql import SparkSession

  spark = SparkSession.builder.master("local[*]").getOrCreate()
  medium_data = spark.read.csv("/home/sem/github/farsante/h2o-data-rust/G1_1e8_1e8_10_5.csv")
  small_data = spark.read.csv("/home/sem/github/farsante/h2o-data-rust/J1_1e8_1e5_5.csv")
  tiny_data = spark.read.csv("/home/sem/github/farsante/h2o-data-rust/J1_1e8_1e2_5.csv")

And the output logical plan looks like this:

  medium_data.explain(mode="cost")
  == Optimized Logical Plan ==
  Relation [_c0#17,_c1#18,_c.....,_c7#24,_c8#25] csv, Statistics(sizeInBytes=4.5 GiB)

  == Physical Plan ==
  FileScan csv [_c0#17,_c1#18,_c......,_c7#24,_c8#25] Batched: false, DataFilters: [], Format: CSV, Loc...

As you can see, the information we are looking for is in the first (or top) row. And it will always be there, because the logical plan is "reversed" and goes from the last operation to the first one, line by line. So the top line of the logical plan will always be a line representing the current state of our DataFrame.

Parsing the top line of the logical plan

As you already understand, we are looking for the number from this line: Statistics(sizeInBytes=4.5 GiB). Let's use built-in Python regexps to extract this information:

  import re
  pattern = r"^.*sizeInBytes=([0-9]+\.[0-9]+)\s(B|KiB|MiB|GiB|TiB|EiB).*$"

Let's see on the pattern. It says:

  1. ^.*: any amount of symbols in the beginning of the row.
  2. sizeInBytes=([0-9]+\.[0-9]+): the number if the form of 1234.1234 exactly after the word sizeInBytes and equal sign. We create a regex-group from this number.
  3. \s(B|KiB|MiB|GiB|TiB|EiB|): our second group which follows the first one after exactly one space symbol.
  4. .*$: any amount of symbols in the end of the row.
  with contextlib.redirect_stdout(io.StringIO()) as stdout:
      medium_data.explain(mode="cost")
  plan = stdout.getvalue()
  top_line = plan.split("\n")[1]
  re.match(pattern, top_line).groups()

Result:

  ('4.5', 'GiB')

Looks like it works!

Corner case: what happens if Spark doesn't know the size?

Before we finalize our code in the Python function, let's check what happens if Spark doesn't know the size of the data. This is the common case for DataFrame objects that are created from memory, not from disk.

  data = [(i, f"id{i}", f"id2{i}", f"id3{i}") for i in range(1_100_000)]
  sdf = spark.createDataFrame(
      data,
      schema="struct<c1:int,c2:string,c3:string,c4:string>"
  ).withColumn("new_col", F.col("c1") * 4)

  with contextlib.redirect_stdout(io.StringIO()) as stdout:
      sdf.explain(mode="cost")

  plan = stdout.getvalue()
  top_line = plan.split("\n")[1]

  re.match(pattern, top_line).groups()

Result:

  ('8.4', 'EiB')

This is not what we expected, is it? An EiB is something like \(\simeq 10^6\) TiB… The answer is simple: if spark cannot estimate the size, it simply returns the maximum available value (Scala Long.MaxValue). You might say this is a bug, but after reading this discussion I understood that there is no easy way to work around it on the side of Spark. So let's just catch this case on the Python side. Unfortunately, our final code with a workaround won't work if your data is really EiB in size, but I can't imagine such an amount in a single Spark Job.

Finalized code

  import contextlib
  import io

  from pyspark.sql import DataFrame

  def _bytes2mb(bb: float) -> float:
      return bb / 1024 / 1024


  def estimate_size_of_df(df: DataFrame, size_in_mb: bool = False) -> float:
      """Estimate the size in Bytes of the given DataFrame.
      If the size cannot be estimated return -1.0. It is possible if
      we failed to parse plan or, most probably, it is the case when statistics
      is unavailable. There is a problem that currently in the case of missing
      statistics spark return 8 (or 12) EiB. If your data size is really measured in EiB
      this function cannot help you. See https://github.com/apache/spark/pull/31817
      for details. Size is returned in Bytes!

      This function works only in PySpark 3.0.0 or higher!

      :param df: DataFrame
      :param size_in_mb: Convert output to Mb instead of B
      :returns: size in bytes (or Mb if size_in_mb)
      """
      with contextlib.redirect_stdout(io.StringIO()) as stdout:
          # mode argument was added in 3.0.0
          df.explain(mode="cost")

      # Get top line of Optimized Logical Plan
      # The output of df.explain(mode="cost") starts from the following line:
      # == Optimized Logical Plan ==
      # The next line after this should contain something like:
      # Statistics(sizeInBytes=3.0 MiB) (untis may be different)
      top_line = stdout.getvalue().split("\n")[1]

      # We need a pattern to parse the real size and untis
      pattern = r"^.*sizeInBytes=([0-9]+\.[0-9]+)\s(B|KiB|MiB|GiB|TiB|EiB).*$"

      _match = re.search(pattern, top_line)

      if _match:
          size = float(_match.groups()[0])
          units = _match.groups()[1]
      else:
          return -1

      if units == "KiB":
          size *= 1024

      if units == "MiB":
          size *= 1024 * 1024

      if units == "GiB":
          size *= 1024 * 1024 * 1024

      if units == "TiB":
          size *= 1024 * 1024 * 1024 * 1024

      if units == "EiB":
          # Most probably it is the case when Statistics is unavailable
          # In this case spark just returns max possible value
          # See https://github.com/apache/spark/pull/31817 for details
          size = -1

      if size < 0:
          return size

      if size_in_mb:
          return _bytes2mb(size)  # size in Mb

      return size  # size in bytes

Testing

  print(estimate_size_of_df(medium_data, size_in_mb=False))
  print(estimate_size_of_df(medium_data, size_in_mb=True))
  print(estimate_size_of_df(small_data, size_in_mb=True))
  print(estimate_size_of_df(tiny_data, size_in_mb=False))

Result:

  4831838208.0
  4608.0
  3.0
  1691.0