Open In App

PySpark – How to Update Nested Columns?

Last Updated : 30 Jan, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

In this article, we are going to learn how to update nested columns using Pyspark in Python.

An interface for Apache Spark in Python is known as Pyspark. Do you know that you can create the nested column in the Pyspark data frame too? Not only you can create the nested column, but also you can update the nested column value according to the specified condition. Want to know more about it? Read the article further in which we have discussed updating the nested columns.

What are nested columns?

The columns which can be further divided into sub-columns are known as nested columns. In Pyspark, the nested columns are defined as struct type for which the subcolumns can be of any type whether it is IntegerType, StringType, etc.

For example: The Full Name can be further divided into First Name, Middle Name, and Last Name. Thus, here Full Name will be of StructType, while the First Name, Middle Name, and Last Name will be of StringType.

Stepwise Implementation:

Step 1: First of all, we need to import the required libraries, i.e., libraries SparkSession, StructType, StructField, StringType, IntegerType, col, lit, and when. The SparkSession library is used to create the session while StructType defines the structure of the data frame and StructField defines the columns of the data frame. The StringType and IntegerType are used to represent String and Integer values for the data frame respectively. The col is used to return a column based on the given column name while lit is used to add a new column to the data frame. 

from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.functions import col, lit, when
from pyspark.sql import SparkSession

Step 2: Now, create a spark session using the getOrCreate() function

spark_session = SparkSession.builder.getOrCreate()

Step 3: Then, define the data set in the list.

data_set = [((nested_values_1), column_value_1),
            ((nested_values_2), column_value_2),
            ((nested_values_3), column_value_3)]

Step 4: Moreover, define the structure using StructType and StructField functions respectively. 

schema = StructType([StructField('column_1',
         StructType([StructField('nested_column_1', column_type(), True),
                      StructField('nested_column_2', column_type(), True),
                     StructField('nested_column_3', column_type(), True) ])),
                      StructField('column_2', column_type(), True)])

Step 5: Further, create a Pyspark data frame using the specified structure and data set. 

df = spark_session.createDataFrame(data = data_set, schema = schema)

Step 6: Later on, update the nested column value using the withField function with nested_column_name and lit with replace_value as arguments.

updated_df = df.withColumn("column_name",
                            col("column_name").withField("nested_column_name",
                                                         lit("replace_value"))))

Step 7: Finally, display the updated data frame.

updated_df.show()

Dataset used in the below examples:

 

Example 1:

In this example, we have defined the data structure and data set and created the Pyspark data frame according to the data structure. Further, we have updated the nested column ‘Date‘ by checking the condition if ‘Date‘ is equal to the value ‘2‘ and replacing it with the value ‘24‘ if the condition meets else by putting back the existing value in that nested column.

Python3




# Pyspark program to updated nested columns
  
# Import the libraries SparkSession, StructType, StructField,
# StringType, IntegerType, col, lit, when
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.functions import col, lit, when
from pyspark.sql import SparkSession
  
# Create a spark session using getOrCreate() function
spark_session = SparkSession.builder.getOrCreate()
  
# Define the data set
data_set = [((2000, 21, 2), 18), ((1998, 14, 6), 24),
            ((1998, 1, 11), 18), ((2006, 30, 3), 16)]
  
# Define the structure for the data frame
schema = StructType([StructField('Date_Of_Birth',
         StructType([StructField('Year', IntegerType(), True),
                StructField('Month', IntegerType(), True),
                StructField('Date', IntegerType(), True)])),
                  StructField('Age', IntegerType(), True)])
  
# Create the Pyspark data frame using createDataFramr function
df = spark_session.createDataFrame(data=data_set, schema=schema)
  
# Update nested column using lit function with specific
# condition using when and otherwise function
updated_df = df.withColumn("Date_Of_Birth",
              col("Date_Of_Birth").withField("Date", when(
              col("Date_Of_Birth.Date") == 2,
              lit(24)).otherwise(lit(col("Date_Of_Birth.Date")))))
  
# Display the updated data frame
updated_df.show()


Output:

+--------------+---+
| Date_Of_Birth|Age|
+--------------+---+
|{2000, 21, 24}| 18|
| {1998, 14, 6}| 24|
| {1998, 1, 11}| 18|
| {2006, 30, 3}| 16|
+--------------+---+

Example 2:

In this example, we have defined the data structure and data set and created the Pyspark data frame according to the data structure. Further, we have updated the nested column ‘Year‘ by checking the condition if ‘Age‘ is equal to the value ‘18‘ and replacing it with the value ‘2004‘ if the condition meets else by putting back the existing value in that nested column.

Python3




# Pyspark program to updated nested columns
  
# Import the libraries SparkSession, StructType,
# StructField, StringType, IntegerType, col, lit, when
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.functions import col, lit, when
from pyspark.sql import SparkSession
  
# Create a spark session using getOrCreate() function
spark_session = SparkSession.builder.getOrCreate()
  
# Define the data set
data_set = [((2000,21,2),18),
            ((1998,14,6),24),
            ((1998,1,11),18),
            ((2006,30,3),16)]
  
# Define the structure for the data frame
schema = StructType([StructField('Date_Of_Birth',
            StructType([StructField('Year', IntegerType(), True),
            StructField('Month', IntegerType(), True),
            StructField('Date', IntegerType(), True) ])),
            StructField('Age', IntegerType(), True)])
  
# Create the Pyspark data frame using createDataFramr function
df = spark_session.createDataFrame(data = data_set,
                                   schema = schema)
  
# Update nested column using lit function with specific 
# condition using when and otherwise function
updated_df = df.withColumn("Date_Of_Birth",
                  col("Date_Of_Birth").withField("Year",
                  when (col("Age")==18,
                  lit(2004)).otherwise(
                  lit(col("Date_Of_Birth.Year")))))
  
# Display the updated data frame
updated_df.show()


Output:

+-------------+---+
|Date_Of_Birth|Age|
+-------------+---+
|{2004, 21, 2}| 18|
|{1998, 14, 6}| 24|
|{2004, 1, 11}| 18|
|{2006, 30, 3}| 16|
+-------------+---+


Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads