Can you provide an example of using a case statement in PySpark?

A case statement in PySpark is a conditional statement that allows for the execution of different code blocks based on the evaluation of a specific expression. It follows a specific syntax of “when/otherwise” clauses, where each “when” clause checks for a specific condition and executes the corresponding code block if the condition is met. The “otherwise” clause serves as a default option if none of the “when” conditions are satisfied. An example of using a case statement in PySpark could be to categorize a dataset based on a specific column value, where each category is assigned a different label or value. This allows for efficient and concise conditional logic in data transformation and manipulation tasks.

Use a Case Statement in PySpark (With Example)


A case statement is a type of statement that goes through conditions and returns a value when the first condition is met.

The easiest way to implement a case statement in a PySpark DataFrame is by using the following syntax:

from pyspark.sql.functions import when

df.withColumn('class',when(df.points<9, 'Bad').when(df.points<12, 'OK').when(df.points<15, 'Good').otherwise('Great')).show()

This particular example adds a new column to a DataFrame called class that takes on the following values:

  • Bad if the value in the points column is less than 9
  • OK if the value in the points column is less than 12
  • Good if the value in the points column is less than 15
  • Great if none of the previous conditions are true

The following example shows how to use this function in practice.

Example: How to Use a Case Statement in PySpark

Suppose we have the following PySpark DataFrame:

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

#define data
data = [['A', 6],
        ['B', 8], 
        ['C', 9], 
        ['D', 9], 
        ['E', 12], 
        ['F', 14],
        ['G', 15],
        ['H', 17],
        ['I', 19],
        ['J', 22]] 
  
#define column names
columns = ['player', 'points'] 
  
#create dataframe using data and column names
df = spark.createDataFrame(data, columns) 
  
#view dataframe
df.show()

+------+------+
|player|points|
+------+------+
|     A|     6|
|     B|     8|
|     C|     9|
|     D|     9|
|     E|    12|
|     F|    14|
|     G|    15|
|     H|    17|
|     I|    19|
|     J|    22|
+------+------+

We can use the following syntax to write a case statement that creates a new column called class whose values are determined by the values in the points column:

from pyspark.sql.functions import when

df.withColumn('class',when(df.points<9, 'Bad').when(df.points<12, 'OK').when(df.points<15, 'Good').otherwise('Great')).show()

+------+------+-----+
|player|points|class|
+------+------+-----+
|     A|     6|  Bad|
|     B|     8|  Bad|
|     C|     9|   OK|
|     D|     9|   OK|
|     E|    12| Good|
|     F|    14| Good|
|     G|    15|Great|
|     H|    17|Great|
|     I|    19|Great|
|     J|    22|Great|
+------+------+-----+

The case statement looked at the value in the points column and returned:

  • Bad if the value in the points column was less than 9
  • OK if the value in the points column was less than 12
  • Good if the value in the points column was less than 15
  • Great if none of the previous conditions are true

Note: We chose to use three conditions in this particular example but you can chain together as many when() statements as you’d like to include even more conditions in your own case statement.

Additional Resources

The following tutorials explain how to perform other common tasks in PySpark:

x