PySpark - Melt DataFrame Example
• PySpark lacks a native `melt()` function, but the `stack()` function provides equivalent functionality for converting wide-format DataFrames to long format with better performance at scale
Key Insights
• PySpark lacks a native melt() function, but the stack() function provides equivalent functionality for converting wide-format DataFrames to long format with better performance at scale
• The stack() function requires specifying the number of column pairs upfront and uses selectExpr() with explicit column references, making it less dynamic than Pandas but more efficient for distributed processing
• For production workloads, always validate your melted DataFrame schema and consider partitioning strategies since melting operations typically increase row count dramatically while reducing column count
Introduction to Melting DataFrames
Melting a DataFrame transforms data from wide format to long format by unpivoting columns into rows. Instead of having multiple columns representing different variables or time periods, you consolidate them into two columns: one identifying the variable name and another containing its value.
This operation is essential when preparing data for visualization libraries, statistical analysis, or machine learning pipelines that expect tidy data formats. You’ll frequently use melting when dealing with time-series data stored with dates as column names, survey responses with multiple choice columns, or any dataset where observations are spread across columns rather than rows.
The challenge with PySpark is that unlike Pandas, there’s no built-in melt() function. However, PySpark’s stack() function provides the same capability with syntax optimized for distributed computing.
Understanding Wide vs Long Format
Wide format stores each entity’s measurements across multiple columns. Long format stores each measurement as a separate row with identifier columns indicating what the measurement represents.
Consider monthly sales data in wide format:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("MeltExample").getOrCreate()
# Wide format: each month is a separate column
wide_data = [
("Product_A", 100, 150, 200),
("Product_B", 80, 120, 160),
("Product_C", 200, 180, 220)
]
wide_df = spark.createDataFrame(wide_data, ["product", "jan_sales", "feb_sales", "mar_sales"])
wide_df.show()
Output:
+---------+---------+---------+---------+
| product|jan_sales|feb_sales|mar_sales|
+---------+---------+---------+---------+
|Product_A| 100| 150| 200|
|Product_B| 80| 120| 160|
|Product_C| 200| 180| 220|
+---------+---------+---------+---------+
In long format, this same data would have rows for each product-month combination, making it easier to filter by month, aggregate across time periods, or feed into analytics tools expecting normalized data structures.
PySpark Melt Implementation Methods
Since PySpark doesn’t provide a native melt() function, we need to leverage existing functions creatively. The primary approach uses the stack() function, which is specifically designed for this type of transformation but requires a different syntax than Pandas users might expect.
Let’s create a more comprehensive sample dataset to work with:
sample_data = [
(1, "Store_A", 1000, 1500, 1200, 1800),
(2, "Store_B", 800, 900, 1100, 1300),
(3, "Store_C", 1200, 1400, 1600, 2000)
]
df = spark.createDataFrame(
sample_data,
["store_id", "store_name", "Q1_revenue", "Q2_revenue", "Q3_revenue", "Q4_revenue"]
)
df.show()
This creates a wide-format DataFrame where quarterly revenues are stored as separate columns.
Method 1: Using Stack Function
The stack() function is PySpark’s answer to melting. It takes a number indicating how many column pairs to stack, followed by pairs of literal values and column references.
from pyspark.sql.functions import expr
# Melt the quarterly revenue columns
melted_df = df.selectExpr(
"store_id",
"store_name",
"stack(4, 'Q1', Q1_revenue, 'Q2', Q2_revenue, 'Q3', Q3_revenue, 'Q4', Q4_revenue) as (quarter, revenue)"
)
melted_df.show()
Output:
+--------+----------+-------+-------+
|store_id|store_name|quarter|revenue|
+--------+----------+-------+-------+
| 1| Store_A| Q1| 1000|
| 1| Store_A| Q2| 1500|
| 1| Store_A| Q3| 1200|
| 1| Store_A| Q4| 1800|
| 2| Store_B| Q1| 800|
| 2| Store_B| Q2| 900|
| 2| Store_B| Q3| 1100|
| 2| Store_B| Q4| 1300|
| 3| Store_C| Q1| 1200|
| 3| Store_C| Q2| 1400|
| 3| Store_C| Q3| 1600|
| 3| Store_C| Q4| 2000|
+--------+----------+-------+-------+
The stack() function’s first argument (4) specifies how many column pairs we’re unpivoting. Each subsequent pair consists of a literal string (the variable name) and a column reference (the value). The as (quarter, revenue) clause names the two resulting columns.
For a more dynamic approach that doesn’t require hardcoding column names:
# Get columns to melt
id_vars = ["store_id", "store_name"]
value_vars = ["Q1_revenue", "Q2_revenue", "Q3_revenue", "Q4_revenue"]
# Build stack expression dynamically
stack_expr = f"stack({len(value_vars)}, "
stack_expr += ", ".join([f"'{col}', {col}" for col in value_vars])
stack_expr += ") as (quarter, revenue)"
melted_df = df.selectExpr(*id_vars, stack_expr)
melted_df.show()
This approach builds the stack expression programmatically, making it reusable for DataFrames with varying numbers of columns.
Method 2: Using SQL Expressions
For scenarios requiring more complex logic or when working in SQL-heavy environments, you can use SQL expressions directly:
df.createOrReplaceTempView("revenue_data")
melted_sql = spark.sql("""
SELECT
store_id,
store_name,
quarter,
revenue
FROM revenue_data
LATERAL VIEW stack(
4,
'Q1', Q1_revenue,
'Q2', Q2_revenue,
'Q3', Q3_revenue,
'Q4', Q4_revenue
) unpivoted AS quarter, revenue
""")
melted_sql.show()
The LATERAL VIEW clause with stack() achieves the same result but allows you to embed the melt operation within larger SQL queries, which can be beneficial when combining melting with joins, aggregations, or window functions.
Handling Multiple Value Columns
Real-world scenarios often require melting multiple measure columns simultaneously. For example, tracking both revenue and units sold:
multi_measure_data = [
(1, "Store_A", 1000, 50, 1500, 75, 1200, 60),
(2, "Store_B", 800, 40, 900, 45, 1100, 55)
]
multi_df = spark.createDataFrame(
multi_measure_data,
["store_id", "store_name", "Q1_revenue", "Q1_units", "Q2_revenue", "Q2_units", "Q3_revenue", "Q3_units"]
)
# Melt both revenue and units
melted_multi = multi_df.selectExpr(
"store_id",
"store_name",
"stack(3, 'Q1', Q1_revenue, Q1_units, 'Q2', Q2_revenue, Q2_units, 'Q3', Q3_revenue, Q3_units) as (quarter, revenue, units)"
)
melted_multi.show()
Output:
+--------+----------+-------+-------+-----+
|store_id|store_name|quarter|revenue|units|
+--------+----------+-------+-------+-----+
| 1| Store_A| Q1| 1000| 50|
| 1| Store_A| Q2| 1500| 75|
| 1| Store_A| Q3| 1200| 60|
| 2| Store_B| Q1| 800| 40|
| 2| Store_B| Q2| 900| 45|
| 2| Store_B| Q3| 1100| 55|
+--------+----------+-------+-------+-----+
Notice that stack() now takes three values per group instead of two, creating three output columns: the identifier (quarter) and two measure columns (revenue and units).
Performance Considerations and Best Practices
Melting operations fundamentally change your DataFrame’s shape, typically multiplying row count while reducing column count. This has significant implications for distributed processing.
Partition Strategy: After melting, consider repartitioning based on your downstream operations:
# Before melting - fewer rows, more columns
print(f"Wide format partitions: {df.rdd.getNumPartitions()}")
print(f"Wide format row count: {df.count()}")
# After melting - more rows, fewer columns
print(f"Long format partitions: {melted_df.rdd.getNumPartitions()}")
print(f"Long format row count: {melted_df.count()}")
# Repartition based on expected query patterns
melted_optimized = melted_df.repartition("quarter")
Execution Plan Analysis: Use explain() to understand how Spark executes your melt operation:
# Compare execution plans
print("=== Wide Format Query Plan ===")
df.filter("Q1_revenue > 1000").explain()
print("\n=== Melted Format Query Plan ===")
melted_df.filter("revenue > 1000").explain()
Caching Considerations: If you’ll reuse the melted DataFrame multiple times, cache it to avoid recomputing:
melted_df.cache()
melted_df.count() # Trigger caching
# Subsequent operations use cached data
result1 = melted_df.filter("quarter = 'Q1'").count()
result2 = melted_df.groupBy("quarter").avg("revenue").show()
Schema Validation: Always verify the output schema matches expectations, especially when building dynamic melt expressions:
melted_df.printSchema()
# Validate expected columns exist
expected_cols = {"store_id", "store_name", "quarter", "revenue"}
actual_cols = set(melted_df.columns)
assert expected_cols == actual_cols, f"Schema mismatch: {expected_cols - actual_cols}"
For production pipelines processing millions of rows, benchmark different approaches. While stack() is generally efficient, extremely wide DataFrames (hundreds of columns) might benefit from alternative strategies like collecting column names and using programmatic approaches to build optimized expressions.
The key is understanding that melting in PySpark isn’t just a simple function call—it’s a distributed transformation that requires thoughtful consideration of data volume, cluster resources, and downstream processing requirements.