PySpark - Stack Function to Unpivot
Unpivoting transforms column-oriented data into row-oriented data. If you've worked with denormalized datasets—think spreadsheets with months as column headers or survey data with question...
Key Insights
- The
stack()function transforms wide-format DataFrames into long-format by converting multiple columns into row-based key-value pairs, essential for normalizing denormalized datasets - Stack requires exact column count specification and works through
expr()orselectExpr(), making it more explicit but also more error-prone than pandas’ melt - For production pipelines, combine
stack()with column preservation strategies and null handling to avoid data loss during unpivoting operations
Introduction to Unpivoting in PySpark
Unpivoting transforms column-oriented data into row-oriented data. If you’ve worked with denormalized datasets—think spreadsheets with months as column headers or survey data with question columns—you’ve encountered the need to unpivot.
The problem is simple: you have data spread across multiple columns that should be represented as rows. Instead of Q1, Q2, Q3, Q4 as separate columns, you want a quarter column and a sales column with four times as many rows.
PySpark’s stack() function solves this by reshaping your DataFrame from wide to long format. Here’s what the transformation looks like:
from pyspark.sql import SparkSession
from pyspark.sql.functions import expr
spark = SparkSession.builder.appName("unpivot").getOrCreate()
# Wide format - difficult to analyze
wide_df = spark.createDataFrame([
("Product_A", 100, 150, 200, 180),
("Product_B", 120, 130, 140, 160)
], ["product", "Q1", "Q2", "Q3", "Q4"])
wide_df.show()
# +---------+---+---+---+---+
# | product| Q1| Q2| Q3| Q4|
# +---------+---+---+---+---+
# |Product_A|100|150|200|180|
# |Product_B|120|130|140|160|
# +---------+---+---+---+---+
# Long format - analysis-ready
long_df = wide_df.selectExpr(
"product",
"stack(4, 'Q1', Q1, 'Q2', Q2, 'Q3', Q3, 'Q4', Q4) as (quarter, sales)"
)
long_df.show()
# +---------+-------+-----+
# | product|quarter|sales|
# +---------+-------+-----+
# |Product_A| Q1| 100|
# |Product_A| Q2| 150|
# |Product_A| Q3| 200|
# |Product_A| Q4| 180|
# |Product_B| Q1| 120|
# |Product_B| Q2| 130|
# |Product_B| Q3| 140|
# |Product_B| Q4| 160|
# +---------+-------+-----+
The long format enables grouping, filtering, and time-series analysis that would be painful with the wide format.
Understanding the Stack Function Syntax
The stack() function signature is straightforward but requires precision:
stack(n, col1_name, col1_value, col2_name, col2_value, ...)
- n: The number of rows to create per input row (must match the number of column pairs)
- col_name, col_value pairs: Alternating column name literals and column references
You must use stack() within expr() or selectExpr() because it’s a SQL function, not a DataFrame method. The function creates new columns based on the alias you provide.
Here’s a minimal example:
# Simple two-column unpivot
simple_df = spark.createDataFrame([
(1, "Alice", 85, 92),
(2, "Bob", 78, 88)
], ["id", "name", "math_score", "english_score"])
unpivoted = simple_df.selectExpr(
"id",
"name",
"stack(2, 'math', math_score, 'english', english_score) as (subject, score)"
)
unpivoted.show()
# +---+-----+-------+-----+
# | id| name|subject|score|
# +---+-----+-------+-----+
# | 1|Alice| math| 85|
# | 1|Alice|english| 92|
# | 2| Bob| math| 78|
# | 2| Bob|english| 88|
# +---+-----+-------+-----+
The number 2 tells stack() to create 2 rows per input row. The pairs 'math', math_score and 'english', english_score define what goes into the new subject and score columns.
Practical Use Case: Sales Data Transformation
Let’s work through a realistic scenario: transforming quarterly sales data from a financial reporting system.
# Typical wide-format sales data from a database or CSV
sales_wide = spark.createDataFrame([
("North", "Electronics", 45000, 52000, 48000, 61000),
("North", "Clothing", 23000, 28000, 31000, 42000),
("South", "Electronics", 38000, 41000, 39000, 47000),
("South", "Clothing", 19000, 22000, 26000, 35000)
], ["region", "category", "Q1_sales", "Q2_sales", "Q3_sales", "Q4_sales"])
print("Wide format:")
sales_wide.show()
# Transform to long format for time-series analysis
sales_long = sales_wide.selectExpr(
"region",
"category",
"""stack(4,
'Q1', Q1_sales,
'Q2', Q2_sales,
'Q3', Q3_sales,
'Q4', Q4_sales
) as (quarter, sales_amount)"""
)
print("Long format:")
sales_long.show()
# Now you can easily analyze trends
from pyspark.sql.functions import sum, avg
quarterly_totals = sales_long.groupBy("quarter").agg(
sum("sales_amount").alias("total_sales"),
avg("sales_amount").alias("avg_sales")
).orderBy("quarter")
quarterly_totals.show()
# +-------+-----------+------------------+
# |quarter|total_sales| avg_sales|
# +-------+-----------+------------------+
# | Q1| 125000| 31250.0|
# | Q2| 143000| 35750.0|
# | Q3| 144000| 36000.0|
# | Q4| 185000| 46250.0|
# +-------+-----------+------------------+
This transformation enables standard analytical queries. You can filter by quarter, calculate growth rates, or join with date dimensions—operations that are awkward with wide-format data.
Advanced Stack Patterns
Real-world scenarios require more sophisticated patterns. Here’s how to handle multiple unpivot operations while preserving identifier columns:
# Complex dataset with multiple metrics to unpivot
metrics_df = spark.createDataFrame([
("2024-01", "Store_A", 1000, 850, 45, 38),
("2024-01", "Store_B", 1200, 1050, 52, 47),
("2024-02", "Store_A", 1100, 920, 48, 41)
], ["month", "store_id", "revenue_online", "revenue_retail", "customers_online", "customers_retail"])
# Unpivot both revenue and customer counts
unpivoted_metrics = metrics_df.selectExpr(
"month",
"store_id",
"""stack(2,
'online', revenue_online, customers_online,
'retail', revenue_retail, customers_retail
) as (channel, revenue, customers)"""
)
unpivoted_metrics.show()
# +-------+--------+-------+-------+---------+
# | month|store_id|channel|revenue|customers|
# +-------+--------+-------+-------+---------+
# |2024-01| Store_A| online| 1000| 45|
# |2024-01| Store_A| retail| 850| 38|
# |2024-01| Store_B| online| 1200| 52|
# |2024-01| Store_B| retail| 1050| 47|
# |2024-02| Store_A| online| 1100| 48|
# |2024-02| Store_A| retail| 920| 41|
# +-------+--------+-------+-------+---------+
Notice how stack() can handle multiple value columns simultaneously. Each row in the stack specification can contain multiple values.
For handling nulls explicitly:
# Data with missing values
sparse_df = spark.createDataFrame([
("Item_1", 100, None, 300),
("Item_2", None, 200, None)
], ["item", "jan", "feb", "mar"])
# Stack preserves nulls
stacked = sparse_df.selectExpr(
"item",
"stack(3, 'jan', jan, 'feb', feb, 'mar', mar) as (month, value)"
)
stacked.show()
# +------+-----+-----+
# | item|month|value|
# +------+-----+-----+
# |Item_1| jan| 100|
# |Item_1| feb| null|
# |Item_1| mar| 300|
# |Item_2| jan| null|
# |Item_2| feb| 200|
# |Item_2| mar| null|
# +------+-----+-----+
# Filter out nulls if needed
stacked.filter("value is not null").show()
Stack vs. Alternative Approaches
You can achieve unpivoting without stack(), but it’s verbose and slower:
# Union-based approach (avoid this)
from pyspark.sql.functions import lit
q1 = wide_df.select("product", lit("Q1").alias("quarter"), "Q1".alias("sales"))
q2 = wide_df.select("product", lit("Q2").alias("quarter"), "Q2".alias("sales"))
q3 = wide_df.select("product", lit("Q3").alias("quarter"), "Q3".alias("sales"))
q4 = wide_df.select("product", lit("Q4").alias("quarter"), "Q4".alias("sales"))
union_result = q1.union(q2).union(q3).union(q4)
# This works but requires multiple scans of the data
# Stack is more efficient - single pass
The union approach reads the source data multiple times and generates a more complex execution plan. For datasets with many columns to unpivot, stack() is significantly more efficient.
In SQL, you can use stack() directly:
# SQL approach
wide_df.createOrReplaceTempView("sales")
sql_result = spark.sql("""
SELECT
product,
quarter,
sales
FROM sales
LATERAL VIEW stack(4,
'Q1', Q1,
'Q2', Q2,
'Q3', Q3,
'Q4', Q4
) AS quarter, sales
""")
Both approaches work identically. Choose based on your team’s preference for DataFrame API vs. SQL.
Common Pitfalls and Best Practices
The most common error is mismatching the count parameter:
# WRONG - says 3 but provides 4 pairs
wrong_df = wide_df.selectExpr(
"product",
"stack(3, 'Q1', Q1, 'Q2', Q2, 'Q3', Q3, 'Q4', Q4) as (quarter, sales)"
)
# Throws: The number of arguments must be 1 + 3 * 2
The formula is: total_args = 1 + (n * columns_per_row). If you’re creating 2 columns (quarter, sales), and want 4 rows, you need 1 + (4 * 2) = 9 arguments total.
Type mismatches cause silent issues:
# Mixed types require casting
mixed_df = spark.createDataFrame([
("A", 100, "150", 200) # Q2 is string
], ["product", "Q1", "Q2", "Q3"])
# This will fail or produce unexpected results
# Cast everything to the same type first
from pyspark.sql.functions import col
normalized = mixed_df.select(
"product",
col("Q1").cast("int"),
col("Q2").cast("int"),
col("Q3").cast("int")
)
# Now stack safely
result = normalized.selectExpr(
"product",
"stack(3, 'Q1', Q1, 'Q2', Q2, 'Q3', Q3) as (quarter, sales)"
)
Best practices:
- Validate column counts before stacking in production pipelines
- Cast columns to consistent types before unpivoting
- Preserve key columns by selecting them alongside the stack expression
- Consider partitioning the result by the new categorical column for downstream processing
- Don’t use stack() when you need to unpivot hundreds of columns—consider restructuring your data model instead
The stack() function is the right tool for transforming wide analytical datasets into normalized, query-friendly formats. Master it, and you’ll handle data reshaping tasks that would otherwise require complex custom logic or external tools.