How can Lag be calculated by group using PySpark?


You can use the following syntax to calculate lagged values by group in a PySpark DataFrame:

from pyspark.sql.window import Window
from pyspark.sql.functions import lag

#specify grouping and ordering variables
w  = Window.partitionBy('store').orderBy('day')

#calculate lagged sales by group
df_new = df.withColumn('lagged_sales', lag(df.sales,1).over(w))

This particular example creates a new column called lagged_sales that contains the lagged values from the sales column in the DataFrame, grouped by the values in the store column.

The following example shows how to use this syntax in practice.

Example: How to Calculate Lagged Values by Group in PySpark

Suppose we have the following PySpark DataFrame that contains information about sales made during consecutive days at two different stores:

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

#define data
data = [['A', 1, 18], 
        ['A', 2, 33], 
        ['A', 3, 12], 
        ['A', 4, 15], 
        ['A', 5, 19],
        ['B', 1, 24],
        ['B', 2, 28],
        ['B', 3, 40],
        ['B', 4, 24],
        ['B', 5, 13]]
  
#define column names
columns = ['store', 'day', 'sales'] 
  
#create dataframe using data and column names
df = spark.createDataFrame(data, columns) 
  
#view dataframe
df.show()

We can use the following syntax to calculate the lagged values in the sales column, grouped by the values in the store column:

from pyspark.sql.window import Window
from pyspark.sql.functions import lag

#specify grouping and ordering variables
w  = Window.partitionBy('store').orderBy('day')

#calculate lagged sales by group
df_new = df.withColumn('lagged_sales', lag(df.sales,1).over(w))

#view new DataFrame
df_new.show()

+-----+---+-----+------------+
|store|day|sales|lagged_sales|
+-----+---+-----+------------+
|    A|  1|   18|        null|
|    A|  2|   33|          18|
|    A|  3|   12|          33|
|    A|  4|   15|          12|
|    A|  5|   19|          15|
|    B|  1|   24|        null|
|    B|  2|   28|          24|
|    B|  3|   40|          28|
|    B|  4|   24|          40|
|    B|  5|   13|          24|
+-----+---+-----+------------+

The new column named lagged_sales shows the lagged sales values for each store.

For example:

  • The first value in the lagged_sales column is null since there is no prior value in the sales column.
  • The second value in the lagged_sales column is 18 since this is the prior value in the sales column.

And so on.

If you’d like, you can use the fillna function to replace the null values in the lagged_sales column with zero:

#replace null values with 0 in lagged_sales column
df_new.fillna(0, 'lagged_sales').show()

+-----+---+-----+------------+
|store|day|sales|lagged_sales|
+-----+---+-----+------------+
|    A|  1|   18|           0|
|    A|  2|   33|          18|
|    A|  3|   12|          33|
|    A|  4|   15|          12|
|    A|  5|   19|          15|
|    B|  1|   24|           0|
|    B|  2|   28|          24|
|    B|  3|   40|          28|
|    B|  4|   24|          40|
|    B|  5|   13|          24|
+-----+---+-----+------------+

Each of the null values in the lagged_sales column have now been replaced with zero.

Note: You can find the complete documentation for the PySpark lag function .

Additional Resources

The following tutorials explain how to perform other common tasks in PySpark:

x