Open In App

Apply function to each row of Spark DataFrame

Last Updated : 08 Jun, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

Spark is an open-source, distributed computing system used for processing large data sets across a cluster of computers. It has become increasingly popular due to its ability to handle the big data processing in real-time. Spark’s DataFrame API, which offers a practical and effective method for carrying out data manipulation operations, is one of its key features. We may frequently need to process each row of a Spark DataFrame individually. This can be helpful for a variety of applications, including data transformations, feature engineering, and data cleansing.

In this article, we will discuss how to apply a function to each row of a Spark DataFrame. This is a common operation that is required when performing data cleaning, data transformation, or data analysis tasks.

Required Modules :

pip install pyspark

Concepts Related To the Topic :

Before we dive into the steps for applying a function to each row of a Spark DataFrame, let’s briefly go over some of the key concepts involved.

  1. Spark DataFrame: A DataFrame is a distributed collection of data organized into named columns. It is similar to a table in a relational database or a data frame in R or Python.  DataFrames can be created from structured data files, Hive tables, external databases, or RDDs. It is an immutable distributed collection of data that provides a relational view of data.
  2. Row: A Row is a collection of named data items in a data frame. It is an immutable data structure that represents a single row of data in a data frame.
  3. map() function: The map() function is a higher-order function that takes a function as an argument and applies it to each element in a collection. It returns a new collection with the same number of elements as the original collection.  
  4. RDD: RDD stands for Resilient Distributed Datasets. It is the fundamental data structure in Spark that represents an immutable, distributed collection of objects. RDDs can be created from local data, external storage systems, or other RDDs.
  5. UDF: A User-Defined Function (UDF) is a function that is defined by the user to perform a specific task. In Spark, UDFs can be used to apply custom functions to the data in a DataFrame or RDD.

Steps Needed :

Now that we have a basic understanding of the concepts involved, let’s look at the steps for applying a function to each row of a Spark DataFrame.

  1. Define the function: The first step is to define the function that you want to apply to each row of the data frame. The function should take a single argument, which is a row of the DataFrame.
     
  2. Convert DataFrame to RDD: The next step is to convert the DataFrame to an RDD. This can be done using the rdd method of the DataFrame.
     
  3. Apply the function to each row: Once you have an RDD, you can use the map method to apply the function to each row of the RDD.
     
  4. Convert RDD back to DataFrame: After applying the function to each row of the RDD, you can convert the RDD back to a DataFrame using the toDF method.
     
  5. Display the results: Finally, you can display the results of the operation using the show method of the DataFrame.

Examples :

Example 1 :

Add a new column to the DataFrame that is the sum of two existing columns.

Python3




# Import necessary libraries
from pyspark.sql.functions import col
  
# Create a sample DataFrame
data = [(1, 2), (3, 4), (5, 6)]
df = spark.createDataFrame(data, ["col1", "col2"])
  
# Define a function to add two columns
def add_columns(row):
    return (row[0], row[1], row[0] + row[1])
  
# Apply the function to each row of the DataFrame
new_rdd = df.rdd.map(add_columns)
  
# Convert the RDD back to a DataFrame
new_df = new_rdd.toDF(["col1", "col2", "sum"])
  
# Display the results
new_df.show()


Output :
 

Output

Example 2 :

Suppose we have a DataFrame sales_data with columns date, salesperson, and sales. We want to apply a custom function to each row to calculate the commission earned by each salesperson based on their sales, using the following formula:

            1. If sales <= 1000, commission = 0.05 * sales
            2. If 1000 < sales <= 5000, commission = 0.1 * sales
            3. If sales > 5000, commission = 0.15 * sales 
We have used the @udf decorator to define a vector UDF that takes the salesperson column as input and returns the total sales and commission earned for each salesperson, grouped by month.

Python3




from pyspark.sql.functions import udf
from pyspark.sql.types import StructType, StructField
from pyspark.sql.types import StringType, FloatType
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
  
# Define a Python function to calculate the 
#commission for each salesperson in each order
@udf(FloatType())
def calculate_commission(sales):
    if sales > 5000:
        return sales * 10000.0
    else:
        return sales * 100000.0
  
# Define a Python function to extract the month from the date
@udf(StringType())
def extract_month(date):
    return date.split('-')[1]
  
# Create a DataFrame from the sales data
sales_data = spark.createDataFrame([
    ("Alice", "2022-01-01", 5000),
    ("Alice", "2022-01-01", 7000),
    ("Alice", "2022-02-01", 5000),
    ("Bob", "2022-01-01", 2000),
    ("Bob", "2022-01-01", 4000),
    ("Bob", "2022-02-01", 2000)
], ["salesperson", "date", "sales"])
  
# Apply the UDFs to each row of the DataFrame and group 
#by salesperson and month
sales_data = sales_data.withColumn('commission'
      , calculate_commission('sales'))
sales_data = sales_data.withColumn('month'
    , extract_month('date'))
sales_data = sales_data.groupBy('salesperson'
    , 'month').agg(
    {'sales': 'sum', 'commission': 'sum'})
  
# Rename the aggregated columns to 'total_sales' and 
#'total_commission' and sort by salesperson and month
sales_data = sales_data.withColumnRenamed('sum(sales)'
      , 'total_sales')
sales_data = sales_data.withColumnRenamed('sum(commission)'
      , 'total_commission')
sales_data = sales_data.orderBy(['salesperson'
      , 'month'])
  
# Show the resulting DataFrame
sales_data.show()


Output :

Output

Example 3:

Suppose you have a large dataset containing information about movies, including their titles, release years, genres, and ratings. You want to perform some data cleaning and manipulation on this dataset before analyzing it further. Specifically, you want to:
1. Remove any movies that have a rating of less than 5.
2. Add a new column called “decade” that specifies the decade in which each movie was released.
3. Convert the values in the “genre” column to a list.
 

Python3




from pyspark.sql.functions import udf
from pyspark.sql.types import StructType, StructField
from pyspark.sql.types import StringType, FloatType
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
  
# Create a sample DataFrame
data = [("The Godfather", 1972,
         "Crime, Drama", 9.2),
        ("The Shawshank Redemption"
         1994, "Drama", 9.3),
        ("The Dark Knight", 2008
         "Action, Crime, Drama", 9.0),
        ("Pulp Fiction", 1994
         "Crime, Drama", 8.9),
        ("The Lord of the Rings: The Return of the King",
         2003, "Adventure, Drama, Fantasy", 8.9),
        ("Forrest Gump", 1994, "Drama, Romance", 8.8),
        ("Inception", 2010
         "Action, Adventure, Sci-Fi", 8.8),
        ("The Lord of the Rings: The Fellowship of the Ring",
         2001, "Adventure, Drama, Fantasy", 8.8),
        ("The Lion King", 1994
         "Animation, Adventure, Drama", 8.5),
        ("The Matrix", 1999, "Action, Sci-Fi", 8.7)]
  
df = spark.createDataFrame(data, ["title", "year", "genre", "rating"])
  
# Define a function to remove movies with ratings less than 5
def filter_low_ratings(row):
    return row["rating"] >= 5
  
# Define a function to extract the decade from the year
def extract_decade(row):
    year = row["year"]
    decade = str(year // 10 * 10) + "s"
    return row + (decade, )
  
# Define a function to convert the genre to a list
def convert_to_list(row):
    genre = row["genre"]
    genre_list = genre.split(", ")
    return row + (genre_list, )
  
# Apply the functions to each row of the DataFrame
df.columns[0]
filtered_df = df.filter(filter_low_ratings(df))
decade_df = filtered_df.rdd.map(extract_decade).toDF(df.columns 
                            + ["decade"])
list_df = decade_df.rdd.map(convert_to_list).toDF(df.columns + 
                          ["decade", "genre_list"])
  
# Display the results
list_df.show()


Output :

Output

Example 4:

use of a UDF decoder to decode a string using the Caesar cypher
The Caesar cypher is a simple encryption technique that replaces each letter in the original message with a letter a fixed number of positions down the alphabet. For example, with a shift of 3, A would be replaced by D, B would become E, and so on. To decode a message encoded using the Caesar cypher, we can use a UDF that takes a string and a shift value as inputs and returns the decoded message as output. Here’s an example implementation:
 

Python3




from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, StructType
from pyspark.sql.types import StructField
# Define Caesar Cipher UDF
@udf(StringType())
def caesar_cipher(text, shift):
    alphabet = 'abcdefghijklmnopqrstuvwxyz'
    shifted_alphabet = alphabet[int(shift):]+ \
        alphabet[:int(shift)]
    table = str.maketrans(alphabet, 
        shifted_alphabet)
    return text.translate(table)
  
# Define schema for DataFrame
schema = StructType([
    StructField('id', StringType(), True),
    StructField('text', StringType(), True),
    StructField('shift', StringType(), True)
])
  
# Create sample DataFrame
data = [
    ('1', 'hello', '3'),
    ('2', 'world', '5'),
    ('3', 'goodbye', '10')
]
  
df = spark.createDataFrame(data, schema)
  
# Apply Caesar Cipher UDF to DataFrame
df = df.withColumn('ciphered_text'
              caesar_cipher(df['text'], df['shift']))
  
# Show results
df.show()


This code will create a sample data frame with 3 rows, each containing a text string and a shift value. The Caesar Cipher UDF will then be applied to the ‘text’ and ‘shift’ columns, and the resulting ciphered text will be stored in a new column called ‘ciphered_text’. Finally, the results will be printed using the show() method.
Output :
 

Output



Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads