How can I calculate a rolling mean in PySpark?

Calculating a rolling mean in PySpark refers to the process of finding the average value of a given set of data points over a specific number of time periods. This is often used in time series analysis to smooth out the data and identify trends. To calculate a rolling mean in PySpark, you can use the window function in conjunction with the mean function to partition the data into smaller time frames and calculate the average within each frame. This method allows for efficient and accurate calculation of rolling means in large datasets.

Calculate a Rolling Mean in PySpark


You can use the following syntax to calculate a rolling mean in a PySpark DataFrame:

from pyspark.sql import Window
from pyspark.sql import functions as F

#define window for calculating rolling mean
w = (Window.orderBy('day').rowsBetween(-3, 0))

#create new DataFrame that contains 4-day rolling mean column
df_new = df.withColumn('rolling_mean', F.avg('sales').over(w))

This particular example creates a new column that contains the 4-day rolling average of values in the sales column of the DataFrame.

The following example shows how to use this syntax in practice.

Example: How to Calculate a Rolling Mean in PySpark

Suppose we have the following PySpark DataFrame that contains information about the sales made during 10 consecutive days at some grocery store:

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

#define data
data = [[1, 11], 
        [2, 8], 
        [3, 4], 
        [4, 5], 
        [5, 5], 
        [6, 8],
        [7, 7],
        [8, 7],
        [9, 6],
        [10, 4]] 
  
#define column names
columns = ['day', 'sales']

#create dataframe using data and column names
df = spark.createDataFrame(data, columns) 
  
#view dataframe
df.show()

+---+-----+
|day|sales|
+---+-----+
|  1|   11|
|  2|    8|
|  3|    4|
|  4|    5|
|  5|    5|
|  6|    8|
|  7|    7|
|  8|    7|
|  9|    6|
| 10|    4|
+---+-----+

We can use the following syntax to calculate the 4-day rolling mean of values in the sales column:

from pyspark.sql import Window
from pyspark.sql import functions as F

#define window for calculating rolling mean
w = (Window.orderBy('day').rowsBetween(-3, 0))

#create new DataFrame that contains 4-day rolling mean column
df_new = df.withColumn('rolling_mean', F.avg('sales').over(w))

#view new DataFrame
df_new.show()

+---+-----+-----------------+
|day|sales|     rolling_mean|
+---+-----+-----------------+
|  1|   11|             11.0|
|  2|    8|              9.5|
|  3|    4|7.666666666666667|
|  4|    5|              7.0|
|  5|    5|              5.5|
|  6|    8|              5.5|
|  7|    7|             6.25|
|  8|    7|             6.75|
|  9|    6|              7.0|
| 10|    4|              6.0|
+---+-----+-----------------+

The resulting DataFrame contains a new column called rolling_mean that shows the rolling mean of the values in the sales column for the most recent 4 days.

For example, the rolling mean of values in the sales column on day 4 is calculated as: 

Rolling Mean = (11 + 8 + 4 + 5) / 4 = 7

And the rolling mean of values in the sales column on day 5 is calculated as: 

Rolling Mean = (8 + 4 + 5 + 5) / 4 = 5.5

And so on.

Note that you can calculate a rolling average using a different number of previous periods by simply changing the first value in the rowsBetween function.

For example, you could instead calculate a 5-day rolling average by using rowsBetween(-4, 0) instead.

Additional Resources

The following tutorials explain how to perform other common tasks in PySpark:

x