Use a Case Statement in PySpark (With Example)


A case statement is a type of statement that goes through conditions and returns a value when the first condition is met.

The easiest way to implement a case statement in a PySpark DataFrame is by using the following syntax:

from pyspark.sql.functions import when

df.withColumn('class',when(df.points<9, 'Bad').when(df.points<12, 'OK').when(df.points<15, 'Good').otherwise('Great')).show()

This particular example adds a new column to a DataFrame called class that takes on the following values:

  • Bad if the value in the points column is less than 9
  • OK if the value in the points column is less than 12
  • Good if the value in the points column is less than 15
  • Great if none of the previous conditions are true

The following example shows how to use this function in practice.

Example: How to Use a Case Statement in PySpark

Suppose we have the following PySpark DataFrame:

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

#define data
data = [['A', 6],
        ['B', 8], 
        ['C', 9], 
        ['D', 9], 
        ['E', 12], 
        ['F', 14],
        ['G', 15],
        ['H', 17],
        ['I', 19],
        ['J', 22]] 
  
#define column names
columns = ['player', 'points'] 
  
#create dataframe using data and column names
df = spark.createDataFrame(data, columns) 
  
#view dataframe
df.show()

+------+------+
|player|points|
+------+------+
|     A|     6|
|     B|     8|
|     C|     9|
|     D|     9|
|     E|    12|
|     F|    14|
|     G|    15|
|     H|    17|
|     I|    19|
|     J|    22|
+------+------+

We can use the following syntax to write a case statement that creates a new column called class whose values are determined by the values in the points column:

from pyspark.sql.functions import when

df.withColumn('class',when(df.points<9, 'Bad').when(df.points<12, 'OK').when(df.points<15, 'Good').otherwise('Great')).show()

+------+------+-----+
|player|points|class|
+------+------+-----+
|     A|     6|  Bad|
|     B|     8|  Bad|
|     C|     9|   OK|
|     D|     9|   OK|
|     E|    12| Good|
|     F|    14| Good|
|     G|    15|Great|
|     H|    17|Great|
|     I|    19|Great|
|     J|    22|Great|
+------+------+-----+

The case statement looked at the value in the points column and returned:

  • Bad if the value in the points column was less than 9
  • OK if the value in the points column was less than 12
  • Good if the value in the points column was less than 15
  • Great if none of the previous conditions are true

Note: We chose to use three conditions in this particular example but you can chain together as many when() statements as you’d like to include even more conditions in your own case statement.

x