Open In App

How to re-partition pyspark dataframe in Python

Last Updated : 04 Dec, 2022
Improve
Improve
Like Article
Like
Save
Share
Report

Are you a data science or machine learning enthusiast who likes to play with data? Have you ever got the need to repartition the Pyspark dataset you got? Got confused, about how to fulfill the demand? Don’t worry! In this article, we will discuss the re-partitioning of the Pyspark data frame in Python.

Modules Required:

  • Pyspark: spark library which has the ability to run Python applications using Apache Spark is known as Pyspark. This module can be installed through the following command in Python:
pip install pyspark

Stepwise Implementation:

Step 1: First of all, import the required libraries, i.e. SparkSession, and spark_partition_id. The SparkSession library is used to create the session.

from pyspark.sql import SparkSession

Step 2: Now, create a spark session using the getOrCreate function.

spark_session = SparkSession.builder.getOrCreate()

Step 3: Then, read the CSV file and display it to see if it is correctly uploaded.

data_frame=csv_file = spark_session.read.csv('#Path of CSV file', sep = ',', inferSchema = True, header = True)
data_frame.show()

Step 4: Next, obtain the number of RDD partitions in the data frame before the repartition of data using the getNumPartitions function.

print(data_frame.rdd.getNumPartitions())

Step 5: Finally, repartition the data using the select and repartition function where the select function will contain the column names that need to be partitioned while the repartition function will contain the number of partitions to be done.

df_partition=data_frame.select(#Column names which need to be partitioned).repartition(#Number of partitions)

Step 6: Finally, obtain the number of RDD partitions in the data frame after the repartition of data using the getNumPartitions function. It is basically done in order to see if the repartition has been done successfully.

print(data_frame_partition.rdd.getNumPartitions())

We have read the CSV file (link) in this example and obtained the current number of partitions. Further, we have repartitioned that data into 2 partitions, i.e., longitude, and latitude, and again get the current number of partitions of the new partitioned data to check if it is correctly partitioned.

Python




# Python program to repartition
# Pyspark dataframe
  
# Import the SparkSession library
from pyspark.sql import SparkSession
  
# Create a spark session using getOrCreate() function
spark_session = SparkSession.builder.getOrCreate()
  
# Read the CSV file
data_frame = csv_file = spark_session.read.csv('california_housing_train.csv',
                                               sep=',', inferSchema=True,
                                               header=True)
  
# Display the csv file read
print(data_frame.head()
      )
# Get number of partitions in data frame using getNumPartitions function
print(" Before repartition", data_frame.rdd.getNumPartitions())
  
# Repartition the CSV file by longitude, latitude columns
data_frame_partition = data_frame.select(data_frame.longitude,
                                         data_frame.latitude).repartition(4)
  
# Get number of partitions in data frame using getNumPartitions function
print(" After repartition", data_frame_partition.rdd.getNumPartitions())


Output:

Row(longitude=-114.31, latitude=34.19, housing_median_age=15.0, 
total_rooms=5612.0, total_bedrooms=1283.0, population=1015.0, households=472.0, 
median_income=1.4936, median_house_value=66900.0)

Before repartition 1
After repartition 4


Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads