PySpark - Flatten Nested Struct Column
• Flattening nested struct columns transforms hierarchical data into a flat schema, making it easier to query and compatible with systems that don't support complex types like traditional SQL...
Key Insights
• Flattening nested struct columns transforms hierarchical data into a flat schema, making it easier to query and compatible with systems that don’t support complex types like traditional SQL databases or CSV exports.
• PySpark provides dot notation for manual flattening of known structures, but recursive functions are essential for dynamically handling deeply nested or unknown schemas at scale.
• Arrays of structs require a two-step approach: first exploding the array into multiple rows, then flattening the resulting struct columns using standard techniques.
Understanding the Problem
Nested struct columns are ubiquitous in modern data engineering. When you ingest JSON from REST APIs, process event streams, or work with semi-structured data from NoSQL databases, you’ll encounter hierarchical data structures. While PySpark handles these complex types natively, many downstream systems don’t. Traditional relational databases, business intelligence tools, and simple file formats like CSV require flat schemas.
Consider this common scenario: you’re ingesting user profile data from an API that returns nested JSON. Here’s what that looks like in PySpark:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
spark = SparkSession.builder.appName("FlattenStructs").getOrCreate()
# Sample data with nested structures
data = [
(1, ("John", "Doe", ("123 Main St", "Springfield", "IL", "62701"))),
(2, ("Jane", "Smith", ("456 Oak Ave", "Portland", "OR", "97201"))),
]
schema = StructType([
StructField("id", IntegerType(), True),
StructField("person", StructType([
StructField("first_name", StringType(), True),
StructField("last_name", StringType(), True),
StructField("address", StructType([
StructField("street", StringType(), True),
StructField("city", StringType(), True),
StructField("state", StringType(), True),
StructField("zip", StringType(), True)
]), True)
]), True)
])
df = spark.createDataFrame(data, schema)
df.show(truncate=False)
This creates a DataFrame with three levels of nesting. The person column contains a struct, which itself contains an address struct. Let’s explore how to flatten this effectively.
Inspecting Nested Schemas
Before flattening, you need to understand your data structure. PySpark provides several methods for schema inspection:
# View the schema in tree format
df.printSchema()
# Output:
# root
# |-- id: integer (nullable = true)
# |-- person: struct (nullable = true)
# | |-- first_name: string (nullable = true)
# | |-- last_name: string (nullable = true)
# | |-- address: struct (nullable = true)
# | | |-- street: string (nullable = true)
# | | |-- city: string (nullable = true)
# | | |-- state: string (nullable = true)
# | | |-- zip: string (nullable = true)
# Programmatically inspect column types
from pyspark.sql.types import StructType
for field in df.schema.fields:
print(f"{field.name}: {field.dataType}")
if isinstance(field.dataType, StructType):
for subfield in field.dataType.fields:
print(f" {subfield.name}: {subfield.dataType}")
Understanding the schema structure is critical for both manual and automated flattening approaches.
Manual Flattening with Dot Notation
For simple, known structures, dot notation provides the most straightforward approach. You can access nested fields by chaining field names with dots:
from pyspark.sql.functions import col
# Flatten using dot notation
flattened_df = df.select(
col("id"),
col("person.first_name").alias("first_name"),
col("person.last_name").alias("last_name"),
col("person.address.street").alias("street"),
col("person.address.city").alias("city"),
col("person.address.state").alias("state"),
col("person.address.zip").alias("zip")
)
flattened_df.show()
# +---+----------+---------+-----------+-----------+-----+-----+
# | id|first_name|last_name| street| city|state| zip|
# +---+----------+---------+-----------+-----------+-----+-----+
# | 1| John| Doe|123 Main St|Springfield| IL|62701|
# | 2| Jane| Smith|456 Oak Ave| Portland| OR|97201|
# +---+----------+---------+-----------+-----------+-----+-----+
This approach works well when you know the exact schema and it doesn’t change frequently. The alias() method lets you control the output column names, avoiding unwieldy names like person.address.street.
Recursive Flattening for Dynamic Schemas
When dealing with unknown or frequently changing schemas, manual flattening becomes impractical. A recursive function can traverse any struct hierarchy and generate the appropriate select statements:
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType, ArrayType
def flatten_struct(df: DataFrame, separator: str = "_") -> DataFrame:
"""
Recursively flatten all struct columns in a DataFrame.
Args:
df: Input DataFrame with nested structs
separator: Character to use when joining nested field names
Returns:
DataFrame with all structs flattened
"""
# Get all column names and types
flat_cols = []
nested_cols = []
for field in df.schema.fields:
if isinstance(field.dataType, StructType):
nested_cols.append(field.name)
else:
flat_cols.append(col(field.name))
# If no nested columns, return as-is
if not nested_cols:
return df
# Expand struct columns
for nested_col in nested_cols:
struct_fields = df.schema[nested_col].dataType.fields
for struct_field in struct_fields:
flat_cols.append(
col(f"{nested_col}.{struct_field.name}").alias(
f"{nested_col}{separator}{struct_field.name}"
)
)
# Select flattened columns
df_flat = df.select(flat_cols)
# Recursively flatten if there are still nested structs
return flatten_struct(df_flat, separator)
# Apply the function
result = flatten_struct(df)
result.printSchema()
result.show()
# root
# |-- id: integer (nullable = true)
# |-- person_first_name: string (nullable = true)
# |-- person_last_name: string (nullable = true)
# |-- person_address_street: string (nullable = true)
# |-- person_address_city: string (nullable = true)
# |-- person_address_state: string (nullable = true)
# |-- person_address_zip: string (nullable = true)
This function handles arbitrary nesting depth and automatically generates meaningful column names by concatenating the field hierarchy with underscores (or your chosen separator).
Handling Arrays of Structs
Arrays of structs present a unique challenge. You can’t simply flatten them—you need to decide whether to explode the array into multiple rows or handle it differently:
from pyspark.sql.functions import explode
# Sample data with array of structs
array_data = [
(1, "John", [
{"product": "Laptop", "price": 1200},
{"product": "Mouse", "price": 25}
]),
(2, "Jane", [
{"product": "Monitor", "price": 300}
])
]
df_array = spark.createDataFrame(array_data, ["id", "name", "purchases"])
df_array.printSchema()
# Step 1: Explode the array
df_exploded = df_array.select(
col("id"),
col("name"),
explode(col("purchases")).alias("purchase")
)
# Step 2: Flatten the struct
df_final = df_exploded.select(
col("id"),
col("name"),
col("purchase.product").alias("product"),
col("purchase.price").alias("price")
)
df_final.show()
# +---+----+-------+-----+
# | id|name|product|price|
# +---+----+-------+-----+
# | 1|John| Laptop| 1200|
# | 1|John| Mouse| 25|
# | 2|Jane|Monitor| 300|
# +---+----+-------+-----+
The explode() function creates a new row for each array element. If you have multiple array columns and explode all of them, you’ll get a Cartesian product, so use this carefully. For cases where you want to keep arrays intact but flatten their struct contents, you’ll need to use higher-order functions like transform().
Performance Considerations and Best Practices
Flattening has performance implications you should understand:
Column Pruning: Flat schemas enable better column pruning with formats like Parquet. If you only need a few fields, Spark can read just those columns instead of entire nested structures.
Wide Tables: Extremely nested data can produce hundreds of columns when flattened. This can slow down the query optimizer and make DataFrames unwieldy. Consider flattening selectively—only the parts you need.
Naming Conventions: Establish consistent naming patterns. Using underscores as separators (parent_child_field) is standard and avoids issues with systems that don’t support dots in column names.
# Bad: Inconsistent naming
df.select(
col("person.first_name").alias("firstName"),
col("person.address.city").alias("person_city")
)
# Good: Consistent snake_case with clear hierarchy
df.select(
col("person.first_name").alias("person_first_name"),
col("person.address.city").alias("person_address_city")
)
Schema Evolution: If your source schema changes frequently, automated flattening functions are essential. They adapt to new fields automatically, though you should monitor for breaking changes in your pipeline.
When Not to Flatten: Keep nested structures when you’re working entirely within Spark, especially for intermediate transformations. Nested data is more compact and can be more efficient for certain operations. Only flatten when interfacing with systems that require it.
Conclusion
Flattening nested struct columns in PySpark is a fundamental skill for data engineers working with semi-structured data. Use dot notation with select() for simple, known schemas where you need full control over output column names. Deploy recursive flattening functions for dynamic schemas or deeply nested structures that would be tedious to flatten manually.
For arrays of structs, combine explode() with flattening techniques, but be mindful of the data expansion this creates. Always consider whether flattening is necessary—sometimes keeping the nested structure is more efficient.
The recursive function presented here serves as a production-ready starting point that you can customize with different naming conventions, selective flattening logic, or special handling for specific data types. Master these techniques and you’ll handle any nested data structure PySpark throws at you.