PySpark UDF of MapType
Last Updated :
12 Sep, 2023
Consider a scenario where we have a PySpark DataFrame column of type MapType. Keys are strings and values ​​can be of different types (integer, string, boolean, etc.). I need to do some operations on this column, Filter, transform values, or extract specific keys from a map. PySpark allows you to define custom functions using user-defined functions (UDFs) to apply transformations to Spark DataFrames. PySpark has built-in UDF support for primitive data types, but handling complex data structures like MapType with mixed value types requires a custom approach. This tutorial will walk you through the steps to create his PySpark UDF of mixed-value MapType.
PySpark UDF of MapType Function and their Syntax
The UDF function in pyspark.sql.functions is used to define custom functions. It requires two parameters. Python functions and return types.
Syntax of PySpark UDF
Syntax: udf(function, return type)
A MapType column represents a map or dictionary-like data structure that maps keys to values. It is a collection of key-value pairs, where keys and values ​​can have different data types.
Syntax of PySpark MapType
Syntax: MapType(keyType,valueType,valueContainsNull=True)
Parameters:
- keyType: Datatype of the keys in the map, which are not allowed to be null.
- valueType: Datatype of the values in the map
- valueContainsNull: Which is Boolean type, indicates whether values contain null values.
Create PySpark MapType
In PySpark you can create a MapType using the MapType class in the pyspark.sql.types module. MapType represents the data type of the map or dictionary that stores each key/value pair. Here’s an example, of how we can create MapType
Python3
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import MapType, StringType, IntegerType
spark = SparkSession.builder.getOrCreate()
key_type = StringType()
value_type = IntegerType()
map_type = MapType(key_type, value_type)
data = [
( 1 , { "apple" : 3 , "orange" : 2 }),
( 2 , { "banana" : 4 , "kiwi" : 1 }),
( 3 , { "grape" : 5 , "mango" : 2 })
]
df = spark.createDataFrame(data, [ "id" , "fruit_counts" ])
schema = df.schema
schema[ "fruit_counts" ].dataType = map_type
df = spark.createDataFrame(df.rdd, schema)
df.show(truncate = False )
|
Output:
Accessing Map Values and Filtering Rows
To access map values ​​and filter rows based on specific criteria in PySpark, you can use the getItem() function to get the value from the map column and the filter() method to pass the filter criteria to the DataFrame. Below is an example showing how to access map values ​​and filter rows in PySpark.
Python3
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
spark = SparkSession.builder.getOrCreate()
data = [
( 1 , { "apple" : 3 , "orange" : 2 }),
( 2 , { "banana" : 4 , "kiwi" : 1 }),
( 3 , { "grape" : 5 , "mango" : 2 })
]
df = spark.createDataFrame(data, [ "id" , "fruit_counts" ])
df = df.withColumn( "apple_count" , col( "fruit_counts" ).getItem( "apple" ))
filtered_df = df. filter (col( "apple_count" ) > 2 )
filtered_df.show(truncate = False )
|
Output:
Exploring a MapType column
To explore a MapType column in PySpark, we can use the explode function provided by PySpark’s function module. The Explosion() function is used to transform a column of MapTypes into multiple rows. Each row represents a key-value pair in the map. Below is an example showing how MapType columns are resolved in PySpark.
Python3
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, explode
spark = SparkSession.builder.getOrCreate()
data = [
( 1 , { "apple" : 3 , "orange" : 2 }),
( 2 , { "banana" : 4 , "kiwi" : 1 }),
( 3 , { "grape" : 5 , "mango" : 2 })
]
df = spark.createDataFrame(data, [ "id" , "fruit_counts" ])
exploded_df = df.select( "id" , explode( "fruit_counts" ).alias( "fruit" , "count" ))
exploded_df.show(truncate = False )
|
Output:
UDF of MapType with mixed value type
To process data using Spark, essential modules are imported, enabling the creation of a SparkSession, definition of UDFs, column manipulation, and data type specification. With the SparkSession established, sample data represented as a list of tuples is transformed into a DataFrame with specified column names. A Python UDF is then defined to process map values, converting strings to uppercase, multiplying integers by 2, and setting other value types to None. The UDF is registered, specifying the return type as a MapType with string keys and values. The UDF is applied to the map column (Fruit_counts) using withColumn, resulting in a new column called ‘processed_counts’. The DataFrame, now displaying the original column and the newly processed data, is printed for examination.
Python3
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf
from pyspark.sql.types import MapType, StringType, IntegerType
spark = SparkSession.builder.getOrCreate()
data = [
( 1 , { "apple" : 3 , "orange" : "two" }),
( 2 , { "banana" : "four" , "kiwi" : 1 }),
( 3 , { "grape" : 5 , "mango" : "three" })
]
df = spark.createDataFrame(data, [ "id" , "fruit_counts" ])
def process_map(map_value):
processed_map = {}
for key, value in map_value.items():
if isinstance (value, str ):
processed_map[key] = value.upper()
elif isinstance (value, int ):
processed_map[key] = value * 2
else :
processed_map[key] = None
return processed_map
process_map_udf = udf(process_map, MapType(StringType(), StringType()))
processed_df = df.withColumn( "processed_counts" , process_map_udf(col( "fruit_counts" )))
processed_df.show(truncate = False )
|
Output:
With Multiple value types using maptype and UDF
The function process_map takes a dictionary parameter called map_data, which represents the MapType. It accesses the ‘integer’ key from map_data using the get method and checks if the value is an integer using the isinstance function. If it’s an integer, the function multiplies it by 2. Then, it retrieves the ‘array’ key and, if the value exists and is a list, it performs a specific operation on each element of the array (e.g., converting them to uppercase using a list comprehension). Similarly, the function retrieves the ‘string’ key, and if the value exists and is a string, it applies another operation (e.g., converting it to lowercase using the lower method). Finally, the modified map_data dictionary is returned, containing the processed values based on the specified operations.
By using this approach, you can handle different value types within a single map type and perform specific operations on each value based on its type.
Python3
def process_map(map_data):
integer_value = map_data.get( 'integer' )
if integer_value is not None :
map_data[ 'integer' ] = integer_value * 2
array_value = map_data.get( 'array' )
if array_value is not None and isinstance (array_value, list ):
map_data[ 'array' ] = [element.upper() for element in array_value]
string_value = map_data.get( 'string' )
if string_value is not None and isinstance (string_value, str ):
map_data[ 'string' ] = string_value.lower()
return map_data
my_map = {
'integer' : 10 ,
'array' : [ 'apple' , 'banana' , 'cherry' ],
'string' : 'Hello World'
}
processed_map = process_map(my_map)
print (processed_map)
|
Output:
Using JSON file
To begin data processing with Spark, a SparkSession is created. The JSON data is represented as a list of dictionaries. The DataFrame is created using spark.createDataFrame() to handle the JSON data. For specific data extraction from a map, a User-Defined Function (UDF) named extract_details is defined and registered using udf(). The UDF is applied to the ‘details’ column of the DataFrame using withColumn(), which results in a new column ‘details_extracted’. From the DataFrame, the desired columns (‘name’, ‘details_extracted’, ‘date’) are selected with select(). The resulting DataFrame, containing the selected details, is displayed using show(). Finally, the Spark session is stopped to complete the data processing.
Python3
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import MapType, StringType
spark = SparkSession.builder.getOrCreate()
json_data = [
{
'name' : 'John' ,
'details' : {
'age' : 25 ,
'interests' : [ 'reading' , 'music' ],
'zip' : '12345'
},
'date' : '2023-06-30'
},
{
'name' : 'Jane' ,
'details' : {
'age' : 30 ,
'interests' : [ 'sports' , 'movies' ],
'zip' : '67890' ,
'extra_info' : {
'height' : 170 ,
'weight' : 60
}
},
'date' : '2023-06-30'
}
]
df = spark.createDataFrame(json_data)
def extract_details(details_map):
age = details_map.get( 'age' )
interests = details_map.get( 'interests' )
zip_code = details_map.get( 'zip' )
return f "Age: {age}, Interests: {interests}, Zip Code: {zip_code}"
extract_details_udf = udf(extract_details, StringType())
df_with_details = df.withColumn( 'details_extracted' , extract_details_udf(df[ 'details' ]))
df_with_details.select( 'name' , 'details_extracted' , 'date' ).show(truncate = False )
spark.stop()
|
Output:
Share your thoughts in the comments
Please Login to comment...