How to Split Data into Training & Test Sets in PySpark?

In PySpark, splitting data into training and test sets is a two-step process. First, you must use the randomSplit() method to create two datasets, one for training and one for testing. Then, you can use the appropriate methods to apply the training and testing data to the model. This ensures that the model is tested on data that it has not seen before and helps to reduce overfitting.


Often when we fit to datasets, we first split the dataset into a training set and a test set.

The easiest way to split a dataset into a training and test set in PySpark is to use the randomSplit function as follows:

train_df, test_df = df.randomSplit(weights=[0.7,0.3], seed=100)

The weights argument specifies the percentage of from the original DataFrame to place in the training and test set, respectively.

In this example, we chose to place 70% of the observations into the training set and 30% in the test set.

The seed argument is an integer that is used to ensure that the random split is the same each time you run the code. 

The following example shows how to split a PySpark DataFrame into a training and test set in practice.

Example: Split Data into Training and Test Set in PySpark

First, let’s create the following PySpark DataFrame that contains information about hours spent studying, number of prep exams taken, and final exam score for various students at some university:

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

#define data
data = [[1, 1, 76],
        [2, 3, 78],
        [2, 3, 85],
        [4, 5, 88],
        [2, 2, 72],
        [1, 2, 69],
        [5, 1, 94],
        [4, 1, 94],
        [2, 0, 88],
        [4, 3, 92],
        [4, 4, 90],
        [3, 3, 75],
        [6, 2, 96],
        [5, 4, 90],
        [3, 4, 82],
        [4, 4, 85],
        [6, 5, 99],
        [2, 1, 83],
        [1, 0, 62],
        [2, 1, 76]]
  
#define column names
columns = ['hours', 'prep_exams', 'score'] 
  
#create dataframe using data and column names
df = spark.createDataFrame(data, columns) 
  
#view first five rows of dataframe
df.limit(5).show()

+-----+----------+-----+
|hours|prep_exams|score|
+-----+----------+-----+
|    1|         1|   76|
|    2|         3|   78|
|    2|         3|   85|
|    4|         5|   88|
|    2|         2|   72|
+-----+----------+-----+

Suppose we would like to fit a model to this dataset, using hours and prep_exams as the predictor variables and score as the response variable.

Before we do so, we may first want to randomly split the dataset so that 70% of the total rows are used for training and 30% are used for testing.

We can use the following syntax to do so:

#split dataset into training and test sets
train_df, test_df = df.randomSplit(weights=[0.7,0.3], seed=100)

We can then use the count() function to view the number of rows in each resulting dataset:

#view count of rows in train_df
print(train_df.count())

14

#view count of rows in test_df
print(test_df.count())

6

We can see that 14 of the 20 (70%) original rows are used for the training set.

If we’d like, we can also view the first five rows of both the training and test sets:

#view first five rows of training set
train_df.limit(5).show()

+-----+----------+-----+
|hours|prep_exams|score|
+-----+----------+-----+
|    1|         1|   76|
|    2|         3|   78|
|    2|         3|   85|
|    4|         5|   88|
|    1|         2|   69|
+-----+----------+-----+

#view first five rows of test set
test_df.limit(5).show()

+-----+----------+-----+
|hours|prep_exams|score|
+-----+----------+-----+
|    2|         2|   72|
|    2|         0|   88|
|    4|         1|   94|
|    3|         4|   82|
|    4|         4|   85|
+-----+----------+-----+

We have successfully split the original dataset into a training and test set.

We can now proceed to fit whatever model we’d like to the training set and then test the performance of the model on the test set.

Note: You can find the complete documentation for the PySpark randomSplit function .

The following tutorials explain how to perform other common tasks in PySpark:

x