Open In App

Defining DataFrame Schema with StructField and StructType

Last Updated : 26 Dec, 2022
Improve
Improve
Like Article
Like
Save
Share
Report

In this article, we will learn how to define DataFrame Schema with StructField and StructType. 

  • The StructType and StructFields are used to define a schema or its part for the Dataframe. This defines the name, datatype, and nullable flag for each column.
  • StructType object is the collection of StructFields objects. It is a Built-in datatype that contains the list of StructField.

Syntax: 

  • pyspark.sql.types.StructType(fields=None)
  • pyspark.sql.types.StructField(name, datatype,nullable=True)

Parameter:

  • fields – List of StructField.
  • name – Name of the column.
  • datatype – type of data i.e, Integer, String, Float etc.
  • nullable – whether fields are NULL/None or not.

For defining schema we have to use the StructType() object in which we have to define or pass the StructField() which contains the name of the column, datatype of the column, and the nullable flag. We can write:-

schema = StructType([StructField(column_name1,datatype(),nullable_flag),
            StructField(column_name2,datatype(),nullable_flag),
            StructField(column_name3,datatype(),nullable_flag)
            ])

Example 1: Defining DataFrame with schema with StructType and StructField.

Python




# importing necessary libraries
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, LongType, StringType, FloatType
 
# function to create SparkSession
def create_session():
    spk = SparkSession.builder \
        .master("local") \
        .appName("Product_mart.com") \
        .getOrCreate()
    return spk
 
# function to create dataframe
def create_df(spark, data, schema):
    df1 = spark.createDataFrame(data, schema)
    return df1
 
 
if __name__ == "__main__":
 
    # calling function to create SparkSession
    spark = create_session()
 
    input_data = [("Refrigerator", 112345, 4.0, 12499),
                  ("LED TV", 114567, 4.2, 49999),
                  ("Washing Machine", 113465, 3.9, 69999),
                  ("T-shirt", 124378, 4.1, 1999),
                  ("Jeans", 126754, 3.7, 3999),
                  ("Running Shoes", 134565, 4.7, 1499),
                  ("Face Mask", 145234, 4.6, 999)]
 
    # defining schema for the dataframe with
    # StructType and StructField
    schm = StructType([
        StructField("Product Name", StringType(), True),
        StructField("Product ID", LongType(), True),
        StructField("Rating", FloatType(), True),
        StructField("Product Price", IntegerType(), True),
    ])
 
    # calling function to create dataframe
    df = create_df(spark, input_data, schm)
 
    # visualizing dataframe and it's schema
    df.printSchema()
    df.show()


Output:

In the above code, we made the nullable flag=True. The use of making it True is that if while creating Dataframe any field value is NULL/None then also Dataframe will be created with none value. 

Example 2: Defining Dataframe schema with nested StructType.

Python




# importing necessary libraries
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, LongType, StringType, FloatType
 
# function to create SparkSession
def create_session():
    spk = SparkSession.builder \
        .master("local") \
        .appName("Product_mart.com") \
        .getOrCreate()
    return spk
 
# function to create dataframe
def create_df(spark, data, schema):
    df1 = spark.createDataFrame(data, schema)
    return df1
 
 
if __name__ == "__main__":
 
    # calling function to create SparkSession
    spark = create_session()
 
    input_data = [(("Refrigerator", 112345), 4.0, 12499),
                  (("LED TV", 114567), 4.2, 49999),
                  (("Washing Machine", 113465), 3.9, 69999),
                  (("T-shirt", 124378), 4.1, 1999),
                  (("Jeans", 126754), 3.7, 3999),
                  (("Running Shoes", 134565), 4.7, 1499),
                  (("Face Mask", 145234), 4.6, 999)]
 
    # defining schema for the dataframe using
    # nested StructType
    schm = StructType([
        StructField('Product', StructType([
            StructField('Product Name', StringType(), True),
            StructField('Product ID', LongType(), True),
        ])),
       
        StructField('Rating', FloatType(), True),
        StructField('Price', IntegerType(), True)])
 
    # calling function to create dataframe
    df = create_df(spark, input_data, schm)
 
    # visualizing dataframe and it's schema
    df.printSchema()
    df.show(truncate=False)


Output:

Example 3: Changing Structure of Dataframe and adding new column Using PySpark Column Class.

Python




# importing necessary libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.functions import col, struct, when
from pyspark.sql.types import StructType, StructField, IntegerType, LongType, StringType, FloatType
 
# function to create SparkSession
def create_session():
    spk = SparkSession.builder \
        .master("local") \
        .appName("Product_mart.com") \
        .getOrCreate()
    return spk
 
# function to create dataframe
def create_df(spark, data, schema):
    df1 = spark.createDataFrame(data, schema)
    return df1
 
 
if __name__ == "__main__":
 
    # calling function to create SparkSession
    spark = create_session()
 
    input_data = [("Refrigerator", 112345, 4.0, 12499),
                  ("LED TV", 114567, 4.2, 49999),
                  ("Washing Machine", 113465, 3.9, 69999),
                  ("T-shirt", 124378, 4.1, 1999),
                  ("Jeans", 126754, 3.7, 3999),
                  ("Running Shoes", 134565, 4.7, 1499),
                  ("Face Mask", 145234, 4.6, 999)]
 
    # defining schema for the dataframe using
    # nested StructType
    schm = StructType([
        StructField("Product Name", StringType(), True),
        StructField("Product ID", LongType(), True),
        StructField("Rating", FloatType(), True),
        StructField("Product Price", IntegerType(), True)])
 
    # calling function to create dataframe
    df = create_df(spark, input_data, schm)
 
    # copying the columns to the new struct
    # Product
    new_df = df.withColumn("Product",
                           struct(col("Product Name").alias("Name"),
                                  col("Product ID").alias("ID"),
                                  col("Rating").alias("Rating"),
                                  col("Product Price").alias("Price")))
 
    # adding new column according to the given
    # condition
    new_df = new_df.withColumn("Product Range",
                               when(col("Product Price").cast(
                                   IntegerType()) < 1000, "Low")
                               .when(col("Product Price").cast(IntegerType()
                                                              ) < 7000, "Medium")
                               .otherwise("High"))
 
    # dropping the columns as all column values
    # are copied in Product column
    new_df = new_df.drop("Product Name", "Product ID",
                         "Rating", "Product Price")
 
    # visualizing dataframe and it's schema
    new_df.printSchema()
    new_df.show(truncate=False)


Output:

  • In the above example, we are changing the structure of the Dataframe using struct() function and copy the column into the new struct ‘Product’ and creating the Product column using withColumn() function.
  • After copying the ‘Product Name’, ‘Product ID’, ‘Rating’, ‘Product Price’ to the new struct ‘Product’.
  • We are adding the new column ‘Price Range’ using withColumn() function, according to the given condition that is split into three categories i.e, Low, Medium, and High. If ‘Product Price’ is less than 1000 then that product falls in the Low category and if ‘Product Price’ is less than 7000 then that product falls in the Medium category otherwise that product fall in the High category.
  • After creating the new struct ‘Product’ and adding the new column ‘Price Range’ we have to drop the ‘Product Name’, ‘Product ID’, ‘Rating’, ‘Product Price’ column using the drop() function. Then printing the schema with changed Dataframe structure and added columns.

Example 4: Defining Dataframe schema using the JSON format and StructType().

Python




# importing necessary libraries
from pyspark.sql import SparkSession
import pyspark.sql.types as T
import json
 
# function to create SparkSession
def create_session():
    spk = SparkSession.builder \
        .master("local") \
        .appName("Product_mart.com") \
        .getOrCreate()
    return spk
 
# function to create dataframe
def create_df(spark, data, schema):
    df1 = spark.createDataFrame(data, schema)
    return df1
 
 
if __name__ == "__main__":
 
    # calling function to create SparkSession
    spark = create_session()
 
    input_data = [("Refrigerator", 4.0),
                  ("LED TV", 4.2),
                  ("Washing Machine", 3.9),
                  ("T-shirt", 4.1)
                  ]
 
    # defining schema for the dataframe with
    # StructType and StructField
    schm = T.StructType([
        T.StructField("Product Name", T.StringType(), True),
        T.StructField("Rating", T.FloatType(), True)
    ])
 
    # calling function to create dataframe
    df = create_df(spark, input_data, schm)
 
    # visualizing dataframe and it's schema
    print("Original Dataframe:-")
    df.printSchema()
    df.show()
 
    print("-------------------------------------------")
    print("Schema in json format:-")
 
    # storing schema in json format using
    # schema.json() function
    scehma = df.schema.json()
    print(scehma)
 
    # loading the json format schema
    schm1 = StructType.fromJson(json.loads(scehma))
 
    # creating dataframe using json format schema
    json_df = spark.createDataFrame(
        spark.sparkContext.parallelize(input_data), schm1)
    print("-------------------------------------------")
    print("Dataframe using json schema:-")
     
    # showing the created dataframe from json format
    # schema printing the schema of created dataframe
    json_df.printSchema()
    json_df.show()


Output:

Note: You can also store the JSON format in the file and use the file for defining the schema, code for this is also the same as above only you have to pass the JSON file in loads() function, in the above example, the schema in JSON format is stored in a variable, and we are using that variable for defining schema.

Example 5: Defining Dataframe schema using StructType() with ArrayType() and MapType().

Python




# importing necessary libraries
from pyspark.sql import SparkSession
import pyspark.sql.types as T
 
# function to create SparkSession
def create_session():
    spk = SparkSession.builder \
        .master("local") \
        .appName("Product_mart.com") \
        .getOrCreate()
    return spk
 
# function to create dataframe
def create_df(spark, data, schema):
    df1 = spark.createDataFrame(data, schema)
    return df1
 
 
if __name__ == "__main__":
 
    # calling function to create SparkSession
    spark = create_session()
 
    # Data containing the Array and Map- key,value pair
    input_data = [
        ("Alex", 'Buttler', ["Mathematics", "Computer Science"],
         {"Mathematics": 'Physics', "Chemistry": "Biology"}),
        ("Sam", "Samson", ["Chemistry, Biology"],
         {"Chemistry": 'Physics', "Mathematics": "Biology"}),
        ("Rossi", "Bryant", ["English", "Geography"],
         {"History": 'Geography', "Chemistry": "Biology"}),
        ("Sidz", "Murraz", ["History", "Environmental Science"],
         {"English": 'Environmental Science', "Chemistry": "Mathematics"}),
        ("Robert", "Cox", ["Physics", "English"],
         {"Computer Science": 'Environmental Science', "Chemistry": "Geography"})
    ]
 
    # defining schema with ArrayType and MapType()
    # using StructType() and StructField()
    array_schm = StructType([
        StructField('Firstname', StringType(), True),
        StructField('Lastname', StringType(), True),
        StructField('Subject', ArrayType(StringType()), True),
        StructField('Subject Combinations', MapType(
            StringType(), StringType()), True)
    ])
 
    # calling function for creating the dataframe
    df = create_df(spark, input_data, array_schm)
 
    # printing schema of df and showing dataframe
    df.printSchema()
    df.show(truncate=False)


Output:



Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads