PySpark column-level lineage

Introduction

In this post, I will show you how to use information from the spark plan to track data lineage at the column level. Let's say we have the following DataFrame object:

  from pyspark.sql import SparkSession, functions as F
  spark = SparkSession.builder.master("local[*]").getOrCreate()
  dat = spark.read.csv("/home/sem/github/farsante/h2o-data-rust/J1_1e8_1e5_5.csv", header=True)
  dat.printSchema()

Result:

root
 |-- id1: string (nullable = true)
 |-- id2: string (nullable = true)
 |-- id4: string (nullable = true)
 |-- id5: string (nullable = true)
 |-- v2: string (nullable = true)

Let's create some transformations on top of our dat object:

  dat_new = (
      dat.withColumn("id1_renamed", F.col("id1"))
      .withColumn("id1_and_id2", F.concat_ws("_", F.col("id1"), F.col("id2")))
      .withColumn("num1", F.lit(1))
      .withColumn("num2", F.lit(2))
      .filter(F.rand() <= F.lit(0.5))
      .select("id1_renamed", "id1_and_id2", "id1", "num1", "num2")
      .withColumn("hash_id", F.hash("id1_and_id2"))
      .join(dat.select("id1", "id4"), on=["id1"], how="left")
      .withColumn("hash_of_two_ids", F.concat_ws("_", "id4", "hash_id"))
      .groupBy("id1_renamed")
      .agg(F.count_distinct("hash_of_two_ids").alias("cnt_ids"), F.sum(F.col("num1") + F.col("num2")).alias("sum_col"))
  )

Even with such a small transformation, it is not at all obvious which column is coming from where. Tracking transformations is the goal of Data Lineage. There are several types of data lineage:

  • At the data source level, when we want to track all data sources in the process of our transformations;
  • At the column level, when we want to track how which column was transformed during the process.

In this post I will focus on the second one, but the first one can be achieved in a similar way. But to implement it, we need to understand a little bit how Apache Spark works, how lazy computations work, and what the Directed Acyclic Graph of computations is.

A short introduction to spark computations model and Catalyst

You can get a deeper dive by reading an original paper about Spark SQL and Catalyst: Spark SQL: Relational Data Processing in Spark. I will just give a top-level overview.

When you apply a transformation, like withColumn("num1", F.lit(1)), Spark only adds a step to the computation graph, but does not add an actual column to the PySpark DataFrame you are working with. So at any moment, DataFrame is not a "real" data, but just a directed graph of computation steps. PySpark provides a way to get a string representation of the plan, to work with a plan as with a real graph data structure you need to use the Scala/Java API of Apache Spark. When you perform an action, like df.count() or df.write, Spark will get your computation graph and make an execution. This is a very simplified view, because in reality there are many different intermediate steps:

  • Transforming the parsed logical plan into an analyzed logical plan by resolving sources and column references;
  • Optimizing the logical plan by applying optimization rules (such as moving filter expressions to the beginning of the plan, or moving select expressions to the source column level);
  • Generate different versions of physical plans based on the same optimized logical plan;
  • Apply cost-based selection of the best physical plan;
  • perform code generation based on the selected physical plan;
  • Execute the code.

For anyone who wants to better understand how spark works with plans and how optimizations can be applied, I highly recommend the book How query engines work by Andy Grow, creator of the Apache Arrow Datafusion.

Getting a string-representation of plan in PySpark

But for now, we just need the parsed logical plan, so let's make a simple Python function that returns it:

  import contextlib
  from pyspark.sql import DataFrame


  def get_logical_plan(df: DataFrame) -> str:
      with contextlib.redirect_stdout(StringIO()) as stdout:
          df.explain(extended=True)

      plan_lines = stdout.getvalue().split("\n")

      start_line = plan_lines.index("== Analyzed Logical Plan ==") + 2
      end_line = plan_lines.index("== Optimized Logical Plan ==")

      return "\n".join(plan_lines[start_line:end_line])

It may look overly complicated, but there is no other way to get a string representation of the analyzed logical plan from PySpark. df.explain returns nothing, instead it prints all plans (analyzed logical, optimized logical, physical) to standard output. That's why we need to use contextlib.redirect_stdout. You can check what the whole output of df.explain looks like. It is broken up by lines like == Analyzed Logical Plan == and similar. Also, the analyzed logical plan always starts from the schema of the DataFrame, so we need to add another line.

Let's see what the plan looks like for our dat_new DataFrame that we created:

  get_logical_plan(dat_new)
  Aggregate [id1_renamed#2430], [id1_renamed#2430, count(distinct hash_of_two_ids#2491) AS cnt_ids#2508L, sum((num1#2445 + num2#2454)) AS sum_col#2510L]
+- Project [id1#1321, id1_renamed#2430, id1_and_id2#2437, num1#2445, num2#2454, hash_id#2469, id4#2480, concat_ws(_, id4#2480, cast(hash_id#2469 as string)) AS hash_of_two_ids#2491]
   +- Project [id1#1321, id1_renamed#2430, id1_and_id2#2437, num1#2445, num2#2454, hash_id#2469, id4#2480]
      +- Join LeftOuter, (id1#1321 = id1#2478)
         :- Project [id1_renamed#2430, id1_and_id2#2437, id1#1321, num1#2445, num2#2454, hash(id1_and_id2#2437, 42) AS hash_id#2469]
         :  +- Project [id1_renamed#2430, id1_and_id2#2437, id1#1321, num1#2445, num2#2454]
         :     +- Filter (rand(-7677477572161899967) <= 0.5)
         :        +- Project [id1#1321, id2#1322, id4#1323, id5#1324, v2#1325, id1_renamed#2430, id1_and_id2#2437, num1#2445, 2 AS num2#2454]
         :           +- Project [id1#1321, id2#1322, id4#1323, id5#1324, v2#1325, id1_renamed#2430, id1_and_id2#2437, 1 AS num1#2445]
         :              +- Project [id1#1321, id2#1322, id4#1323, id5#1324, v2#1325, id1_renamed#2430, concat_ws(_, id1#1321, id2#1322) AS id1_and_id2#2437]
         :                 +- Project [id1#1321, id2#1322, id4#1323, id5#1324, v2#1325, id1#1321 AS id1_renamed#2430]
         :                    +- Relation [id1#1321,id2#1322,id4#1323,id5#1324,v2#1325] csv
         +- Project [id1#2478, id4#2480]
            +- Relation [id1#2478,id2#2479,id4#2480,id5#2481,v2#2482] csv

As you can see, the analyzed logical plan contains all calculation steps from the last one to the first one (Relation ... csv). An important thing is that PySpark adds unique IDs to each column, so the final names in the plan are not real column names, but something like name#unique_id. This will help us a lot when we will create our column lineage parser, because it simplifies all things: you do not need to think about collisions or renaming, because PySpark has already solved all these problems!

Parsing plan to get column-lineage

As you can see, there is a limited list of possible operations:

  • Relation: mapping of columns to files or tables;
  • Project~: any column operation, such as withColumn, withColumnRenamed, select, etc;
  • Filter~: any filter operation;
  • Join~: various types of join operations;
  • Aggregate~: aggregate operations;

There are also some additional cases like Union, but the union operation makes things very complex, so let's decide to avoid it. Just because if a plan contains Union it is very hard to parse it, because a column can appear in any side of a union-like operation…

Defining an output data-structure and user API

First, we need to define what our column lineage will look like and what the data structure representing the lineage will be. By design, the data lineage is a directed acyclic graph (or tree). One of the simplest ways to represent a graph-like structure is simply to use a list of edges (called an adjacency list). Nodes of our graph will contain not only ids, but also some additional information, like the description of the computation step. Let's store the attributes in a dict-like structure. And the API should be very simple: just a function that takes a DataFrame object and a column name. For simplicity, it might also be good to store the list of all nodes in the graph. Let's define the structure and a function signature:

  from dataclasses import dataclass

  @dataclass
  class ColumnLineageGraph:
      """Structure to represent columnar data lineage."""

      nodes: list[int]  # list of hash values that represent nodes
      edges: list[list[int]]  # list of edges in the form of list of pairs
      node_attrs: dict[int, str]  # labels of nodes (expressions)

  def get_column_lineage(df: DataFrame, columns: str) -> ColumnLineageGraph:
      raise NotImplementedError()

Creating recursive parsing function

We will be using a lot of regular expressions and we need to import them first:

  import re
Transforming from graph-nodes to column names

It doesn't really matter that our logical plan is a list of strings. By design and idea, it is the tree structure, and the best way to traverse the tree is, of course, a recursion. Let's create an inner recursive function to traverse the plan:

 def _node2column(node: str) -> str:
      """Inner function. Transform the node from plan to column name.
      Like: col_11#1234L -> col_11.
      """
      match_ = re.match(r"([\w\d]+)#[\w\d]+", node)
      if match_:
          return match_.groups()[0]

We also need a way to get a node ID from the column name. To do this, let's add another simple function:

  def _get_aliases(col: str, line: str) -> tuple[list[str], str]:
      """Inner function. Returns all the aliases from the expr and expr itself."""
      alias_exp = _extract_alias_expressions(col, line)
      # Regexp to extract columns: each column has a pattern like col_name#1234
      return (re.findall(r"[\w\d]+#[\w\d]+", alias_exp), alias_exp)
Parsing ALIAS expressions

One of the most complicated cases in a Spark plan is an alias. You may be faced with the following options:

  • Literal expressions, like 1 AS col#1234;
  • Just an alias, like col1#1234 AS col2#1235;
  • An alias to the expression, like (col1#1234 + col2#1235) AS col3#1236.

And the last one can contain an unlimited number of nested expressions. It is almost impossible to parse such a case via regular expressions, looks like we need to balance parentheses, as in Leetcode easy task. I will use a counter based approach, where we have a counter of unbalanced parentheses and we reach the end of the expression when the counter is zero.

  def _extract_alias_expressions(col: str, line: str) -> str:
      """Inner function. Extract expression before ... AS col from the line."""
      num_close_parentheses = 0 # our counter
      idx = line.index(f" AS {col}") # the end of the alias expression we need to parse
      alias_expr = [] # buffer to store what we are parsing

      if line[idx - 1] != ")":
          """It is possible that there is no expression.
          It is the case when we just make a rename of the column. In the plan
          it will look like `col#123 AS col#321`;
          """
          for j in range(idx - 1, 0, -1):
              alias_expr.append(line[j])
              if line[j - 1] == "[":
                  break
              if line[j - 1] == " ":
                  break
          return "".join(alias_expr)[::-1]

      """In all other cases there will be `(` at the end of the expr before AS.
      Our goal is to go symbol by symbol back until we balance all the parentheses.
      """
      for i in range(idx - 1, 0, -1):
          alias_expr.append(line[i])
          if line[i] == ")":
              # Add parenthesis
              num_close_parentheses += 1
          if line[i] == "(":
              if num_close_parentheses == 1:
                  # Parentheses are balanced
                  break
              # Remove parenthesis
              num_close_parentheses -= 1

      """After balancing parentheses we need to parse leading expression.
      It is always here because we checked single alias case separately."""
      for j in range(i, 0, -1):
          alias_expr.append(line[j])
          if line[j - 1] == "[":
              break
          if line[j - 1] == " ":
              break

      return "".join(alias_expr[::-1])

It may look like magic, so let's check how it works on examples from our real plan representation:

  _extract_alias_expressions(
      "id1_and_id2#2437",
      "Project [id1#1321, id2#1322, id4#1323, id5#1324, v2#1325, id1_renamed#2430, concat_ws(_, id1#1321, id2#1322) AS id1_and_id2#2437]"
  )

And the result is:

  'concat_ws((_, id1#1321, id2#1322)'

Looks like it works! Finally some of the knowledge from the Leetcode tasks was put into practice!

Parsing aggregation-like expressions

In most cases we do not need additional columns from the row of the plan, except for one that we are working with. The only exception is aggregation: it might be good to store information about aggregation keys in our final node attributes. Let's add a simple function to do this:

  def _add_aggr_or_not(expr: str, line: str) -> str:
      """If the expr is aggregation we should add agg keys to the beginning."""
      # We are checking for aggregation pattern
      match_ = re.match(r"^[\s\+\-:]*Aggregate\s\[([\w\d#,\s]+)\].*$", line)
      if match_:
          agg_expr = match_.groups()[0]
          return (
              "GroupBy: " + re.sub(r"([\w\d]+)#([\w\d]+)", r"\1", agg_expr) + f"\n{expr}"
          )

      # If not just return an original expr
      return expr
Building a final recursive parser

Now we have everything we need. So let's go through the logical plan line by line, adding nodes and attributes to our graph structure:

  def _get_graph(lines: list[str], node: str):
      nodes = []
      edges = []
      node_attrs = {}

      for i, l in enumerate(lines):  # noqa: E741
          """Iteration over lines of logical plan."""

          # We should use hash of line + node as a key in the graph.
          # It is not enough to use only hash of line because the same line
          # may be related to multiple nodes!
          # A good example is reading the CSV that is represented by one line!
          h = hash(l + node)

          # If the current node is not root we need to store hash of previous node.
          prev_h = None if not nodes else nodes[-1]

          if node not in l:
              continue
          if f"AS {node}" in l:
              """It is a hard case, when a node is an alias to some expression."""
              aliases, expr = _get_aliases(node, l)
              # For visualization we need to transform from nodes to columns
              expr = re.sub(r"([\w\d]+)#([\w\d]+)", r"\1", expr)

              # Append a new node
              nodes.append(h)
              # Append expr as an attribute of the node
              node_attrs[h] = _add_aggr_or_not(f"{expr} AS {_node2column(node)}", l)

              if len(aliases) == 1:
                  # It is the case of simple alis
                  # Like col1#123 AS col2#321
                  # In this case we just replace an old node by new one.
                  if prev_h:
                      edges.append([h, prev_h])
                  node = aliases[0]
              else:
                  # It is a case of complex expression.
                  # Here we recursively go through all the nodes from expr.
                  if prev_h:
                      edges.append([h, prev_h])
                  for aa in aliases:
                      # Get graph from sub-column
                      sub_nodes, sub_edges, sub_attrs = _get_graph(lines[i:], aa)

                      # Add everything to the current graph
                      nodes.extend(sub_nodes)
                      edges.extend(sub_edges)
                      node_attrs = {**node_attrs, **sub_attrs}

                      # Add connection between top subnode and node
                      edges.append([sub_nodes[0], h])
                  return (nodes, edges, node_attrs)
          else:
              # Continue of the simple alias or expr case
              # In the future that may be more cases, that is the reason of nested if instead of elif
              if "Relation" in l:
                  nodes.append(h)
                  if prev_h:
                      edges.append([h, prev_h])

                  # It is a pattern, related to data-sources (like CSV)
                  match_ = re.match(r"[\s\+\-:]*Relation\s\[.*\]\s(\w+)", l)
                  if match_:
                      s_ = "Read from {}: {}"
                      # Add data-source as a node
                      node_attrs[h] = s_.format(match_.groups()[0], _node2column(node))
                  else:
                      # We need it to avoid empty graphs and related runtime exceptions
                      print(l)
                      node_attrs[h] = f"Relation to: {_node2column(node)}"

              elif "Join" in l:
                  nodes.append(h)
                  if prev_h:
                      edges.append([h, prev_h])
                  match_ = re.match(r"[\s\+\-:]*Join\s(\w+),\s\((.*)\)", l)
                  if match_:
                      join_type = match_.groups()[0]
                      join_expr = match_.groups()[1]
                      join_expr_clr = re.sub(r"([\w\d]+)#([\w\d]+)", r"\1", join_expr)
                      node_attrs[h] = f"{join_type}: {join_expr_clr}"
              else:
                  continue

      if not nodes:
          # Just the case of empty return. We need to avoid it.
          # I'm not sure that line is reachable.
          nodes.append(h)
          node_attrs[h] = f"Select: {_node2column(node)}"

      return (nodes, edges, node_attrs)
All together

Now we are ready to put all the pieces together into a single function:

  def get_column_lineage(df: DataFrame, column: str) -> ColumnLineageGraph:
      """Get data lineage on the level of the given column.

      Currently Union operation is not supported! API is unstable, no guarantee
      that custom spark operations or connectors won't break it!

      :param df: DataFrame
      :param column: column
      :returns: Struct with nodes, edges and attributes
      """
      lines = get_plan_from_df(df, PlanType.ANALYZED_LOGICAL_PLAN).split("\n")

      # Top line should contain plan-id of our column. We need it.
      # Regular pattern of node is column#12345L or [\w\d]+#[\w\d]+
      match_ = re.match(r".*(" + column + r"#[\w\d]+).*", lines[0])
      if match_:
          node = match_.groups()[0]
      else:
          err = f"There is no column {column} in the final schema of DF!"
          raise KeyError(err)

      nodes, edges, attrs = _get_graph(lines, node)

      return ColumnLineageGraph(nodes, edges, attrs)

Testing and drawing our implementation

Let's see how our function works:

  get_column_lineage(dat_new, "cnt_ids")

Will produce the following:

  ColumnLineageGraph(nodes=[-3047688324833821294, 8934572903754805890, -22248459158511064, -3092611391038289840, 1490298382268190732, -6431655222193019101, -1002279244933706460], edges=[[8934572903754805890, -3047688324833821294], [-22248459158511064, 8934572903754805890], [1490298382268190732, -3092611391038289840], [-6431655222193019101, 1490298382268190732], [-1002279244933706460, 1490298382268190732], [-3092611391038289840, 8934572903754805890]], node_attrs={-3047688324833821294: 'GroupBy: id1_renamed\ncount((distinct hash_of_two_ids) AS cnt_ids', 8934572903754805890: 'concat_ws((_, id4, cast(hash_id as string)) AS hash_of_two_ids', -22248459158511064: 'Read from csv: id4', -3092611391038289840: 'hash((id1_and_id2, 42) AS hash_id', 1490298382268190732: 'concat_ws((_, id1, id2) AS id1_and_id2', -6431655222193019101: 'Read from csv: id1', -1002279244933706460: 'Read from csv: id2'})

Looks like it works, at least in our simple case.

Drawing the graph

To draw the graph as a tree, let's use the Python library NetworkX. And GraphViz as the drawing engine. You need to install the following packages to use it:

  • networkx
  • pygraphviz
  • matplotlib
  def plot_column_lineage_graph(
      df: DataFrame,
      column: str,
  ) -> "matplotlib.pyplot.Figure":
      """Plot the column lineage graph as matplotlib figure.

      :param df: DataFrame
      :param column: column
      :returns: matplotlib.pyplot.Figure
      """
      try:
          import networkx as nx
          from networkx.drawing.nx_agraph import graphviz_layout
      except ModuleNotFoundError as e:
          err = "NetworkX is not installed. Try `pip install networkx`. "
          err += (
              "You may use `get_column_lineage` instead, that doesn't require NetworkX."
          )
          raise ModuleNotFoundError(err) from e

      try:
          import matplotlib.pyplot as plt
      except ModuleNotFoundError as e:
          err = "You need matplotlib installed to draw the Graph"
          raise ModuleNotFoundError(err) from e

      import importlib

      if not importlib.util.find_spec("pygraphviz"):
          err = "You need to have pygraphviz installed to draw the Graph"
          raise ModuleNotFoundError(err)

      lineage = get_column_lineage(df, column)
      g = nx.DiGraph()

      g.add_nodes_from(lineage.nodes)
      g.add_edges_from(lineage.edges)

      pos = graphviz_layout(g, prog="twopi")
      pos_attrs = {}
      for node, coords in pos.items():
          pos_attrs[node] = (coords[0], coords[1] + 10)
      nx.draw(g, pos=pos)
      nx.draw_networkx_labels(g, labels=lineage.node_attrs, pos=pos_attrs, clip_on=False)

      return plt.gcf()

If we run it, we get the following:

  import matplotlib.pyplot as plt
  col = "cnt_ids"
  f = plot_column_lineage_graph(dat_new, col)
  f.show()

Column lineage

Looks exactly as what we need!

Afterwards

This functionality is mostly for educational purposes, to better understand how Spark Plan is organized. Another possible use case is if you need some simple inline Python code for this task. For real production data lineage on top of Spark, I recommend using a Spline Project!