Open In App

How to See Record Count Per Partition in a pySpark DataFrame

Improve
Improve
Like Article
Like
Save
Share
Report

The API which was introduced to support Spark and Python language and has features of Scikit-learn and Pandas libraries of Python is known as Pyspark. Whenever we upload any file in the Pyspark, it creates a partition of that data equal to the number of cores. The user can repartition that data and divide it into as many partitions as he wants. Thus, after partitioning, if he wants to know how many records exist in his every partition. He can achieve it using the function of the Pyspark module.

How to See Record Count Per Partition in a pySpark DataFrame

Modules Required:

Pyspark: The API which was introduced to support Spark and Python language and has features of Scikit-learn and Pandas libraries of Python is known as Pyspark. This module can be installed through the following command in Python:

pip install pyspark

Stepwise Implementation:

Step 1: First of all, import the required libraries, i.e. SparkSession, and spark_partition_id. The SparkSession library is used to create the session while spark_partition_id is used to get the record count per partition.

from pyspark.sql import SparkSession
from pyspark.sql.functions import spark_partition_id

Step 2: Now, create a spark session using the getOrCreate function.

spark_session = SparkSession.builder.getOrCreate()

Step 3: Then, read the CSV file and display it to see if it is correctly uploaded.

data_frame=csv_file = spark_session.read.csv('#Path of CSV file', sep = ',', inferSchema = True, header = True)
data_frame.show()

Step 4: Moreover, get the number of partitions using the getNumPartitions function.

print(data_frame.rdd.getNumPartitions())

Step 5: Next, get the record count per partition using the spark_partition_id function.

data_frame.withColumn("partitionId",spark_partition_id()).groupBy("partitionId").count().show()

Step 6: Then, repartition the data using the select and repartition function where the select function will contain the column names that need to be partitioned while the repartition function will contain the number of partitions to be done.

data_frame_partition=data_frame.select(#Column names which need to be partitioned).repartition(#Number of partitions)

Step 7: Later on, obtain the number of RDD partitions in the data frame after the repartition of data using the getNumPartitions function. It is basically done in order to see if the repartition has been done successfully.

print(data_frame_partition.rdd.getNumPartitions())

Step 8: Finally, get the record count per partition using the spark_partition_id function after the repartition of data.

data_frame_partition.withColumn("partitionId",spark_partition_id()).groupBy("partitionId").count().show()

Example 1

In this example, we have read the CSV file (link), i.e., the dataset of 5×5, and obtained the number of partitions as well as the record count per transition using the spark_partition_id function. Further, we have repartitioned that data and again get the number of partitions as well as the record count per transition of the new partitioned data.

Python3




# Python program to see Record Count
# Per Partition in a pySpark DataFrame
  
# Import the SparkSession and spark_session_id library
from pyspark.sql import SparkSession
from pyspark.sql.functions import spark_partition_id
  
# Create a spark session using getOrCreate() function
spark_session = SparkSession.builder.getOrCreate()
  
# Read the CSV file
data_frame = csv_file = spark_session.read.csv(
    '/content/student_data.csv', sep=',', inferSchema=True, header=True)
  
# Get number of partitions in data frame using getNumPartitions function
print(data_frame.rdd.getNumPartitions())
  
# Get record count per partition using spark_partition_id function
data_frame.withColumn("partitionId", spark_partition_id()
                      ).groupBy("partitionId").count().show()
  
# Repartition the CSV file by name and age columns
data_frame_partition = data_frame.select(
    data_frame.name, data_frame.age).repartition(2)
  
# Get number of partitions in data frame using getNumPartitions function
print(data_frame_partition.rdd.getNumPartitions())
  
# Get record count per partition using spark_partition_id function
data_frame_partition.withColumn("partitionId", spark_partition_id()).groupBy(
    "partitionId").count().show()


Output:

+-----------+-----+
|partitionId|count|
+-----------+-----+
|          0|    5|
+-----------+-----+

2

+-----------+-----+
|partitionId|count|
+-----------+-----+
|          0|    3|
|          1|    2|
+-----------+-----+

Example 2:

In this example, we have read the CSV file (link) and obtained the number of partitions as well as the record count per transition using the spark_partition_id function. 

Python




# Python program to see Record Count
# Per Partition in a pySpark DataFrame
  
# Import the SparkSession, spark_partition_id libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import spark_partition_id
  
# Create a spark session using getOrCreate() function
spark_session = SparkSession.builder.getOrCreate()
  
# Read the CSV file
data_frame = csv_file = spark_session.read.csv(
    '/content/sample_data/california_housing_train.csv'
  sep=',', inferSchema=True, header=True)
  
# Get number of partitions in data frame using getNumPartitions function
print(data_frame.rdd.getNumPartitions())
  
# Get record count per partition using spark_partition_id function
data_frame.withColumn("partitionId", spark_partition_id()
                      ).groupBy("partitionId").count().show()


Output:

1
+-----------+-----+
|partitionId|count|
+-----------+-----+
|          0|17000|
+-----------+-----+


Last Updated : 04 Dec, 2022
Like Article
Save Article
Previous
Next
Share your thoughts in the comments
Similar Reads