Open In App

Split Spark DataFrame based on condition in Python

In this article, we are going to learn how to split data frames based on conditions using Pyspark in Python.

Spark data frames are a powerful tool for working with large datasets in Apache Spark. They allow to manipulate and analyze data in a structured way, using SQL-like operations. Sometimes, we may want to split a Spark DataFrame based on a specific condition.  For example, we may want to split a DataFrame into two separate DataFrames based on whether a column value is greater than or less than a certain threshold.



Why split a data frame based on a condition? 

There are a few common reasons to split a data frame based on a condition:

Splitting the PySpark data frame using the filter() method

The filter() method is used to return a new data frame that contains only the rows that match a specified condition passed in the filter() function as parameters. 



Syntax : 

df.filter(condition)

  • Where df is the name of the DataFrame and condition is a boolean expression that specifies the condition to be true or false.

Problem statement: Given a CSV file containing information about people, such as their name, age, and gender, the task is to split the data into two data frames based on the gender of the person. The first data frame should contain the rows where the gender is male, and the second data frame should contain rows where the gender is female. Below is the stepwise implementation to perform this task:

Step 1: The first step is to create a SparkSession, which is the entry point to using Spark functionality. We give it the name “Split DataFrame” for reference.

Step 2: Next, we use the spark.read.csv() method to load the data from the “number.csv” file into a data frame. We specify that the file has a header row and that we want Spark to infer the schema of the data.

Step 3: We then use the filter() method on the data frame to split it into two new data frames based on a certain condition. In this case, we use the condition df[‘gender’] == ‘Male’ to filter the data frame and create a new data frame called males_df containing only rows with a gender of ‘Male’. Similarly, we use the condition df[‘gender’] == ‘Female’ to filter the data frame and create a new data frame called females_df containing only rows with a gender of ‘Female’.
Step 4: Finally, we use the show() method to print the contents of the males_df and females_df data frames.

Dataset: number.csv 




# Import required modules
from pyspark.sql import SparkSession
  
# Create a SparkSession
spark = SparkSession.builder.appName("Split DataFrame").getOrCreate()
  
# Load the data into a DataFrame
df = spark.read.csv("number.csv",
                    header=True,
                    inferSchema=True)
df.show()
  
# Split the DataFrame into two 
# DataFrames based on a condition
males_df = df.filter(df['gender'] == 'Male')
females_df = df.filter(df['gender'] == 'Female')
  
# Print the dataframes 
males_df.show()
females_df.show()

Output before split:

 

Output after split: 

 

Alternatively, we can also use where() method for filter, for example:

males_df = df.where(df['gender'] == 'Male')
females_df = df.where(df['gender'] == 'Female')

Handling the Split DataFrames

Once we have split the data frame, we can perform further operations on the resulting data frames, such as aggregating the data, joining with other tables, or saving the data to a new file. Here is an example of how to use the count() method to get the number of rows in each of the splitted data frames:




# Import required modules
from pyspark.sql import SparkSession
  
# Create a SparkSession
spark = SparkSession.builder.appName("Split DataFrame").getOrCreate()
  
# Load the data into a DataFrame
df = spark.read.csv("number.csv",
                    header=True,
                    inferSchema=True)
  
# Split the DataFrame into two 
# data frame based on a condition
males_df = df.filter(df['gender'] == 'Male')
females_df = df.filter(df['gender'] == 'Female')
  
# Print the data frames
males_df.show()
females_df.show()
  
# Print the count
print("Males:", males_df.count())
print("Females:", females_df.count())

Output : 

 


Article Tags :