PySpark - Stack Function to Unpivot

Unpivoting transforms column-oriented data into row-oriented data. If you've worked with denormalized datasets—think spreadsheets with months as column headers or survey data with question...

Key Insights

  • The stack() function transforms wide-format DataFrames into long-format by converting multiple columns into row-based key-value pairs, essential for normalizing denormalized datasets
  • Stack requires exact column count specification and works through expr() or selectExpr(), making it more explicit but also more error-prone than pandas’ melt
  • For production pipelines, combine stack() with column preservation strategies and null handling to avoid data loss during unpivoting operations

Introduction to Unpivoting in PySpark

Unpivoting transforms column-oriented data into row-oriented data. If you’ve worked with denormalized datasets—think spreadsheets with months as column headers or survey data with question columns—you’ve encountered the need to unpivot.

The problem is simple: you have data spread across multiple columns that should be represented as rows. Instead of Q1, Q2, Q3, Q4 as separate columns, you want a quarter column and a sales column with four times as many rows.

PySpark’s stack() function solves this by reshaping your DataFrame from wide to long format. Here’s what the transformation looks like:

from pyspark.sql import SparkSession
from pyspark.sql.functions import expr

spark = SparkSession.builder.appName("unpivot").getOrCreate()

# Wide format - difficult to analyze
wide_df = spark.createDataFrame([
    ("Product_A", 100, 150, 200, 180),
    ("Product_B", 120, 130, 140, 160)
], ["product", "Q1", "Q2", "Q3", "Q4"])

wide_df.show()
# +---------+---+---+---+---+
# |  product| Q1| Q2| Q3| Q4|
# +---------+---+---+---+---+
# |Product_A|100|150|200|180|
# |Product_B|120|130|140|160|
# +---------+---+---+---+---+

# Long format - analysis-ready
long_df = wide_df.selectExpr(
    "product",
    "stack(4, 'Q1', Q1, 'Q2', Q2, 'Q3', Q3, 'Q4', Q4) as (quarter, sales)"
)

long_df.show()
# +---------+-------+-----+
# |  product|quarter|sales|
# +---------+-------+-----+
# |Product_A|     Q1|  100|
# |Product_A|     Q2|  150|
# |Product_A|     Q3|  200|
# |Product_A|     Q4|  180|
# |Product_B|     Q1|  120|
# |Product_B|     Q2|  130|
# |Product_B|     Q3|  140|
# |Product_B|     Q4|  160|
# +---------+-------+-----+

The long format enables grouping, filtering, and time-series analysis that would be painful with the wide format.

Understanding the Stack Function Syntax

The stack() function signature is straightforward but requires precision:

stack(n, col1_name, col1_value, col2_name, col2_value, ...)
  • n: The number of rows to create per input row (must match the number of column pairs)
  • col_name, col_value pairs: Alternating column name literals and column references

You must use stack() within expr() or selectExpr() because it’s a SQL function, not a DataFrame method. The function creates new columns based on the alias you provide.

Here’s a minimal example:

# Simple two-column unpivot
simple_df = spark.createDataFrame([
    (1, "Alice", 85, 92),
    (2, "Bob", 78, 88)
], ["id", "name", "math_score", "english_score"])

unpivoted = simple_df.selectExpr(
    "id",
    "name", 
    "stack(2, 'math', math_score, 'english', english_score) as (subject, score)"
)

unpivoted.show()
# +---+-----+-------+-----+
# | id| name|subject|score|
# +---+-----+-------+-----+
# |  1|Alice|   math|   85|
# |  1|Alice|english|   92|
# |  2|  Bob|   math|   78|
# |  2|  Bob|english|   88|
# +---+-----+-------+-----+

The number 2 tells stack() to create 2 rows per input row. The pairs 'math', math_score and 'english', english_score define what goes into the new subject and score columns.

Practical Use Case: Sales Data Transformation

Let’s work through a realistic scenario: transforming quarterly sales data from a financial reporting system.

# Typical wide-format sales data from a database or CSV
sales_wide = spark.createDataFrame([
    ("North", "Electronics", 45000, 52000, 48000, 61000),
    ("North", "Clothing", 23000, 28000, 31000, 42000),
    ("South", "Electronics", 38000, 41000, 39000, 47000),
    ("South", "Clothing", 19000, 22000, 26000, 35000)
], ["region", "category", "Q1_sales", "Q2_sales", "Q3_sales", "Q4_sales"])

print("Wide format:")
sales_wide.show()

# Transform to long format for time-series analysis
sales_long = sales_wide.selectExpr(
    "region",
    "category",
    """stack(4, 
        'Q1', Q1_sales,
        'Q2', Q2_sales,
        'Q3', Q3_sales,
        'Q4', Q4_sales
    ) as (quarter, sales_amount)"""
)

print("Long format:")
sales_long.show()

# Now you can easily analyze trends
from pyspark.sql.functions import sum, avg

quarterly_totals = sales_long.groupBy("quarter").agg(
    sum("sales_amount").alias("total_sales"),
    avg("sales_amount").alias("avg_sales")
).orderBy("quarter")

quarterly_totals.show()
# +-------+-----------+------------------+
# |quarter|total_sales|         avg_sales|
# +-------+-----------+------------------+
# |     Q1|     125000|           31250.0|
# |     Q2|     143000|           35750.0|
# |     Q3|     144000|           36000.0|
# |     Q4|     185000|           46250.0|
# +-------+-----------+------------------+

This transformation enables standard analytical queries. You can filter by quarter, calculate growth rates, or join with date dimensions—operations that are awkward with wide-format data.

Advanced Stack Patterns

Real-world scenarios require more sophisticated patterns. Here’s how to handle multiple unpivot operations while preserving identifier columns:

# Complex dataset with multiple metrics to unpivot
metrics_df = spark.createDataFrame([
    ("2024-01", "Store_A", 1000, 850, 45, 38),
    ("2024-01", "Store_B", 1200, 1050, 52, 47),
    ("2024-02", "Store_A", 1100, 920, 48, 41)
], ["month", "store_id", "revenue_online", "revenue_retail", "customers_online", "customers_retail"])

# Unpivot both revenue and customer counts
unpivoted_metrics = metrics_df.selectExpr(
    "month",
    "store_id",
    """stack(2,
        'online', revenue_online, customers_online,
        'retail', revenue_retail, customers_retail
    ) as (channel, revenue, customers)"""
)

unpivoted_metrics.show()
# +-------+--------+-------+-------+---------+
# |  month|store_id|channel|revenue|customers|
# +-------+--------+-------+-------+---------+
# |2024-01| Store_A| online|   1000|       45|
# |2024-01| Store_A| retail|    850|       38|
# |2024-01| Store_B| online|   1200|       52|
# |2024-01| Store_B| retail|   1050|       47|
# |2024-02| Store_A| online|   1100|       48|
# |2024-02| Store_A| retail|    920|       41|
# +-------+--------+-------+-------+---------+

Notice how stack() can handle multiple value columns simultaneously. Each row in the stack specification can contain multiple values.

For handling nulls explicitly:

# Data with missing values
sparse_df = spark.createDataFrame([
    ("Item_1", 100, None, 300),
    ("Item_2", None, 200, None)
], ["item", "jan", "feb", "mar"])

# Stack preserves nulls
stacked = sparse_df.selectExpr(
    "item",
    "stack(3, 'jan', jan, 'feb', feb, 'mar', mar) as (month, value)"
)

stacked.show()
# +------+-----+-----+
# |  item|month|value|
# +------+-----+-----+
# |Item_1|  jan|  100|
# |Item_1|  feb| null|
# |Item_1|  mar|  300|
# |Item_2|  jan| null|
# |Item_2|  feb|  200|
# |Item_2|  mar| null|
# +------+-----+-----+

# Filter out nulls if needed
stacked.filter("value is not null").show()

Stack vs. Alternative Approaches

You can achieve unpivoting without stack(), but it’s verbose and slower:

# Union-based approach (avoid this)
from pyspark.sql.functions import lit

q1 = wide_df.select("product", lit("Q1").alias("quarter"), "Q1".alias("sales"))
q2 = wide_df.select("product", lit("Q2").alias("quarter"), "Q2".alias("sales"))
q3 = wide_df.select("product", lit("Q3").alias("quarter"), "Q3".alias("sales"))
q4 = wide_df.select("product", lit("Q4").alias("quarter"), "Q4".alias("sales"))

union_result = q1.union(q2).union(q3).union(q4)

# This works but requires multiple scans of the data
# Stack is more efficient - single pass

The union approach reads the source data multiple times and generates a more complex execution plan. For datasets with many columns to unpivot, stack() is significantly more efficient.

In SQL, you can use stack() directly:

# SQL approach
wide_df.createOrReplaceTempView("sales")

sql_result = spark.sql("""
    SELECT 
        product,
        quarter,
        sales
    FROM sales
    LATERAL VIEW stack(4,
        'Q1', Q1,
        'Q2', Q2,
        'Q3', Q3,
        'Q4', Q4
    ) AS quarter, sales
""")

Both approaches work identically. Choose based on your team’s preference for DataFrame API vs. SQL.

Common Pitfalls and Best Practices

The most common error is mismatching the count parameter:

# WRONG - says 3 but provides 4 pairs
wrong_df = wide_df.selectExpr(
    "product",
    "stack(3, 'Q1', Q1, 'Q2', Q2, 'Q3', Q3, 'Q4', Q4) as (quarter, sales)"
)
# Throws: The number of arguments must be 1 + 3 * 2

The formula is: total_args = 1 + (n * columns_per_row). If you’re creating 2 columns (quarter, sales), and want 4 rows, you need 1 + (4 * 2) = 9 arguments total.

Type mismatches cause silent issues:

# Mixed types require casting
mixed_df = spark.createDataFrame([
    ("A", 100, "150", 200)  # Q2 is string
], ["product", "Q1", "Q2", "Q3"])

# This will fail or produce unexpected results
# Cast everything to the same type first
from pyspark.sql.functions import col

normalized = mixed_df.select(
    "product",
    col("Q1").cast("int"),
    col("Q2").cast("int"),
    col("Q3").cast("int")
)

# Now stack safely
result = normalized.selectExpr(
    "product",
    "stack(3, 'Q1', Q1, 'Q2', Q2, 'Q3', Q3) as (quarter, sales)"
)

Best practices:

  1. Validate column counts before stacking in production pipelines
  2. Cast columns to consistent types before unpivoting
  3. Preserve key columns by selecting them alongside the stack expression
  4. Consider partitioning the result by the new categorical column for downstream processing
  5. Don’t use stack() when you need to unpivot hundreds of columns—consider restructuring your data model instead

The stack() function is the right tool for transforming wide analytical datasets into normalized, query-friendly formats. Master it, and you’ll handle data reshaping tasks that would otherwise require complex custom logic or external tools.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.